ml_switcheroo.plugins.device_checks =================================== .. py:module:: ml_switcheroo.plugins.device_checks .. autoapi-nested-parse:: Plugin for translating Device Availability Checks. Maps `torch.cuda.is_available()` to JAX's `len(jax.devices('gpu')) > 0`. Functions --------- .. autoapisummary:: ml_switcheroo.plugins.device_checks.transform_cuda_check Module Contents --------------- .. py:function:: transform_cuda_check(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) -> libcst.BaseExpression Plugin Hook: Transforms CUDA availability check. Triggers: `torch.cuda.is_available()` via 'cuda_is_available' plugin key. Transformation: Input: `torch.cuda.is_available()` Output: `len(jax.devices('gpu')) > 0` Note: usage of 'gpu' backend string is hardcoded as per standard JAX idiom for CUDA.