CustomVjpΒΆ

Define custom Vision-Jacobian Product (backward pass) for valid differentiation.

Abstract Signature:

CustomVjp(fun: Callable, nondiff_argnums: tuple = ())

PyTorch

API: torch.autograd.function.Function
Strategy: Direct Mapping

JAX (Core)

API: jax.custom_vjp
Strategy: Direct Mapping

TensorFlow

API: tf.custom_gradient
Strategy: Direct Mapping

Flax NNX

API: flax.nnx.custom_vjp
Strategy: Direct Mapping

PaxML / Praxis

API: jax.custom_vjp
Strategy: Direct Mapping