PyTorch
API: torch.addbmm
Strategy: Direct Mapping
JAX (Core)
API: jax.numpy.sum
Strategy: Macro '{beta} * {input} + {alpha} * jax.numpy.sum(jax.numpy.matmul({batch1}, {batch2}), axis=0)'
NumPy
API: numpy.matmul
Strategy: Macro '{beta} * {input} + {alpha} * numpy.sum(numpy.matmul({batch1}, {batch2}), axis=0)'
Keras
API: keras.ops.sum
Strategy: Macro '{beta} * {input} + {alpha} * keras.ops.sum(keras.ops.matmul({batch1}, {batch2}), axis=0)'
TensorFlow
API: tf.linalg.matmul
Strategy: Macro '{beta} * {input} + {alpha} * tf.reduce_sum(tf.linalg.matmul({batch1}, {batch2}), axis=0)'
Apple MLX
API: mlx.core.matmul
Strategy: Macro '{beta} * {input} + {alpha} * mlx.core.sum(mlx.core.matmul({batch1}, {batch2}), axis=0)'
Flax NNX
API: jax.numpy.sum
Strategy: Macro '{beta} * {input} + {alpha} * jax.numpy.sum(jax.numpy.matmul({batch1}, {batch2}), axis=0)'