ValueAndGradΒΆ
Computes the value and gradient of a function.
Abstract Signature:
ValueAndGrad(model, fn: Callable)
PyTorch
API:
βStrategy: Plugin (torch_functional_grad)
JAX (Core)
API:
βStrategy: Plugin (jax_value_and_grad_wrapper)
Keras
API:
βStrategy: Plugin (keras_grad_tape)
TensorFlow
API:
βStrategy: Plugin (tf_grad_tape)
PaxML / Praxis
API:
βStrategy: Plugin (pax_grad_wrapper)