Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

training problem #11

Open
meroluo opened this issue Mar 13, 2019 · 1 comment
Open

training problem #11

meroluo opened this issue Mar 13, 2019 · 1 comment

Comments

@meroluo
Copy link

meroluo commented Mar 13, 2019

Thank you for your wonderful work!
I want to train a model with my own dataset, but there are something wrong in the process.
The error is described as below:

===> Loading datasets
===> Building model
/home/luomeilu/anaconda3/envs/py2/lib/python2.7/site-packages/torch/nn/_reduction.py:49: UserWarning: size_average and reduce args will be deprecated, please use reduction='sum' instead.
warnings.warn(warning.format(ret))
===> Setting GPU
===> load model model/model_epoch_28.pth
/home/luomeilu/anaconda3/envs/py2/lib/python2.7/site-packages/torch/serialization.py:435: SourceChangeWarning: source code of class 'torch.nn.parallel.data_parallel.DataParallel' has changed. you can retrieve the original source code by accessing the object's source attribute or set torch.nn.Module.dump_patches = True and use the patch tool to revert the changes.
warnings.warn(msg, SourceChangeWarning)
/home/luomeilu/anaconda3/envs/py2/lib/python2.7/site-packages/torch/serialization.py:435: SourceChangeWarning: source code of class 'torch.nn.modules.conv.Conv2d' has changed. you can retrieve the original source code by accessing the object's source attribute or set torch.nn.Module.dump_patches = True and use the patch tool to revert the changes.
warnings.warn(msg, SourceChangeWarning)
/home/luomeilu/anaconda3/envs/py2/lib/python2.7/site-packages/torch/serialization.py:435: SourceChangeWarning: source code of class 'torch.nn.modules.activation.ReLU' has changed. you can retrieve the original source code by accessing the object's source attribute or set torch.nn.Module.dump_patches = True and use the patch tool to revert the changes.
warnings.warn(msg, SourceChangeWarning)
Traceback (most recent call last):
File "main.py", line 128, in
main()
File "main.py", line 69, in main
model.load_state_dict(weights['model'].state_dict())
File "/home/luomeilu/anaconda3/envs/py2/lib/python2.7/site-packages/torch/nn/modules/module.py", line 769, in load_state_dict
self.class.name, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for DRRN:
Missing key(s) in state_dict: "input.weight", "conv1.weight", "conv2.weight", "output.weight".
Unexpected key(s) in state_dict: "module.input.weight", "module.conv1.weight", "module.conv2.weight", "module.output.weight".

It seems that the parameters in model are miss? I can't understand the error, hoping you can give me some suggestions. Sincerely appreciate for your reply.

@xsacha
Copy link

xsacha commented May 10, 2019

You need to remove module from the state_dict because it was trained with data parallel.
You can do model = model.module to achieve the same thing, I believe.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants