pytorch로 학습한 모델을 저장하고 불러오는 코드를 공유합니다.
모델을 저장할 때에는 학습모델과 사용한 Optimizer를 함께 저장해줘야 합니다.
Pytorch 모델 저장 Code
torch.save(model, './model_name.pt')
torch.save(model.state_dict(), './model_state_dict.pt')
torch.save({'model': model.state_dict(), 'optimizer': optimizer.state_dict()}, './all.tar')
저장한 모델 불러오기
model = torch.load('./model_name.pt')
model.load_state_dict(torch.load('./model_state_dict.pt')
checkpoint = torch.load('./all.tar')
model.loaded_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
'Python' 카테고리의 다른 글
[Python] 시간 축 간격 조절 (0) | 2021.05.17 |
---|---|
[Python] 특정 시간 조건 행 추출 (0) | 2021.05.17 |
[Python] matplotlib font 속성 (0) | 2021.01.15 |
[Python] glob을 이용하여 csv 파일 불러오기 (0) | 2021.01.15 |
[Python] Correlation & Heatmap (0) | 2021.01.12 |