HalfΒΆ
Casts the input tensor to 16-bit floating point.
Abstract Signature:
Half(input: Tensor)
PyTorch
API:
torch.halfStrategy: Plugin (type_methods)
JAX (Core)
API:
jax.numpy.float16Strategy: Macro '{input}.astype(jax.numpy.float16)'
NumPy
API:
numpy.astypeStrategy: Macro '{input}.astype(numpy.float16)'
Keras
API:
keras.ops.castStrategy: Macro 'keras.ops.cast({input}, 'float16')'
TensorFlow
API:
tf.castStrategy: Macro 'tf.cast({input}, tf.float16)'
Apple MLX
API:
mlx.core.astypeStrategy: Macro '{input}.astype(mlx.core.float16)'