테스트 할 때 왜 DataParallel을 사용해야합니까?

리키 오

GPU에서 훈련, num_gpus는 1로 설정됩니다.

device_ids = list(range(num_gpus))
model = NestedUNet(opt.num_channel, 2).to(device)
model = nn.DataParallel(model, device_ids=device_ids)

CPU에서 테스트 :

model = NestedUNet_Purn2(opt.num_channel, 2).to(dev)
device_ids = list(range(num_gpus))
model = torch.nn.DataParallel(model, device_ids=device_ids)
model_old = torch.load(path, map_location=dev)
pretrained_dict = model_old.state_dict()
model_dict = model.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)

정확한 결과를 얻을 수 있지만 삭제하면 :

device_ids = list(range(num_gpus))
model = torch.nn.DataParallel(model, device_ids=device_ids)

결과가 잘못되었습니다.

마이클 정고

nn.DataParallel실제 모델이 module속성에 할당되는 모델을 래핑 합니다. 이는 상태 사전의 키에 module.접두사 가 있음을 의미합니다 .

차이를 확인하기 위해 하나의 컨볼 루션으로 매우 단순화 된 버전을 살펴 보겠습니다.

class NestedUNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)

model = NestedUNet()

model.state_dict().keys() # => odict_keys(['conv1.weight', 'conv1.bias'])

# Wrap the model in DataParallel
model_dp = nn.DataParallel(model, device_ids=range(num_gpus))

model_dp.state_dict().keys() # => odict_keys(['module.conv1.weight', 'module.conv1.bias'])

저장 한 상태 사전이 nn.DataParallel일반 모델의 상태와 일치하지 않습니다. 현재 상태 dict를로드 된 상태 dict와 병합합니다. 즉, 모델에 키에 속하는 속성이없고 대신 무작위로 초기화 된 모델이 남아 있기 때문에로드 된 상태가 무시됩니다.

이러한 실수를 방지하려면 상태 사전을 병합하지 말고 모델에 직접 적용해야합니다.이 경우 키가 일치하지 않으면 오류가 발생합니다.

RuntimeError: Error(s) in loading state_dict for NestedUNet:
        Missing key(s) in state_dict: "conv1.weight", "conv1.bias".
        Unexpected key(s) in state_dict: "module.conv1.weight", "module.conv1.bias".

저장 한 상태 사전을 호환 가능하게하려면 module.접두어를 제거 할 수 있습니다 .

pretrained_dict = {key.replace("module.", ""): value for key, value in pretrained_dict.items()}
model.load_state_dict(pretrained_dict)

또한 nn.DataParallel상태를 저장하기 전에 모델의 래핑을 해제하여 향후이 문제를 방지 할 수 있습니다 model.module.state_dict(). 따라서 항상 상태와 함께 모델을 먼저로드 한 다음 나중에 nn.DataParallel여러 GPU를 사용하려는 경우 모델을 넣기로 결정할 수 있습니다.

이 기사는 인터넷에서 수집됩니다. 재 인쇄 할 때 출처를 알려주십시오.

침해가 발생한 경우 연락 주시기 바랍니다[email protected] 삭제

에서 수정
0

몇 마디 만하겠습니다

0리뷰
로그인참여 후 검토

관련 기사

분류에서Dev

문자열을 분할 할 때 왜 stringstream을 사용해야합니까?

분류에서Dev

"apt-key adv"를 사용할 때 왜 ": 80"을 지정해야합니까?

분류에서Dev

Ruby를 사용할 때 왜`source / .rvm / scripts / rvm`을 입력해야합니까?

분류에서Dev

say 함수를 사용할 때 왜 use 문을 지정해야합니까?

분류에서Dev

RGR 방법론을 사용할 때 속성 테스트를 단위 테스트로 실행해야합니까?

분류에서Dev

maven-failsafe-plugin을 사용할 때 통합 테스트를 어디에 저장해야합니까?

분류에서Dev

Scikit-learn 및 데이터 시각화 : 예측을 사용할 때 왜 Ravel을 사용해야합니까?

분류에서Dev

각도에서 ng-src를 사용할 때 왜 표현식을 사용해야합니까?

분류에서Dev

scalatest를 사용하여 akka를 테스트 할 때 테스트 케이스를 찾을 수 없습니다. 수정하려면 어떻게해야합니까?

분류에서Dev

RAM에 여유 공간이 충분할 때 왜 스왑을 사용합니까?

분류에서Dev

요구 사항이 명확 할 때 스크럼을 사용해야합니까?

분류에서Dev

CUDA를 사용할 때 왜 memset을 사용합니까?

분류에서Dev

단위 테스트-인스턴스화가 다를 때 setUp 및 tearDown을 사용할 수 있거나 사용해야합니까?

분류에서Dev

pytest-qt를 사용하여 테스트 할 때 왜 치명적인 Python 오류가 발생합니까?

분류에서Dev

Gradle을 사용할 때 버전 및 SDK 정보를 Android 매니페스트에 복제해야합니까?

분류에서Dev

신용 카드를 확인하기 위해 Luhn의 알고리즘을 사용할 때 왜 숫자를 반대로해야합니까?

분류에서Dev

