MapStateΒΆ

Maps a function over a State object.

Abstract Signature:

MapState(f: Callable, state: State)

JAX (Core)

API: jax.tree_util.tree_map
Strategy: Direct Mapping

Flax NNX

API: flax.nnx.map_state
Strategy: Direct Mapping