Monday, 21 March 2022

How to use Real-World-Weight Cross-Entropy loss in PyTorch

I'm working on multiclass classification where some mistakes are more severe than others. Therefore, I would like to incorporate the costs into my loss function. I found this under the name Real-World-Weight Cross-Entropy, described in this paper. The formula goes as below:

enter image description here

I haven't find any ready-to-use implementation, apart from weight argument of standard CrossEntropyLoss, which I believe works quite different to my use-case (as far as I understand the cost of incorrectly classifying one category is the same no matter with which category it was confused).

How can I apply this in PyTorch?

import torch.nn as nn
import torch

loss = nn.CrossEntropyLoss()
input = torch.randn(3, 5, requires_grad=True)
target = torch.empty(3, dtype=torch.long).random_(5)

cost_matrix = torch.zeros((5, 5))
cost_matrix[1, 0] = 0.4
cost_matrix[2, 0] = 0.1
cost_matrix[2, 1] = 0.9
cost_matrix[3, 0] = 0.4
cost_matrix[3, 1] = 0.9
cost_matrix[3, 2] = 0.1
cost_matrix[4, 0] = 0.1
cost_matrix[4, 1] = 0.4
cost_matrix[4, 2] = 0.9
cost_matrix[4, 3] = 0.1
cost_matrix[0, 1] = 0.4
cost_matrix[0, 2] = 0.1
cost_matrix[1, 2] = 0.9
cost_matrix[0, 3] = 0.4
cost_matrix[1, 3] = 0.9
cost_matrix[2, 3] = 0.1
cost_matrix[0, 4] = 0.1
cost_matrix[1, 4] = 0.4
cost_matrix[2, 4] = 0.9
cost_matrix[3, 4] = 0.1
 


from How to use Real-World-Weight Cross-Entropy loss in PyTorch

No comments:

Post a Comment