ml_switcheroo.plugins.scatter

Plugin for Scatter/Gather Syntax Transformation.

Handles the fundamental semantic difference between: 1. PyTorch (x.scatter_(dim, index, src)): Implicitly iterates dim. 2. JAX/NumPy (x.at[index].set(src)): Explicit indexing via special accessor.

This plugin converts: - x.scatter_(dim, index, src) -> x.at[index].set(src) (Simple Case) - x.scatter(dim, index, src) -> x.at[index].set(src) (Out-of-place)

Decoupling Logic:

Removed framework checks. This logic executes if the operation is mapped via requires_plugin=”scatter_indexer”.

Functions

transform_scatter(→ libcst.CSTNode)

Hook: Transforms scatter method calls into index-update syntax (JAX style).

Module Contents

ml_switcheroo.plugins.scatter.transform_scatter(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) libcst.CSTNode[source]

Hook: Transforms scatter method calls into index-update syntax (JAX style).

Trigger: Operations mapped with requires_plugin: “scatter_indexer”.

Transformation:

Input: tensor.scatter_(dim, index, src) Output: tensor.at[index].set(src)

Note regarding dim: JAX’s at[index] syntax implies the indices are fully specified or slicing. PyTorch’s scatter applies along a dim. This transformation maps to the high-level at[idx].set(src) utility.

Parameters:
  • node – The original CST Call (scatter).

  • ctx – The hook context.

Returns:

The transformed CST Call (at[].set()).