ml_switcheroo.plugins.casting ============================= .. py:module:: ml_switcheroo.plugins.casting .. autoapi-nested-parse:: Plugin for Type Casting Methods. Addresses the syntax mismatch between: 1. PyTorch Shorthands: `x.float()`, `x.long()`, `x.half()`, etc. 2. JAX/NumPy/Array API: `x.astype(dtype)`. Transformation: 1. Detects calls to known shorthand methods. 2. Changes the method name to `astype`. 3. Injects the corresponding JAX/Numpy dtype object as the argument. Mappings: .float() -> .astype(jnp.float32) .double() -> .astype(jnp.float64) .half() -> .astype(jnp.float16) .long() -> .astype(jnp.int64) .int() -> .astype(jnp.int32) .bool() -> .astype(jnp.bool_) .byte() -> .astype(jnp.uint8) Attributes ---------- .. autoapisummary:: ml_switcheroo.plugins.casting.TYPE_MAP Functions --------- .. autoapisummary:: ml_switcheroo.plugins.casting.transform_casting Module Contents --------------- .. py:data:: TYPE_MAP .. py:function:: transform_casting(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) -> libcst.Call Hook: Converts shorthand casts to astype calls. Trigger: Operations mapped to 'Cast' category or specific methods with `requires_plugin: "type_methods"`.