Thursday, 15 September 2022

PyTorch: Load/unpickle only the state dict of a model saved with torch.save

I saved a nn.Module model using (logically):

model = MyWeirdModel()
model.patched_features = .....
train(model)
torch.save(model, file)

Ideally, one would load this model using

model = torch.load(file)

However, in my case this doesn't work because Pickle uses the static class definition when un-pickling, so I get AttributeError: 'MyWeirdModel' object has no attribute 'patched_features' (this attribute was added to MyWeirdModel at runtime).

I would like to avoid having to re-train the model, so I don't want to change the code for saving, only loading.

# Initialise the model in the same way as before
model = MyWeirdModel()
model.patched_features = .....

state_dict = load_state_dict_only(file) # How does one do this?

model.load_state_dict(state_dict)

My understanding is that torch.save() saves the model AND the state dict. How do I load only the state dict from the pickled model, such that I can recover the model?



from PyTorch: Load/unpickle only the state dict of a model saved with torch.save

No comments:

Post a Comment