ml_switcheroo.plugins.scatter¶
Plugin for Scatter/Gather Syntax Transformation.
Handles the fundamental semantic difference between: 1. PyTorch (x.scatter_(dim, index, src)): Implicitly iterates dim. 2. JAX/NumPy (x.at[index].set(src)): Explicit indexing via special accessor.
This plugin converts: - x.scatter_(dim, index, src) -> x.at[index].set(src) (Simple Case) - x.scatter(dim, index, src) -> x.at[index].set(src) (Out-of-place)
- Decoupling Logic:
Removed framework checks. This logic executes if the operation is mapped via requires_plugin=”scatter_indexer”.
Functions¶
|
Hook: Transforms scatter method calls into index-update syntax (JAX style). |
Module Contents¶
- ml_switcheroo.plugins.scatter.transform_scatter(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) libcst.CSTNode[source]¶
Hook: Transforms scatter method calls into index-update syntax (JAX style).
Trigger: Operations mapped with requires_plugin: “scatter_indexer”.
- Transformation:
Input: tensor.scatter_(dim, index, src) Output: tensor.at[index].set(src)
Note regarding dim: JAX’s at[index] syntax implies the indices are fully specified or slicing. PyTorch’s scatter applies along a dim. This transformation maps to the high-level at[idx].set(src) utility.
- Parameters:
node – The original CST Call (scatter).
ctx – The hook context.
- Returns:
The transformed CST Call (at[].set()).