GPU에서 훈련, num_gpus는 1로 설정됩니다.
device_ids = list(range(num_gpus))
model = NestedUNet(opt.num_channel, 2).to(device)
model = nn.DataParallel(model, device_ids=device_ids)
CPU에서 테스트 :
model = NestedUNet_Purn2(opt.num_channel, 2).to(dev)
device_ids = list(range(num_gpus))
model = torch.nn.DataParallel(model, device_ids=device_ids)
model_old = torch.load(path, map_location=dev)
pretrained_dict = model_old.state_dict()
model_dict = model.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
정확한 결과를 얻을 수 있지만 삭제하면 :
device_ids = list(range(num_gpus))
model = torch.nn.DataParallel(model, device_ids=device_ids)
결과가 잘못되었습니다.
nn.DataParallel
실제 모델이 module
속성에 할당되는 모델을 래핑 합니다. 이는 상태 사전의 키에 module.
접두사 가 있음을 의미합니다 .
차이를 확인하기 위해 하나의 컨볼 루션으로 매우 단순화 된 버전을 살펴 보겠습니다.
class NestedUNet(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
model = NestedUNet()
model.state_dict().keys() # => odict_keys(['conv1.weight', 'conv1.bias'])
# Wrap the model in DataParallel
model_dp = nn.DataParallel(model, device_ids=range(num_gpus))
model_dp.state_dict().keys() # => odict_keys(['module.conv1.weight', 'module.conv1.bias'])
저장 한 상태 사전이 nn.DataParallel
일반 모델의 상태와 일치하지 않습니다. 현재 상태 dict를로드 된 상태 dict와 병합합니다. 즉, 모델에 키에 속하는 속성이없고 대신 무작위로 초기화 된 모델이 남아 있기 때문에로드 된 상태가 무시됩니다.
이러한 실수를 방지하려면 상태 사전을 병합하지 말고 모델에 직접 적용해야합니다.이 경우 키가 일치하지 않으면 오류가 발생합니다.
RuntimeError: Error(s) in loading state_dict for NestedUNet:
Missing key(s) in state_dict: "conv1.weight", "conv1.bias".
Unexpected key(s) in state_dict: "module.conv1.weight", "module.conv1.bias".
저장 한 상태 사전을 호환 가능하게하려면 module.
접두어를 제거 할 수 있습니다 .
pretrained_dict = {key.replace("module.", ""): value for key, value in pretrained_dict.items()}
model.load_state_dict(pretrained_dict)
또한 nn.DataParallel
상태를 저장하기 전에 모델의 래핑을 해제하여 향후이 문제를 방지 할 수 있습니다 model.module.state_dict()
. 따라서 항상 상태와 함께 모델을 먼저로드 한 다음 나중에 nn.DataParallel
여러 GPU를 사용하려는 경우 모델을 넣기로 결정할 수 있습니다.
이 기사는 인터넷에서 수집됩니다. 재 인쇄 할 때 출처를 알려주십시오.
침해가 발생한 경우 연락 주시기 바랍니다[email protected] 삭제
몇 마디 만하겠습니다