ml_switcheroo.plugins.gather¶

Plugin for Gather Semantics Adaptation.

Addresses the signature mismatch between: 1. PyTorch: torch.gather(input, dim, index, *, sparse_grad=False, out=None) 2. JAX/NumPy: jax.numpy.take_along_axis(arr, indices, axis)

Transformation: - Reorders positional arguments: (input, dim, index) -> (input, index, dim). - Maps keyword arguments: dim -> axis, index -> indices. - Strips unsupported kwargs like sparse_grad or out.

This ensures that torch.gather(x, 1, idx) correctly becomes jnp.take_along_axis(x, idx, 1).

Functions¶

transform_gather(→ libcst.Call)

Hook: Adapts gather calls to take_along_axis semantics.

Module Contents¶

ml_switcheroo.plugins.gather.transform_gather(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) → libcst.Call¶

Hook: Adapts gather calls to take_along_axis semantics.

Target Frameworks: JAX, NumPy.