CheckpointΒΆ

Checkpoint logic for gradient reduction.

Abstract Signature:

Checkpoint(function: Callable, args)

PyTorch

API: torch.utils.checkpoint.checkpoint
Strategy: Direct Mapping

JAX (Core)

API: jax.checkpoint
Strategy: Direct Mapping

Apple MLX

API: mlx.core.checkpoint
Strategy: Direct Mapping