GatherΒΆ

Gathers values along an axis specified by dim.

Abstract Signature:

Gather(input: Tensor, dim: int, index: Tensor)

PyTorch

API: torch.gather
Strategy: Direct Mapping

JAX (Core)

API: jnp.take_along_axis
Strategy: Plugin (gather_adapter)

NumPy

API: np.take_along_axis
Strategy: Plugin (gather_adapter)

Keras

API: keras.ops.take_along_axis
Strategy: Plugin (gather_adapter)

TensorFlow

API: tf.gather
Strategy: Macro 'tf.gather({input}, {index}, batch_dims={dim})'

Apple MLX

API: mlx.core.take_along_axis
Strategy: Plugin (gather_adapter)

Flax NNX

API: jnp.take_along_axis
Strategy: Plugin (gather_adapter)

PaxML / Praxis

API: jnp.take_along_axis
Strategy: Plugin (gather_adapter)