본문 바로가기

Python

[Python][Pytorch] 학습모델 save & Load

pytorch로 학습한 모델을 저장하고 불러오는 코드를 공유합니다.

모델을 저장할 때에는 학습모델과 사용한 Optimizer를 함께 저장해줘야 합니다.

 

Pytorch 모델 저장 Code 

 

 torch.save(model, './model_name.pt')

 torch.save(model.state_dict(), './model_state_dict.pt')

 torch.save({'model': model.state_dict(), 'optimizer': optimizer.state_dict()}, './all.tar')

 

저장한 모델 불러오기

 

 model = torch.load('./model_name.pt')

 model.load_state_dict(torch.load('./model_state_dict.pt')

 

 checkpoint = torch.load('./all.tar')

 model.loaded_state_dict(checkpoint['model'])

 optimizer.load_state_dict(checkpoint['optimizer'])