ml_switcheroo.frameworks.common.jax_stack¶
JAX Stack Common Logic (Level 0 & Level 1).
This module provides the JAXStackMixin, a reusable base for any Framework Adapter
built on top of the JAX ecosystem (e.g., Flax, PaxML, Haiku).
It standardizes:
Level 0 (Core): JIT compilation templates, Device syntax (
jax.devices), and Array API mappings (jax.numpy).Level 1 (Common Libs):
Optax: Optimization primitives and loss functions.
Orbax: Checkpointing and Serialization.
- Usage:
- class MyJaxFramework(JAXStackMixin):
# … logic …
Classes¶
Mixin providing shared implementations for JAX ecosystem adapters. |
Module Contents¶
- class ml_switcheroo.frameworks.common.jax_stack.JAXStackMixin[source]¶
Mixin providing shared implementations for JAX ecosystem adapters.
This ensures consistent translation for:
Optimization (Torch Optimizers -> Optax Factory Functions).
Serialization (Torch Save/Load -> Orbax Checkpointing).
Device Management (Torch Device -> JAX Devices).
Test Configuration (Gen-Tests templates).
Verification Normalization (JAX Array -> NumPy).
Weight Migration (Loading/Saving checkpoints via Orbax).
- property jax_test_config: Dict[str, str]¶
Returns standard JAX test generation templates using JIT wrapping.
Defines: - import: Libraries to import (including opt-in Chex support). - convert_input: Syntax to convert Numpy array to JAX array. - to_numpy: Identity transform (preserves PyTrees for Chex comparison). - jit_template: Detailed JAX JIT syntax with static argument support.
- get_to_numpy_code() str[source]¶
Returns logic to convert JAX arrays to NumPy. Checks for __array__ protocol which JAX arrays implement.
- get_device_syntax(device_type: str, device_index: str | None = None) str[source]¶
Returns JAX-compliant syntax for device specification.
Maps ‘cuda’/’gpu’ to ‘gpu’ backend. Maps ‘cpu’ to ‘cpu’ backend.
- Parameters:
device_type – String literal or variable representing device type (e.g., “‘cuda’”).
device_index – Optional index string (e.g., “0”).
- Returns:
jax.devices('gpu')[0].- Return type:
Python code string constructing the device object
- get_device_check_syntax() str[source]¶
Returns JAX syntax for checking if GPUs are available. Format:
len(jax.devices('gpu')) > 0
- get_rng_split_syntax(rng_var: str, key_var: str) str[source]¶
Returns JAX syntax for splitting a PRNG key. Format:
rng, key = jax.random.split(rng)
- get_serialization_imports() List[str][source]¶
Returns standard imports for JAX serialization via Orbax.
- get_serialization_syntax(op: str, file_arg: str, object_arg: str | None = None) str[source]¶
Returns Orbax syntax for save/load operations.
- Parameters:
op – Operation name (‘save’ or ‘load’).
file_arg – Path to checkpoint directory.
object_arg – The PyTree to save (required for save).
- Returns:
Python code string.
- get_weight_conversion_imports() List[str][source]¶
Returns imports required for the generated weight migration script.
- get_weight_load_code(path_var: str) str[source]¶
Returns python code to load a checkpoint from path_var into a variable named raw_state. The raw_state is a flat dictionary where keys are dot-separated strings (e.g. ‘layer.weight’).
- get_tensor_to_numpy_expr(tensor_var: str) str[source]¶
Returns a python expression string that converts tensor_var from JAX array to numpy array.
- get_weight_save_code(state_var: str, path_var: str) str[source]¶
Returns python code to save the dictionary state_var (mapping flat keys to numpy arrays) to path_var. It unstricts flat keys back to PyTree structure using unflatten_dict and saves via Orbax.
- get_doc_url(api_name: str) str | None[source]¶
Generates a default documentation URL for standard JAX APIs. Maps to ReadTheDocs autosummary path. NOTE: Subclasses (Flax/Pax) should override this for their specific namespaces.
- Parameters:
api_name – The fully qualified API path (e.g. ‘jax.numpy.abs’).
- Returns:
String URL or None.