ml_switcheroo.plugins.casting

Plugin for Type Casting Methods.

Addresses the syntax mismatch between: 1. PyTorch Shorthands: x.float(), x.long(), x.half(), etc. 2. JAX/NumPy/Array API: x.astype(dtype).

Transformation: 1. Detects calls to known shorthand methods. 2. Changes the method name to astype. 3. Injects the corresponding JAX/Numpy dtype object as the argument.

Mappings:

.float() -> .astype(jnp.float32) .double() -> .astype(jnp.float64) .half() -> .astype(jnp.float16) .long() -> .astype(jnp.int64) .int() -> .astype(jnp.int32) .bool() -> .astype(jnp.bool_) .byte() -> .astype(jnp.uint8)

Attributes

TYPE_MAP

Functions

transform_casting(→ libcst.Call)

Hook: Converts shorthand casts to astype calls.

Module Contents

ml_switcheroo.plugins.casting.TYPE_MAP
ml_switcheroo.plugins.casting.transform_casting(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) libcst.Call

Hook: Converts shorthand casts to astype calls.

Trigger: Operations mapped to ‘Cast’ category or specific methods with requires_plugin: “type_methods”.