ml_switcheroo.plugins.casting¶

Plugin for Type Casting Methods.

Addresses the syntax mismatch between: 1. PyTorch Shorthands: x.float(), x.long(), x.half(), etc. 2. JAX/NumPy/Array API: x.astype(dtype).

Transformation: 1. Detects calls to known shorthand methods (triaged by Rewriter via ‘type_methods’ plugin). 2. Checks if the target framework declares has_numpy_compatible_arrays in its traits. 3. Looks up the target_type in the semantic metadata for the abstract operation. 4. Queries the Semantics Manager for the target framework’s implementation of that Type. 5. Generates an .astype(…) call using the retrieved dtype API.

Functions¶

transform_casting(→ libcst.Call)

Hook: Converts shorthand casts to astype calls.

Module Contents¶

ml_switcheroo.plugins.casting.transform_casting(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) → libcst.Call[source]¶

Hook: Converts shorthand casts to astype calls.

Logic:
  1. Verify target framework supports numpy array semantics (via Traits).

  2. Access the Abstract Operation ID (e.g., ‘CastFloat’).

  3. Look up the ‘target_type’ metadata (e.g., ‘Float32’).

  4. Look up the target framework’s API for ‘Float32’ (e.g., ‘jax.numpy.float32’).

  5. Rewrite x.foo() -> x.astype(jax.numpy.float32).

Parameters:
  • node (cst.Call) – The source call node.

  • ctx (HookContext) – Context providing access to the Knowledge Base.

Returns:

The transformed node utilizing .astype().

Return type:

cst.Call