SetDefaultDtypeΒΆ

Sets the default floating point dtype.

Abstract Signature:

SetDefaultDtype(d: DType)

PyTorch

API: torch.set_default_dtype
Strategy: Direct Mapping

JAX (Core)

API: β€”
Strategy: Macro 'jax.config.update('jax_enable_x64', {d} == jnp.float64)'

Keras

API: keras.config.set_floatx
Strategy: Direct Mapping