ml_switcheroo.frameworks.common.jax_stack

JAX Stack Common Logic (Level 0 & Level 1).

This module provides the JAXStackMixin, a reusable base for any Framework Adapter built on top of the JAX ecosystem (e.g., Flax, PaxML, Haiku).

It standardizes:

  1. Level 0 (Core): JIT compilation templates, Device syntax (jax.devices), and Array API mappings (jax.numpy).

  2. Level 1 (Common Libs):

    • Optax: Optimization primitives and loss functions.

    • Orbax: Checkpointing and Serialization.

Usage:
class MyJaxFramework(JAXStackMixin):

# … logic …

Classes

JAXStackMixin

Mixin providing shared implementations for JAX ecosystem adapters.

Module Contents

class ml_switcheroo.frameworks.common.jax_stack.JAXStackMixin[source]

Mixin providing shared implementations for JAX ecosystem adapters.

This ensures consistent translation for:

  • Optimization (Torch Optimizers -> Optax Factory Functions).

  • Serialization (Torch Save/Load -> Orbax Checkpointing).

  • Device Management (Torch Device -> JAX Devices).

  • Test Configuration (Gen-Tests templates).

  • Verification Normalization (JAX Array -> NumPy).

  • Weight Migration (Loading/Saving checkpoints via Orbax).

property jax_test_config: Dict[str, str]

Returns standard JAX test generation templates using JIT wrapping.

Defines: - import: Libraries to import (including opt-in Chex support). - convert_input: Syntax to convert Numpy array to JAX array. - to_numpy: Identity transform (preserves PyTrees for Chex comparison). - jit_template: Detailed JAX JIT syntax with static argument support.

get_to_numpy_code() str[source]

Returns logic to convert JAX arrays to NumPy. Checks for __array__ protocol which JAX arrays implement.

get_device_syntax(device_type: str, device_index: str | None = None) str[source]

Returns JAX-compliant syntax for device specification.

Maps ‘cuda’/’gpu’ to ‘gpu’ backend. Maps ‘cpu’ to ‘cpu’ backend.

Parameters:
  • device_type – String literal or variable representing device type (e.g., “‘cuda’”).

  • device_index – Optional index string (e.g., “0”).

Returns:

jax.devices('gpu')[0].

Return type:

Python code string constructing the device object

get_device_check_syntax() str[source]

Returns JAX syntax for checking if GPUs are available. Format: len(jax.devices('gpu')) > 0

get_rng_split_syntax(rng_var: str, key_var: str) str[source]

Returns JAX syntax for splitting a PRNG key. Format: rng, key = jax.random.split(rng)

get_serialization_imports() List[str][source]

Returns standard imports for JAX serialization via Orbax.

get_serialization_syntax(op: str, file_arg: str, object_arg: str | None = None) str[source]

Returns Orbax syntax for save/load operations.

Parameters:
  • op – Operation name (‘save’ or ‘load’).

  • file_arg – Path to checkpoint directory.

  • object_arg – The PyTree to save (required for save).

Returns:

Python code string.

get_weight_conversion_imports() List[str][source]

Returns imports required for the generated weight migration script.

get_weight_load_code(path_var: str) str[source]

Returns python code to load a checkpoint from path_var into a variable named raw_state. The raw_state is a flat dictionary where keys are dot-separated strings (e.g. ‘layer.weight’).

get_tensor_to_numpy_expr(tensor_var: str) str[source]

Returns a python expression string that converts tensor_var from JAX array to numpy array.

get_weight_save_code(state_var: str, path_var: str) str[source]

Returns python code to save the dictionary state_var (mapping flat keys to numpy arrays) to path_var. It unstricts flat keys back to PyTree structure using unflatten_dict and saves via Orbax.

get_doc_url(api_name: str) str | None[source]

Generates a default documentation URL for standard JAX APIs. Maps to ReadTheDocs autosummary path. NOTE: Subclasses (Flax/Pax) should override this for their specific namespaces.

Parameters:

api_name – The fully qualified API path (e.g. ‘jax.numpy.abs’).

Returns:

String URL or None.