I want to use a pretrained model as the encoder part in my model. You can find a version of my model:
class MyClass(nn.Module):
def __init__(self, pretrained=False):
super(MyClass, self).__init__()
self.encoder=S3D_featureExtractor_multi_output()
if pretrained:
weight_dict=torch.load(os.path.join('models','weights.pt'))
model_dict=self.encoder.state_dict()
list_weight_dict=list(weight_dict.items())
list_model_dict=list(model_dict.items())
for i in range(len(list_model_dict)):
assert list_model_dict[i][1].shape==list_weight_dict[i][1].shape
model_dict[list_model_dict[i][0]].copy_(weight_dict[list_weight_dict[i][0]])
for i in range(len(list_model_dict)):
assert torch.all(torch.eq(model_dict[list_model_dict[i][0]],weight_dict[list_weight_dict[i][0]].to('cpu')))
print('Loading finished!')
def forward(self, x):
a, b = self.encoder(x)
return a, b
Because I modified some parts of the code of this pretrained model, based on this post I need to apply strict=False
to avoid facing error, but based on the scenario that I load the pretrained weights, I cannot find a place in the code to apply strict=False. How can I apply that or how can I change the scenario of loading the pretrained model taht makes it possible to apply strict=False
?
from Loading a modified pretrained model using strict=False in PyTorch
No comments:
Post a Comment