ml_switcheroo.plugins.flatten¶
Plugin for Dimension-Range Flattening.
PyTorch’s flatten(start_dim, end_dim) is a powerful operation that collapses a specific range of dimensions. JAX and NumPy rely on: 1. ravel(): Flattens everything (equivalent to flatten(0, -1)). 2. reshape(): Flattens dimensions if the new shape is calculated correctly.
- Common Use Case:
x = torch.flatten(x, 1) -> Flattens from dim 1 to end (preserving batch). Used in almost every CNN classifier head.
- Transformation Strategy:
If start_dim=1 and end_dim=-1 (or implicit): Generate reshape(x, (x.shape[0], -1)).
If start_dim=0 and end_dim=-1: Generate ravel(x).
Input safety: If the input is a complex expression (e.g. function call), we duplicate it in the shape lookup. This is acceptable for pure functions but suboptimal. Assuming input is usually a variable (x).
Functions¶
|
Hook: Transforms flatten(x, start, end) into reshape or ravel. |
Module Contents¶
- ml_switcheroo.plugins.flatten.transform_flatten(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) libcst.Call¶
Hook: Transforms flatten(x, start, end) into reshape or ravel.
Target Frameworks: JAX, NumPy, TensorFlow, MLX.