ForiLoopΒΆ

Functional loop from lower to upper bound.

Abstract Signature:

ForiLoop(lower: int, upper: int, body_fun: Callable, init_val)

PyTorch

API: β€”
Strategy: Plugin (transform_for_loop)

JAX (Core)

API: jax.lax.fori_loop
Strategy: Direct Mapping

TensorFlow

API: tf.while_loop
Strategy: Direct Mapping

Flax NNX

API: flax.nnx.fori_loop
Strategy: Direct Mapping

PaxML / Praxis

API: jax.lax.fori_loop
Strategy: Direct Mapping