TakeAlongAxis ============= Take elements from an array. **Abstract Signature:** ``TakeAlongAxis(arr: Tensor, indices: Tensor, axis: int | None = -1, mode: str | None, fill_value)`` .. raw:: html
jax.numpy.take_along_axis
numpy.take_along_axis