RecursiveMapΒΆ

Recursively applies a function to all nodes and leaves.

Abstract Signature:

RecursiveMap(f: Callable, node)

PyTorch

API: torch.nn.modules.module.Module.apply
Strategy: Direct Mapping

Flax NNX

API: flax.nnx.recursive_map
Strategy: Direct Mapping