EmbeddingBagΒΆ

Computes sums or means of bags of embeddings, without instantiating the intermediate embeddings.

Abstract Signature:

EmbeddingBag(input: Tensor, weight: Tensor, offsets: Tensor, max_norm: float, norm_type: float = 2.0, scale_grad_by_freq: bool = False, mode: str = mean, sparse: bool = False, per_sample_weights: Tensor, include_last_offset: bool = False, padding_idx: int)

PyTorch

API: torch.embedding_bag
Strategy: Direct Mapping

JAX (Core)

API: β€”
Strategy: Plugin (embedding_bag)

Keras

API: β€”
Strategy: Plugin (embedding_bag_keras)

TensorFlow

API: tf.nn.embedding_lookup_sparse
Strategy: Plugin (embedding_bag_tf)

Apple MLX

API: β€”
Strategy: Plugin (embedding_bag_mlx)

Flax NNX

API: β€”
Strategy: Plugin (embedding_bag)