BatchNormΒΆ

Applies Batch Normalization to the input.

Abstract Signature:

BatchNorm(input: Tensor, running_mean: Tensor, running_var: Tensor, weight: Tensor, bias: Tensor, training: bool = False, momentum: float = 0.1, eps: float = 1e-5)

PyTorch

API: torch.batch_norm
Strategy: Direct Mapping

JAX (Core)

API: β€”
Strategy: Custom / Partial

NumPy

API: β€”
Strategy: Custom / Partial

Keras

API: keras.ops.batch_normalization
Strategy: Direct Mapping

TensorFlow

API: tf.nn.batch_normalization
Strategy: Direct Mapping

Apple MLX

API: mlx.nn.layers.normalization.BatchNorm
Strategy: Direct Mapping

Flax NNX

API: nnx.nn.normalization.BatchNorm
Strategy: Direct Mapping

PaxML / Praxis

API: β€”
Strategy: Custom / Partial