ml_switcheroo.plugins.flatten¶

Plugin for Dimension-Range Flattening.

PyTorch’s flatten(start_dim, end_dim) collapses a range of dimensions. Mapping strategies: 1. JAX: jax.lax.collapse(x, start, stop) - Most robust for dynamic shapes. 2. NumPy/Default: x.reshape(…) or x.ravel().

Functions¶

transform_flatten(→ libcst.Call)

Hook: Transforms flatten(x, start, end) into target-specific logic.

Module Contents¶

ml_switcheroo.plugins.flatten.transform_flatten(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) → libcst.Call[source]¶

Hook: Transforms flatten(x, start, end) into target-specific logic.