-
모델 저장파이토치&텐서플로우&관련코딩 2021. 4. 22. 15:46
1. 2021-04-22 hubmap 모델 5개 앙상블 병렬처리 하다가 생긴 문제.
model을 저장하고 로드하는데에 key가 안맞는 오류가 생겨서 헤맸다
=> GPU 병렬처리하면 key들의 맨 앞에 module.이 추가로 붙는다 따라서 아래 코드로 이걸 다 없얘줘야 로드가 가능하다.
for key in list(state_dict.keys):
if "module." in key:
state_dict[key.replace("module.", "") = state_dict[key]
del state_dict[key]
del을 안하면 state_dict 는 OrderedDict 이므로 이전의 module.이 붙은 key들이 안사라지고 새로운 module.안붙은 것이 추가만 되므로 꼭 삭제 해 줘야한다.
'파이토치&텐서플로우&관련코딩' 카테고리의 다른 글
view vs reshape (0) 2021.04.23 파이토치 torchvision.utils.save_image (0) 2021.04.22 cv2.gaussianblur 계산하는법 (0) 2021.04.17 BCE & BCEwithLOGIT & CE (0) 2021.02.15 progress bar 구현하기. (Jupyter, Pycharm) (0) 2021.02.03