ml_switcheroo.plugins.reshape ============================= .. py:module:: ml_switcheroo.plugins.reshape .. autoapi-nested-parse:: Plugin for "View" Semantics and Reshape strictness. Addresses: PyTorch `tensor.view(*shape)` requires contiguous memory and shares data. JAX `jnp.reshape(arr, shape)` works on any array, copying if necessary, producing immutable output. Semantic Mismatch: In PyTorch, `view` is often used as an assertion of zero-copy reshaping. In JAX, copy/view distinction is less relevant for correctness due to immutability, but relevant for performance. Plugin Logic: 1. **Strict Mode**: If `config.strict_mode` is enabled, this plugin can inject synchronization (`block_until_ready()`) or explicit copies to isolate performance artifacts, depending on the configuration policy. (Prompt request: injects block/copy). 2. **Argument Packing**: Handles the conversion from varargs `view(a, b)` to tuple `reshape((a, b))` if not already handled by a prior pass. Functions --------- .. autoapisummary:: ml_switcheroo.plugins.reshape.transform_view_semantics Module Contents --------------- .. py:function:: transform_view_semantics(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) -> libcst.Call Hook: Maps `view` -> `reshape` with optional strictness injections. Transformation: Input: `x.view(a, b)` Standard Output: `jax.numpy.reshape(x, (a, b))` Strict Output: `jax.numpy.reshape(x, (a, b)).block_until_ready()`