Java8에서 스트림을 사용할 때 항상 Optional을 확인해야합니까?

분류에서Dev

Karate UI 테스트 자동화 정보, karate-chrome을 사용할 때 파일을 업로드하려면 어떻게해야합니까?

분류에서Dev

Observable을 반환하는 HTTP 호출을 할 때 Angular 단위 테스트에서 fakesAync를 어떻게 사용해야합니까?

분류에서Dev

네트워크를 통해 컴퓨터의 드라이브에 액세스 할 때 왜 '$'를 사용합니까?

분류에서Dev

네트워크를 통해 컴퓨터의 드라이브에 액세스 할 때 왜 '$'를 사용합니까?

분류에서Dev

mochai와 chai를 통해 테스트 할 때 왜 던지기 오류 테스트에 실패 했습니까?

분류에서Dev

기능 테스트 매크로 및 Clang을 사용할 때 C ++ Future-extensions에 대한 경고를 피하려면 어떻게해야합니까?

분류에서Dev

우리가 족과 같은 일을 할 수있을 때 우리는 왜 빌더 디자인 패턴을 사용해야합니까?

분류에서Dev

dd 명령을 사용할 때 드라이브를 마운트 해제해야합니까?

분류에서Dev

dd 명령을 사용할 때 드라이브를 마운트 해제해야합니까?

분류에서Dev

방랑 함을 사용할 때 livereload를 위해 전달 된 포트를 추가해야합니까?

분류에서Dev

방랑 함을 사용할 때 livereload를 위해 전달 된 포트를 추가해야합니까?

분류에서Dev

부트 스트랩을 사용할 때 왜곡 된 탐색 모음

Related 관련 기사

  1. 1

    문자열을 분할 할 때 왜 stringstream을 사용해야합니까?

  2. 2

    "apt-key adv"를 사용할 때 왜 ": 80"을 지정해야합니까?

  3. 3

    Ruby를 사용할 때 왜`source / .rvm / scripts / rvm`을 입력해야합니까?

  4. 4

    say 함수를 사용할 때 왜 use 문을 지정해야합니까?

  5. 5

    RGR 방법론을 사용할 때 속성 테스트를 단위 테스트로 실행해야합니까?

  6. 6

    maven-failsafe-plugin을 사용할 때 통합 테스트를 어디에 저장해야합니까?

  7. 7

    Scikit-learn 및 데이터 시각화 : 예측을 사용할 때 왜 Ravel을 사용해야합니까?

  8. 8

    각도에서 ng-src를 사용할 때 왜 표현식을 사용해야합니까?

  9. 9

    scalatest를 사용하여 akka를 테스트 할 때 테스트 케이스를 찾을 수 없습니다. 수정하려면 어떻게해야합니까?

  10. 10

    RAM에 여유 공간이 충분할 때 왜 스왑을 사용합니까?

  11. 11

    요구 사항이 명확 할 때 스크럼을 사용해야합니까?

  12. 12

    CUDA를 사용할 때 왜 memset을 사용합니까?

  13. 13

    단위 테스트-인스턴스화가 다를 때 setUp 및 tearDown을 사용할 수 있거나 사용해야합니까?

  14. 14

    pytest-qt를 사용하여 테스트 할 때 왜 치명적인 Python 오류가 발생합니까?

  15. 15

    Gradle을 사용할 때 버전 및 SDK 정보를 Android 매니페스트에 복제해야합니까?

  16. 16

    신용 카드를 확인하기 위해 Luhn의 알고리즘을 사용할 때 왜 숫자를 반대로해야합니까?

  17. 17

    Java8에서 스트림을 사용할 때 항상 Optional을 확인해야합니까?

  18. 18

    Karate UI 테스트 자동화 정보, karate-chrome을 사용할 때 파일을 업로드하려면 어떻게해야합니까?

  19. 19

    Observable을 반환하는 HTTP 호출을 할 때 Angular 단위 테스트에서 fakesAync를 어떻게 사용해야합니까?

  20. 20

    네트워크를 통해 컴퓨터의 드라이브에 액세스 할 때 왜 '$'를 사용합니까?

  21. 21

    네트워크를 통해 컴퓨터의 드라이브에 액세스 할 때 왜 '$'를 사용합니까?

  22. 22

    mochai와 chai를 통해 테스트 할 때 왜 던지기 오류 테스트에 실패 했습니까?

  23. 23

    기능 테스트 매크로 및 Clang을 사용할 때 C ++ Future-extensions에 대한 경고를 피하려면 어떻게해야합니까?

  24. 24

    우리가 족과 같은 일을 할 수있을 때 우리는 왜 빌더 디자인 패턴을 사용해야합니까?

  25. 25

    dd 명령을 사용할 때 드라이브를 마운트 해제해야합니까?

  26. 26

    dd 명령을 사용할 때 드라이브를 마운트 해제해야합니까?

  27. 27

    방랑 함을 사용할 때 livereload를 위해 전달 된 포트를 추가해야합니까?

  28. 28

    방랑 함을 사용할 때 livereload를 위해 전달 된 포트를 추가해야합니까?

  29. 29

    부트 스트랩을 사용할 때 왜곡 된 탐색 모음

뜨겁다태그

보관