FromFlatStateΒΆ

Convert flat state object into State object.

Abstract Signature:

FromFlatState(flat_state: Mapping)

PyTorch

API: model.load_state_dict
Strategy: Direct Mapping

JAX (Core)

API: flax.serialization.from_state_dict
Strategy: Direct Mapping

Flax NNX

API: flax.nnx.from_flat_state
Strategy: Direct Mapping