RematΒΆ

Applies gradient checkpointing (rematerialization) to a function to save memory at the cost of compute.

Abstract Signature:

Remat(f: Callable)

PyTorch

API: torch.utils.checkpoint.checkpoint
Strategy: Plugin (checkpoint_wrapper)

JAX (Core)

API: jax.checkpoint
Strategy: Direct Mapping

Flax NNX

API: flax.nnx.remat
Strategy: Direct Mapping