GatherΒΆ
Gathers values along an axis specified by dim.
Abstract Signature:
Gather(input: Tensor, dim: int, index: Tensor)
JAX (Core)
API:
jnp.take_along_axisStrategy: Plugin (gather_adapter)
NumPy
API:
np.take_along_axisStrategy: Plugin (gather_adapter)
Keras
API:
keras.ops.take_along_axisStrategy: Plugin (gather_adapter)
TensorFlow
API:
tf.gatherStrategy: Macro 'tf.gather({input}, {index}, batch_dims={dim})'
Apple MLX
API:
mlx.core.take_along_axisStrategy: Plugin (gather_adapter)
Flax NNX
API:
jnp.take_along_axisStrategy: Plugin (gather_adapter)
PaxML / Praxis
API:
jnp.take_along_axisStrategy: Plugin (gather_adapter)