ml_switcheroo.plugins.flatten ============================= .. py:module:: ml_switcheroo.plugins.flatten .. autoapi-nested-parse:: Plugin for Dimension-Range Flattening. PyTorch's `flatten(start_dim, end_dim)` is a powerful operation that collapses a specific range of dimensions. JAX and NumPy rely on: 1. `ravel()`: Flattens everything (equivalent to `flatten(0, -1)`). 2. `reshape()`: Flattens dimensions if the new shape is calculated correctly. Common Use Case: `x = torch.flatten(x, 1)` -> Flattens from dim 1 to end (preserving batch). Used in almost every CNN classifier head. Transformation Strategy: 1. If `start_dim=1` and `end_dim=-1` (or implicit): Generate `reshape(x, (x.shape[0], -1))`. 2. If `start_dim=0` and `end_dim=-1`: Generate `ravel(x)`. 3. Input safety: If the input is a complex expression (e.g. function call), we duplicate it in the shape lookup. This is acceptable for pure functions but suboptimal. Assuming input is usually a variable (`x`). Functions --------- .. autoapisummary:: ml_switcheroo.plugins.flatten.transform_flatten Module Contents --------------- .. py:function:: transform_flatten(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) -> libcst.Call Hook: Transforms `flatten(x, start, end)` into `reshape` or `ravel`. Target Frameworks: JAX, NumPy, TensorFlow, MLX.