diff --git a/src/train.py b/src/train.py index 75d8842..4eb6f42 100644 --- a/src/train.py +++ b/src/train.py @@ -131,7 +131,7 @@ def quantize_model(self): return print('Load Best Model') - ckpt = f'{self.checkpoint_dir}/mobilenetv2.pt' + ckpt = f'{self.checkpoint_dir}/best.pt' save_info = torch.load(ckpt, map_location=self.device) self.net = save_info['model'] self.net.load_state_dict(save_info['state_dict'])