Remat ===== Applies gradient checkpointing (rematerialization) to a function to save memory at the cost of compute. **Abstract Signature:** ``Remat(f: Callable)`` .. raw:: html

PyTorch

API: torch.utils.checkpoint.checkpoint
Strategy: Plugin (checkpoint_wrapper)

JAX (Core)

API: jax.checkpoint
Strategy: Direct Mapping

Flax NNX

API: flax.nnx.remat
Strategy: Direct Mapping