ml_switcheroo.plugins.device_checks¶

Plugin for translating Device Availability Checks.

This module maps framework-specific availability checks (e.g., torch.cuda.is_available()) to the target framework’s equivalent by querying the active FrameworkAdapter.

Decoupling: Instead of hardcoding JAX or TensorFlow logic, this plugin delegates syntax generation to adapter.get_device_check_syntax().

Functions¶

transform_cuda_check(→ libcst.BaseExpression)

Plugin Hook: Transforms CUDA availability check.

Module Contents¶

ml_switcheroo.plugins.device_checks.transform_cuda_check(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) → libcst.BaseExpression[source]¶

Plugin Hook: Transforms CUDA availability check.

Triggers:

torch.cuda.is_available() via ‘cuda_is_available’ plugin key.

Transformation:
Input: torch.cuda.is_available()

Output (JAX): len(jax.devices('gpu')) > 0 Output (Keras): len(keras.config.list_logical_devices('GPU')) > 0 Output (NumPy): False

Parameters:
  • node – The original CST Call node.

  • ctx – HookContext for target framework access.

Returns:

A CST Expression representing the boolean check.