ml_switcheroo.plugins.device_checks¶

Plugin for translating Device Availability Checks.

Maps torch.cuda.is_available() to JAX’s len(jax.devices(‘gpu’)) > 0.

Functions¶

transform_cuda_check(→ libcst.BaseExpression)

Plugin Hook: Transforms CUDA availability check.

Module Contents¶

ml_switcheroo.plugins.device_checks.transform_cuda_check(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) → libcst.BaseExpression¶

Plugin Hook: Transforms CUDA availability check.

Triggers:

torch.cuda.is_available() via ‘cuda_is_available’ plugin key.

Transformation:

Input: torch.cuda.is_available() Output: len(jax.devices(‘gpu’)) > 0

Note: usage of ‘gpu’ backend string is hardcoded as per standard JAX idiom for CUDA.