ml_switcheroo.frameworks.common.jax_stack¶
JAX Stack Common Logic (Level 0 & Level 1).
This module provides the JAXStackMixin, a reusable base for any Framework Adapter built on top of the JAX ecosystem (e.g., Flax, PaxML, Haiku).
It standardizes: 1. Level 0 (Core): JIT compilation templates, Device syntax (jax.devices),
and Array API mappings (jax.numpy).
Level 1 (Common Libs): - Optax: Optimization primitives and loss functions. - Orbax: Checkpointing and Serialization.
- Usage:
- class MyJaxFramework(JAXStackMixin):
- def apply_wiring(self, snapshot):
self._apply_stack_wiring(snapshot) # … custom logic …
Classes¶
Mixin providing shared implementations for JAX ecosystem adapters. |
Module Contents¶
- class ml_switcheroo.frameworks.common.jax_stack.JAXStackMixin¶
Mixin providing shared implementations for JAX ecosystem adapters.
This ensures consistent translation for: - Optimization (Torch Optimizers -> Optax Factory Functions). - Serialization (Torch Save/Load -> Orbax Checkpointing). - Device Management (Torch Device -> JAX Devices). - Test Configuration (Gen-Tests templates).
- property jax_test_config: Dict[str, str]¶
Returns standard JAX test generation templates using JIT wrapping.
- get_device_syntax(device_type: str, device_index: str | None = None) str¶
Returns JAX-compliant syntax for device specification.
Maps ‘cuda’/’gpu’ to ‘gpu’ backend. Maps ‘cpu’ to ‘cpu’ backend.
- Parameters:
device_type – String literal or variable representing device type (e.g., “‘cuda’”).
device_index – Optional index string (e.g., “0”).
- Returns:
jax.devices(‘gpu’)[0].
- Return type:
Python code string constructing the device object
- get_serialization_imports() List[str]¶
Returns standard imports for JAX serialization via Orbax.
- get_serialization_syntax(op: str, file_arg: str, object_arg: str | None = None) str¶
Returns Orbax syntax for save/load operations.
- Parameters:
op – Operation name (‘save’ or ‘load’).
file_arg – Path to checkpoint directory.
object_arg – The PyTree to save (required for save).
- Returns:
Python code string.