이 코드로 PyTorch 모델을 저장하려고 할 때 :
checkpoint = {'model': Net(), 'state_dict': model.state_dict(),'optimizer' :optimizer.state_dict()}
torch.save(checkpoint, 'Checkpoint.pth')
다음과 같은 오류가 발생합니다.
E:\PROGRAM FILES\Anaconda\envs\staj_projesi\lib\site-packages\torch\serialization.py:251: UserWarning: Couldn't retrieve source code for container of type Net. It won't be checked for correctness upon loading.
...
"type " + obj.__name__ + ". It won't be checked "
Can't pickle local object 'trainModel.<locals>.Net'
이 코드로 PyTorch 모델을 저장하려고 할 때 :
checkpoint = {'state_dict': model.state_dict(),'optimizer' :optimizer.state_dict()}
torch.save(checkpoint, 'Checkpoint.pth')
오류가 발생하지 않지만 ANN 클래스를 저장하고 싶습니다. 이 문제를 어떻게 해결할 수 있습니까? 또한 이전에 다른 프로젝트에서 첫 번째 구조로 모델을 저장할 수 있습니다.
당신은 할 수 없습니다! torch.save
개체 state_dict()
만 저장하는 것입니다.
다음을 사용하는 경우 :
checkpoint = {'model': Net(), 'state_dict': model.state_dict(),'optimizer' :optimizer.state_dict()}
torch.save(checkpoint, 'Checkpoint.pth')
모델 자체를 저장하려고하지만이 데이터는에 저장되며 model.state_dict()
모델을로드 할 때 state_dict
먼저 모델 개체를 시작해야합니다.
이것이 바로 두 번째 방법이 제대로 작동하는 이유입니다.
checkpoint = {'state_dict': model.state_dict(),'optimizer' :optimizer.state_dict()}
torch.save(checkpoint, 'Checkpoint.pth')
다음 링크에서 모델을 올바르게 저장 /로드하는 방법에 대한 pytorch 문서를 읽는 것이 좋습니다. https://pytorch.org/tutorials/beginner/saving_loading_models.html
이 기사는 인터넷에서 수집됩니다. 재 인쇄 할 때 출처를 알려주십시오.
침해가 발생한 경우 연락 주시기 바랍니다[email protected] 삭제
몇 마디 만하겠습니다