CachedPartialΒΆ

Create a partial from a NNX transformed function.

Abstract Signature:

CachedPartial(f: Callable)

PyTorch

API: functools.partial
Strategy: Direct Mapping

JAX (Core)

API: jax.tree_util.Partial
Strategy: Direct Mapping

Flax NNX

API: flax.nnx.cached_partial
Strategy: Direct Mapping