PyTorch
API: torch.no_grad
Strategy: Direct Mapping
JAX (Core)
API: —
Strategy: Plugin (context_to_function_wrap)
TensorFlow
API: tf.stop_gradient
Strategy: Direct Mapping
Flax NNX
API: —
Strategy: Plugin (context_to_function_wrap)
PaxML / Praxis
API: —
Strategy: Plugin (context_to_function_wrap)