ml_switcheroo.plugins.casting¶
Plugin for Type Casting Methods.
Addresses the syntax mismatch between: 1. PyTorch Shorthands: x.float(), x.long(), x.half(), etc. 2. JAX/NumPy/Array API: x.astype(dtype).
Transformation: 1. Detects calls to known shorthand methods. 2. Changes the method name to astype. 3. Injects the corresponding JAX/Numpy dtype object as the argument.
- Mappings:
.float() -> .astype(jnp.float32) .double() -> .astype(jnp.float64) .half() -> .astype(jnp.float16) .long() -> .astype(jnp.int64) .int() -> .astype(jnp.int32) .bool() -> .astype(jnp.bool_) .byte() -> .astype(jnp.uint8)
Attributes¶
Functions¶
|
Hook: Converts shorthand casts to astype calls. |
Module Contents¶
- ml_switcheroo.plugins.casting.TYPE_MAP¶
- ml_switcheroo.plugins.casting.transform_casting(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) libcst.Call¶
Hook: Converts shorthand casts to astype calls.
Trigger: Operations mapped to ‘Cast’ category or specific methods with requires_plugin: “type_methods”.