EvalShapeΒΆ

Computes the shape/dtype of a function without performing actual computation.

Abstract Signature:

EvalShape(f: Callable, args)

PyTorch

API: torch.func.eval_shape
Strategy: Direct Mapping

JAX (Core)

API: jax.eval_shape
Strategy: Direct Mapping

TensorFlow

API: β€”
Strategy: Custom / Partial

Flax NNX

API: flax.nnx.eval_shape
Strategy: Direct Mapping