ml_switcheroo.plugins.scatter

Plugin for Scatter/Gather Syntax Transformation.

Handles the fundamental semantic difference between: 1. PyTorch (x.scatter_(dim, index, src)): Implicitly iterates dim. 2. JAX/NumPy (x.at[index].set(src)): Explicit indexing via special accessor.

This plugin converts: - x.scatter_(dim, index, src) -> x.at[index].set(src) (Simple Case) - x.scatter(dim, index, src) -> x.at[index].set(src) (Out-of-place)

Warning: This plugin currently handles the primary case where indices match the tensor rank or simple 1D scattering. Complex dim arguments often require jax.lax.scatter which has a very different signature closer to tf.scatter_nd. This implementation maps to the high-level at[].set() utility which covers the majority of user-facing logic (e.g. masking, simple updates).

Functions

transform_scatter(→ libcst.CSTNode)

Hook: Transforms scatter method calls into JAX index-update syntax.

Module Contents

ml_switcheroo.plugins.scatter.transform_scatter(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) libcst.CSTNode

Hook: Transforms scatter method calls into JAX index-update syntax.

Trigger: Operations mapped with requires_plugin: “scatter_indexer”. Target: JAX/Flax.

Transformation:

Input: tensor.scatter_(dim, index, src) Output: tensor.at[index].set(src)

Note regarding dim: JAX’s at[index] syntax implies the indices are fully specified or slicing. PyTorch’s scatter applies along a dim. Simply swapping scatter(dim, idx, src) to at[idx].set(src) is only valid if idx is compatible with JAX advanced indexing for that shape.

However, for the purpose of structural transpilation, at[idx].set(src) is the nearest syntactic equivalent. Proper dimension handling often requires take_along_axis generics or jax.lax.scatter_add which is lower level.