Monday, 26 December 2022

Fast way to calculate Hessian matrix of model parameters in PyTorch

I want to calculate the Hessian matrix of a loss w.r.t. model parameters in PyTorch, but using torch.autograd.functional.hessian is not an option for me since it recomputes the model output and loss which I already have from previous calls. My current implementation is as follows:

import torch
import time

# Create model
model = torch.nn.Sequential(torch.nn.Linear(1, 100), torch.nn.Tanh(), torch.nn.Linear(100, 1))
num_param = sum(p.numel() for p in model.parameters())

# Evaluate some loss on a random dataset
x = torch.rand((1000,1))
y = torch.rand((1000,1))
y_hat = model(x)
loss = ((y_hat - y)**2).mean()

''' Calculate Hessian '''
start = time.time()

# Allocate Hessian size
H = torch.zeros((num_param, num_param))

# Calculate Jacobian w.r.t. model parameters
J = torch.autograd.grad(loss, list(model.parameters()), create_graph=True)
J = torch.cat([e.flatten() for e in J]) # flatten

# Fill in Hessian
for i in range(num_param):
    result = torch.autograd.grad(J[i], list(model.parameters()), retain_graph=True)
    H[i] = torch.cat([r.flatten() for r in result]) # flatten

print(time.time() - start)

Is there any way to do this faster? Perhaps without using the for loop, since it is calling autograd.grad for every single model variable.



from Fast way to calculate Hessian matrix of model parameters in PyTorch

No comments:

Post a Comment