ml_switcheroo.plugins.flatten¶
Plugin for Dimension-Range Flattening.
PyTorch’s flatten(start_dim, end_dim) collapses a range of dimensions. Mapping strategies: 1. JAX: jax.lax.collapse(x, start, stop) - Most robust for dynamic shapes. 2. NumPy/Default: x.reshape(…) or x.ravel().
Functions¶
|
Hook: Transforms flatten(x, start, end) into target-specific logic. |
Module Contents¶
- ml_switcheroo.plugins.flatten.transform_flatten(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) libcst.Call[source]¶
Hook: Transforms flatten(x, start, end) into target-specific logic.