PyTorch
API: —
Strategy: Plugin (torch_functional_grad)
JAX (Core)
API: —
Strategy: Plugin (jax_value_and_grad_wrapper)
Keras
API: —
Strategy: Plugin (keras_grad_tape)
TensorFlow
API: —
Strategy: Plugin (tf_grad_tape)
Apple MLX
API: mlx.nn.value_and_grad
Strategy: Direct Mapping
Flax NNX
API: nnx.value_and_grad
Strategy: Direct Mapping
PaxML / Praxis
API: —
Strategy: Plugin (pax_grad_wrapper)