Pytorch's EmbeddingBag allows for efficient lookup + reduce operations on varying length collections of embedding indices. There are 3 modes: "sum", "average" and "max" for the reduce operation. With "sum", you can also provide per_sample_weights giving you a weighted sum.
Why is per_sample_weights not allowed for the "max" operation? Looking at how it's implemented, I can only assume there is an issue with performing a "ReduceMean" or "ReduceMax" operation after a "Mul" operation. Could that be something to do with calculating gradients??
p.s: It's easy enough to turn a weighted sum into a weighted average by dividing by the sum of the weights, but for "max" you can't get a weighted equivalent like that.
from Why does Pytorch EmbeddingBag with mode "max" not accept `per_sample_weights`?
No comments:
Post a Comment