ml_switcheroo.plugins.gather¶
Plugin for Gather Semantics Adaptation.
Addresses the signature mismatch between: 1. PyTorch: torch.gather(input, dim, index, *, sparse_grad=False, out=None) 2. Target Framework (e.g. JAX/NumPy): target.take_along_axis(arr, indices, axis)
Transformation: - Reorders positional arguments: (input, dim, index) -> (input, index, dim). - Maps keyword arguments: dim -> axis, index -> indices. - Strips unsupported kwargs like sparse_grad or out.
Decoupling Logic: - Removes hardcoded framework checks. - Strict lookup: If Gather is not mapped in semantics, preserves original call.
Functions¶
|
Hook: Adapts gather calls to take_along_axis semantics. |
Module Contents¶
- ml_switcheroo.plugins.gather.transform_gather(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) libcst.Call[source]¶
Hook: Adapts gather calls to take_along_axis semantics.
Target API Convention: func(input, indices, axis). Source (Torch) Convention: func(input, dim, index).
- Parameters:
node – The original CST Call node.
ctx – Hook Context containing semantic definitions.
- Returns:
Transformed Call node if API mapping exists, else original node.