BatchNormalization
Layer that normalizes its inputs.
Abstract Signature:
BatchNormalization(axis: int = -1, momentum: float = 0.99, epsilon: float = 0.001, center: bool = True, scale: bool = True)
PyTorch
API: torch.nn.BatchNorm2d
Strategy: Direct Mapping
JAX (Core)
API: β
Strategy: Custom / Partial
Keras
API: keras.layers.BatchNormalization
Strategy: Direct Mapping
TensorFlow
API: tf.keras.layers.BatchNormalization
Strategy: Direct Mapping
Apple MLX
API: mlx.nn.BatchNorm
Strategy: Direct Mapping
Flax NNX
API: flax.nnx.BatchNorm
Strategy: Direct Mapping
PaxML / Praxis
API: praxis.layers.BatchNorm
Strategy: Direct Mapping