ml_switcheroo.plugins.gather¶
Plugin for Gather Semantics Adaptation.
Addresses the signature mismatch between: 1. PyTorch: torch.gather(input, dim, index, *, sparse_grad=False, out=None) 2. JAX/NumPy: jax.numpy.take_along_axis(arr, indices, axis)
Transformation: - Reorders positional arguments: (input, dim, index) -> (input, index, dim). - Maps keyword arguments: dim -> axis, index -> indices. - Strips unsupported kwargs like sparse_grad or out.
This ensures that torch.gather(x, 1, idx) correctly becomes jnp.take_along_axis(x, idx, 1).
Functions¶
|
Hook: Adapts gather calls to take_along_axis semantics. |
Module Contents¶
- ml_switcheroo.plugins.gather.transform_gather(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) libcst.Call¶
Hook: Adapts gather calls to take_along_axis semantics.
Target Frameworks: JAX, NumPy.