TrainModeΒΆ

Puts the model or node into training mode (enabling dropout, batch norm updating).

Abstract Signature:

TrainMode(node)

PyTorch

API: β€”
Strategy: Macro '{node}.train()'

JAX (Core)

API: β€”
Strategy: Custom / Partial

Keras

API: β€”
Strategy: Macro 'setattr({node}, 'trainable', True) or {node}'

TensorFlow

API: β€”
Strategy: Custom / Partial

Apple MLX

API: β€”
Strategy: Custom / Partial

Flax NNX

API: flax.nnx.train_mode
Strategy: Direct Mapping