Thursday, 25 March 2021

Why does Pytorch EmbeddingBag with mode "max" not accept `per_sample_weights`?

Pytorch's EmbeddingBag allows for efficient lookup + reduce operations on varying length collections of embedding indices. There are 3 modes: "sum", "average" and "max" for the reduce operation. With "sum", you can also provide per_sample_weights giving you a weighted sum.

Why is per_sample_weights not allowed for the "max" operation? Looking at how it's implemented, I can only assume there is an issue with performing a "ReduceMean" or "ReduceMax" operation after a "Mul" operation. Could that be something to do with calculating gradients??


p.s: It's easy enough to turn a weighted sum into a weighted average by dividing by the sum of the weights, but for "max" you can't get a weighted equivalent like that.



from Why does Pytorch EmbeddingBag with mode "max" not accept `per_sample_weights`?

No comments:

Post a Comment