ml_switcheroo.plugins.device_checks¶
Plugin for translating Device Availability Checks.
Maps torch.cuda.is_available() to JAX’s len(jax.devices(‘gpu’)) > 0.
Functions¶
|
Plugin Hook: Transforms CUDA availability check. |
Module Contents¶
- ml_switcheroo.plugins.device_checks.transform_cuda_check(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) libcst.BaseExpression¶
Plugin Hook: Transforms CUDA availability check.
- Triggers:
torch.cuda.is_available() via ‘cuda_is_available’ plugin key.
- Transformation:
Input: torch.cuda.is_available() Output: len(jax.devices(‘gpu’)) > 0
Note: usage of ‘gpu’ backend string is hardcoded as per standard JAX idiom for CUDA.