Thursday, 29 July 2021

Trouble with minimal hvp on pytorch model

While autograd's hvp tool seems to work very well for functions, once a model becomes involved, Hessian-vector products seem to go to 0. Some code.

First, I define the world's simplest model:

class SimpleMLP(nn.Module):
  def __init__(self, in_dim, out_dim):
      super().__init__()
      self.layers = nn.Sequential(
        nn.Linear(in_dim, out_dim),
      )
      
  def forward(self, x):
    '''Forward pass'''
    return self.layers(x)

Then, a loss function:

def objective(x):
  return torch.sum(0.25 * torch.sum(x)**4)

We instantiate it:

Arows = 2
Acols = 2

mlp = SimpleMLP(Arows, Acols)

Finally, I'm going to define a "forward" function (distinct from the model's forward function) that will serve as the the full model+loss that we want to analyze:

def forward(*params_list):
  for param_val, model_param in zip(params_list, mlp.parameters()):
    model_param.data = param_val
 
  x = torch.ones((Arows,))
  return objective(mlp(x))

This passes a ones vector into the single-layer "mlp," and passes it into our quadratic loss.

Now, I attempt to compute:

v = torch.ones((6,))
v_tensors = []
idx = 0
#this code "reshapes" the v vector as needed
for i, param in enumerate(mlp.parameters()):
  numel = param.numel()
  v_tensors.append(torch.reshape(torch.tensor(v[idx:idx+numel]), param.shape))
  idx += numel

And finally:

param_tensors = tuple(mlp.parameters())
reshaped_v = tuple(v_tensors)
soln =  torch.autograd.functional.hvp(forward, param_tensors, v=reshaped_v)

But, alas, the Hessian-Vector Product in soln is all 0's. What is happening?



from Trouble with minimal hvp on pytorch model

No comments:

Post a Comment