I'm working on multiclass classification where some mistakes are more severe than others. Therefore, I would like to incorporate the costs into my loss function. I found this under the name Real-World-Weight Cross-Entropy, described in this paper. The formula goes as below:
I haven't find any ready-to-use implementation, apart from weight
argument of standard CrossEntropyLoss
, which I believe works quite different to my use-case (as far as I understand the cost of incorrectly classifying one category is the same no matter with which category it was confused).
How can I apply this in PyTorch?
import torch.nn as nn
import torch
loss = nn.CrossEntropyLoss()
input = torch.randn(3, 5, requires_grad=True)
target = torch.empty(3, dtype=torch.long).random_(5)
cost_matrix = torch.zeros((5, 5))
cost_matrix[1, 0] = 0.4
cost_matrix[2, 0] = 0.1
cost_matrix[2, 1] = 0.9
cost_matrix[3, 0] = 0.4
cost_matrix[3, 1] = 0.9
cost_matrix[3, 2] = 0.1
cost_matrix[4, 0] = 0.1
cost_matrix[4, 1] = 0.4
cost_matrix[4, 2] = 0.9
cost_matrix[4, 3] = 0.1
cost_matrix[0, 1] = 0.4
cost_matrix[0, 2] = 0.1
cost_matrix[1, 2] = 0.9
cost_matrix[0, 3] = 0.4
cost_matrix[1, 3] = 0.9
cost_matrix[2, 3] = 0.1
cost_matrix[0, 4] = 0.1
cost_matrix[1, 4] = 0.4
cost_matrix[2, 4] = 0.9
cost_matrix[3, 4] = 0.1
from How to use Real-World-Weight Cross-Entropy loss in PyTorch
No comments:
Post a Comment