FilterStateΒΆ

Filter a State into one or more States based on types/predicates.

Abstract Signature:

FilterState(state: State, filters)

PyTorch

API: β€”
Strategy: Custom / Partial

JAX (Core)

API: flax.core.traverse_util.filter_state
Strategy: Direct Mapping

Flax NNX

API: flax.nnx.filter_state
Strategy: Direct Mapping