LinearGeneralΒΆ

Flexible linear layer with axis specification.

Abstract Signature:

LinearGeneral(in_features: int, out_features: int, axis: Union[int, Tuple[int]])

PyTorch

API: β€”
Strategy: Plugin (linear_general_adapter)

Keras

API: keras.layers.Dense
Strategy: Direct Mapping

Flax NNX

API: flax.nnx.LinearGeneral
Strategy: Direct Mapping

PaxML / Praxis

API: praxis.layers.Linear
Strategy: Direct Mapping