I want to calculate the Hessian matrix of a loss w.r.t. model parameters in PyTorch, but using torch.autograd.functional.hessian
is not an option for me since it recomputes the model output and loss which I already have from previous calls. My current implementation is as follows:
import torch
import time
# Create model
model = torch.nn.Sequential(torch.nn.Linear(1, 100), torch.nn.Tanh(), torch.nn.Linear(100, 1))
num_param = sum(p.numel() for p in model.parameters())
# Evaluate some loss on a random dataset
x = torch.rand((1000,1))
y = torch.rand((1000,1))
y_hat = model(x)
loss = ((y_hat - y)**2).mean()
''' Calculate Hessian '''
start = time.time()
# Allocate Hessian size
H = torch.zeros((num_param, num_param))
# Calculate Jacobian w.r.t. model parameters
J = torch.autograd.grad(loss, list(model.parameters()), create_graph=True)
J = torch.cat([e.flatten() for e in J]) # flatten
# Fill in Hessian
for i in range(num_param):
result = torch.autograd.grad(J[i], list(model.parameters()), retain_graph=True)
H[i] = torch.cat([r.flatten() for r in result]) # flatten
print(time.time() - start)
Is there any way to do this faster? Perhaps without using the for loop, since it is calling autograd.grad
for every single model variable.
from Fast way to calculate Hessian matrix of model parameters in PyTorch
No comments:
Post a Comment