VMapΒΆ

Vectorizing map. Creates a function that maps f over array axes.

Abstract Signature:

VMap(f: Callable, in_axes = 0, out_axes = 0)

PyTorch

API: torch.vmap
Strategy: Direct Mapping

JAX (Core)

API: jax.vmap
Strategy: Direct Mapping

TensorFlow

API: tf.vectorized_map
Strategy: Plugin (vmap_adapter)

Apple MLX

API: mlx.core.vmap
Strategy: Direct Mapping

Flax NNX

API: flax.nnx.vmap
Strategy: Direct Mapping