ml_switcheroo.plugins.flatten¶

Plugin for Dimension-Range Flattening.

PyTorch’s flatten(start_dim, end_dim) is a powerful operation that collapses a specific range of dimensions. JAX and NumPy rely on: 1. ravel(): Flattens everything (equivalent to flatten(0, -1)). 2. reshape(): Flattens dimensions if the new shape is calculated correctly.

Common Use Case:

x = torch.flatten(x, 1) -> Flattens from dim 1 to end (preserving batch). Used in almost every CNN classifier head.

Transformation Strategy:
  1. If start_dim=1 and end_dim=-1 (or implicit): Generate reshape(x, (x.shape[0], -1)).

  2. If start_dim=0 and end_dim=-1: Generate ravel(x).

  3. Input safety: If the input is a complex expression (e.g. function call), we duplicate it in the shape lookup. This is acceptable for pure functions but suboptimal. Assuming input is usually a variable (x).

Functions¶

transform_flatten(→ libcst.Call)

Hook: Transforms flatten(x, start, end) into reshape or ravel.

Module Contents¶

ml_switcheroo.plugins.flatten.transform_flatten(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) → libcst.Call¶

Hook: Transforms flatten(x, start, end) into reshape or ravel.

Target Frameworks: JAX, NumPy, TensorFlow, MLX.