ml_switcheroo.plugins.gather ============================ .. py:module:: ml_switcheroo.plugins.gather .. autoapi-nested-parse:: Plugin for Gather Semantics Adaptation. Addresses the signature mismatch between: 1. PyTorch: `torch.gather(input, dim, index, *, sparse_grad=False, out=None)` 2. JAX/NumPy: `jax.numpy.take_along_axis(arr, indices, axis)` Transformation: - Reorders positional arguments: `(input, dim, index)` -> `(input, index, dim)`. - Maps keyword arguments: `dim` -> `axis`, `index` -> `indices`. - Strips unsupported kwargs like `sparse_grad` or `out`. This ensures that `torch.gather(x, 1, idx)` correctly becomes `jnp.take_along_axis(x, idx, 1)`. Functions --------- .. autoapisummary:: ml_switcheroo.plugins.gather.transform_gather Module Contents --------------- .. py:function:: transform_gather(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) -> libcst.Call Hook: Adapts gather calls to take_along_axis semantics. Target Frameworks: JAX, NumPy.