diff --git a/requirements.txt b/requirements.txt index 20c8b3e..49c737c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,7 +3,7 @@ matplotlib==2.1.0 tensorflow numpy==1.13.3 inflect==0.2.5 -librosa==0.6.0 +librosa==0.6.2 scipy==1.0.0 tensorboardX==1.1 Unidecode==1.0.22 diff --git a/train.py b/train.py index e8035c1..d83095d 100644 --- a/train.py +++ b/train.py @@ -41,8 +41,9 @@ def load_checkpoint(checkpoint_path, model, optimizer): assert os.path.isfile(checkpoint_path) checkpoint_dict = torch.load(checkpoint_path, map_location='cpu') - iteration = checkpoint_dict['iteration'] - optimizer.load_state_dict(checkpoint_dict['optimizer']) + iteration = checkpoint_dict.get('iteration', 0) + if 'optimizer' in checkpoint_dict: + optimizer.load_state_dict(checkpoint_dict['optimizer']) model_for_loading = checkpoint_dict['model'] model.load_state_dict(model_for_loading.state_dict()) print("Loaded checkpoint '{}' (iteration {})" .format(