VjpΒΆ
Compute the vector-Jacobian product.
Abstract Signature:
Vjp(fun: Callable, primals: List[Tensor], cotangents: List[Tensor])
JAX (Core)
API:
βStrategy: Macro 'lambda f, p, c: (lambda out, vjp_fn: (out, vjp_fn(c)))(*jax.vjp(f, *p))'
Flax NNX
API:
βStrategy: Macro 'lambda f, p, c: (lambda out, vjp_fn: (out, vjp_fn(c)))(*jax.vjp(f, *p))'