ml_switcheroo.frameworks.common.jax_stack

JAX Stack Common Logic (Level 0 & Level 1).

This module provides the JAXStackMixin, a reusable base for any Framework Adapter built on top of the JAX ecosystem (e.g., Flax, PaxML, Haiku).

It standardizes: 1. Level 0 (Core): JIT compilation templates, Device syntax (jax.devices),

and Array API mappings (jax.numpy).

  1. Level 1 (Common Libs): - Optax: Optimization primitives and loss functions. - Orbax: Checkpointing and Serialization.

Usage:
class MyJaxFramework(JAXStackMixin):
def apply_wiring(self, snapshot):

self._apply_stack_wiring(snapshot) # … custom logic …

Classes

JAXStackMixin

Mixin providing shared implementations for JAX ecosystem adapters.

Module Contents

class ml_switcheroo.frameworks.common.jax_stack.JAXStackMixin

Mixin providing shared implementations for JAX ecosystem adapters.

This ensures consistent translation for: - Optimization (Torch Optimizers -> Optax Factory Functions). - Serialization (Torch Save/Load -> Orbax Checkpointing). - Device Management (Torch Device -> JAX Devices). - Test Configuration (Gen-Tests templates).

property jax_test_config: Dict[str, str]

Returns standard JAX test generation templates using JIT wrapping.

get_device_syntax(device_type: str, device_index: str | None = None) str

Returns JAX-compliant syntax for device specification.

Maps ‘cuda’/’gpu’ to ‘gpu’ backend. Maps ‘cpu’ to ‘cpu’ backend.

Parameters:
  • device_type – String literal or variable representing device type (e.g., “‘cuda’”).

  • device_index – Optional index string (e.g., “0”).

Returns:

jax.devices(‘gpu’)[0].

Return type:

Python code string constructing the device object

get_serialization_imports() List[str]

Returns standard imports for JAX serialization via Orbax.

get_serialization_syntax(op: str, file_arg: str, object_arg: str | None = None) str

Returns Orbax syntax for save/load operations.

Parameters:
  • op – Operation name (‘save’ or ‘load’).

  • file_arg – Path to checkpoint directory.

  • object_arg – The PyTree to save (required for save).

Returns:

Python code string.