ml_switcheroo.plugins.scatter¶
Plugin for Scatter/Gather Syntax Transformation.
Handles the fundamental semantic difference between: 1. PyTorch (x.scatter_(dim, index, src)): Implicitly iterates dim. 2. JAX/NumPy (x.at[index].set(src)): Explicit indexing via special accessor.
This plugin converts: - x.scatter_(dim, index, src) -> x.at[index].set(src) (Simple Case) - x.scatter(dim, index, src) -> x.at[index].set(src) (Out-of-place)
Warning: This plugin currently handles the primary case where indices match the tensor rank or simple 1D scattering. Complex dim arguments often require jax.lax.scatter which has a very different signature closer to tf.scatter_nd. This implementation maps to the high-level at[].set() utility which covers the majority of user-facing logic (e.g. masking, simple updates).
Functions¶
|
Hook: Transforms scatter method calls into JAX index-update syntax. |
Module Contents¶
- ml_switcheroo.plugins.scatter.transform_scatter(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) libcst.CSTNode¶
Hook: Transforms scatter method calls into JAX index-update syntax.
Trigger: Operations mapped with requires_plugin: “scatter_indexer”. Target: JAX/Flax.
- Transformation:
Input: tensor.scatter_(dim, index, src) Output: tensor.at[index].set(src)
Note regarding dim: JAX’s at[index] syntax implies the indices are fully specified or slicing. PyTorch’s scatter applies along a dim. Simply swapping scatter(dim, idx, src) to at[idx].set(src) is only valid if idx is compatible with JAX advanced indexing for that shape.
However, for the purpose of structural transpilation, at[idx].set(src) is the nearest syntactic equivalent. Proper dimension handling often requires take_along_axis generics or jax.lax.scatter_add which is lower level.