ComputeMomentsΒΆ

Computes mean and variance over the valid data points in inputs.

Abstract Signature:

ComputeMoments(inputs: Tensor, padding: Tensor, reduce_over_dims: List[int])

PyTorch

API: β€”
Strategy: Plugin (compute_moments_masked)

PaxML / Praxis

API: praxis.layers.compute_moments
Strategy: Direct Mapping