ml_switcheroo.frameworks.common.jax_stack ========================================= .. py:module:: ml_switcheroo.frameworks.common.jax_stack .. autoapi-nested-parse:: JAX Stack Common Logic (Level 0 & Level 1). This module provides the `JAXStackMixin`, a reusable base for any Framework Adapter built on top of the JAX ecosystem (e.g., Flax, PaxML, Haiku). It standardizes: 1. **Level 0 (Core)**: JIT compilation templates, Device syntax (`jax.devices`), and Array API mappings (`jax.numpy`). 2. **Level 1 (Common Libs)**: - **Optax**: Optimization primitives and loss functions. - **Orbax**: Checkpointing and Serialization. Usage: class MyJaxFramework(JAXStackMixin): def apply_wiring(self, snapshot): self._apply_stack_wiring(snapshot) # ... custom logic ... Classes ------- .. autoapisummary:: ml_switcheroo.frameworks.common.jax_stack.JAXStackMixin Module Contents --------------- .. py:class:: JAXStackMixin Mixin providing shared implementations for JAX ecosystem adapters. This ensures consistent translation for: - Optimization (Torch Optimizers -> Optax Factory Functions). - Serialization (Torch Save/Load -> Orbax Checkpointing). - Device Management (Torch Device -> JAX Devices). - **Test Configuration** (Gen-Tests templates). .. py:property:: jax_test_config :type: Dict[str, str] Returns standard JAX test generation templates using JIT wrapping. .. py:method:: get_device_syntax(device_type: str, device_index: Optional[str] = None) -> str Returns JAX-compliant syntax for device specification. Maps 'cuda'/'gpu' to 'gpu' backend. Maps 'cpu' to 'cpu' backend. :param device_type: String literal or variable representing device type (e.g., "'cuda'"). :param device_index: Optional index string (e.g., "0"). :returns: `jax.devices('gpu')[0]`. :rtype: Python code string constructing the device object .. py:method:: get_serialization_imports() -> List[str] Returns standard imports for JAX serialization via Orbax. .. py:method:: get_serialization_syntax(op: str, file_arg: str, object_arg: Optional[str] = None) -> str Returns Orbax syntax for save/load operations. :param op: Operation name ('save' or 'load'). :param file_arg: Path to checkpoint directory. :param object_arg: The PyTree to save (required for save). :returns: Python code string.