EmbeddingBagΒΆ
Computes sums or means of bags of embeddings, without instantiating the intermediate embeddings.
Abstract Signature:
EmbeddingBag(input: Tensor, weight: Tensor, offsets: Tensor, max_norm: float, norm_type: float = 2.0, scale_grad_by_freq: bool = False, mode: str = mean, sparse: bool = False, per_sample_weights: Tensor, include_last_offset: bool = False, padding_idx: int)
JAX (Core)
API:
βStrategy: Plugin (embedding_bag)
Keras
API:
βStrategy: Plugin (embedding_bag_keras)
TensorFlow
API:
tf.nn.embedding_lookup_sparseStrategy: Plugin (embedding_bag_tf)
Apple MLX
API:
βStrategy: Plugin (embedding_bag_mlx)
Flax NNX
API:
βStrategy: Plugin (embedding_bag)