ml_switcheroo.plugins.flatten¶

Plugin for Dimension-Range Flattening.

PyTorch’s flatten(start_dim, end_dim) is a powerful operation that collapses a specific range of dimensions. JAX and NumPy rely on: 1. ravel(): Flattens everything (equivalent to flatten(0, -1)). 2. reshape(): Flattens dimensions if the new shape is calculated correctly.

Common Use Case:

x = torch.flatten(x, 1) -> reshaping (x.shape[0], -1).

Transformation Strategy:
  1. If start_dim=1 and end_dim=-1 (or implicit): Generate target_reshape(x, (x.shape[0], -1)).

  2. If start_dim=0 and end_dim=-1: Generate target_ravel(x).

  3. Decoupling: If target APIs are not found in semantics, returns original node.

Functions¶

transform_flatten(→ libcst.Call)

Hook: Transforms flatten(x, start, end) into reshape or ravel.

Module Contents¶

ml_switcheroo.plugins.flatten.transform_flatten(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) → libcst.Call[source]¶

Hook: Transforms flatten(x, start, end) into reshape or ravel.

Decoupling Update: Logic checks ctx.plugin_traits.has_numpy_compatible_arrays. Strictly looks up flatten_full or flatten_range abstract ops.