VjpΒΆ

Compute the vector-Jacobian product.

Abstract Signature:

Vjp(fun: Callable, primals: List[Tensor], cotangents: List[Tensor])

PyTorch

API: torch.autograd.functional.vjp
Strategy: Direct Mapping

JAX (Core)

API: β€”
Strategy: Macro 'lambda f, p, c: (lambda out, vjp_fn: (out, vjp_fn(c)))(*jax.vjp(f, *p))'

Apple MLX

API: mlx.core.vjp
Strategy: Direct Mapping

Flax NNX

API: β€”
Strategy: Macro 'lambda f, p, c: (lambda out, vjp_fn: (out, vjp_fn(c)))(*jax.vjp(f, *p))'