While autograd's hvp tool seems to work very well for functions, once a model becomes involved, Hessian-vector products seem to go to 0. Some code.
First, I define the world's simplest model:
class SimpleMLP(nn.Module):
def __init__(self, in_dim, out_dim):
super().__init__()
self.layers = nn.Sequential(
nn.Linear(in_dim, out_dim),
)
def forward(self, x):
'''Forward pass'''
return self.layers(x)
Then, a loss function:
def objective(x):
return torch.sum(0.25 * torch.sum(x)**4)
We instantiate it:
Arows = 2
Acols = 2
mlp = SimpleMLP(Arows, Acols)
Finally, I'm going to define a "forward" function (distinct from the model's forward function) that will serve as the the full model+loss that we want to analyze:
def forward(*params_list):
for param_val, model_param in zip(params_list, mlp.parameters()):
model_param.data = param_val
x = torch.ones((Arows,))
return objective(mlp(x))
This passes a ones vector into the single-layer "mlp," and passes it into our quadratic loss.
Now, I attempt to compute:
v = torch.ones((6,))
v_tensors = []
idx = 0
#this code "reshapes" the v vector as needed
for i, param in enumerate(mlp.parameters()):
numel = param.numel()
v_tensors.append(torch.reshape(torch.tensor(v[idx:idx+numel]), param.shape))
idx += numel
And finally:
param_tensors = tuple(mlp.parameters())
reshaped_v = tuple(v_tensors)
soln = torch.autograd.functional.hvp(forward, param_tensors, v=reshaped_v)
But, alas, the Hessian-Vector Product in soln is all 0's. What is happening?
from Trouble with minimal hvp on pytorch model
No comments:
Post a Comment