-
Notifications
You must be signed in to change notification settings - Fork 13
Open
Description
def load(self, args):
# Load a trained model and vocabulary that you have fine-tuned
assert args.reload_from>=0, "please specify the checkpoint iteration in args.reload_from"
output_dir = os.path.join(f"./output/{args.model}/{args.model_size}/models/", f'checkpoint-{args.reload_from}')
self.model = DialogBERT.from_pretrained(output_dir)
self.model.to(args.device)
def from_pretrained(self, model_dir):
self.encoder_config = BertConfig.from_pretrained(model_dir)
self.tokenizer = BertTokenizer.from_pretrained(path.join(model_dir, 'tokenizer'), do_lower_case=True)
self.utt_encoder = BertForPreTraining.from_pretrained(path.join(model_dir, 'utt_encoder'))
self.context_encoder = BertForSequenceClassification.from_pretrained(path.join(model_dir, 'context_encoder'))
self.context_mlm_trans = BertPredictionHeadTransform(self.encoder_config)
self.context_mlm_trans.load_state_dict(torch.load(path.join(model_dir, 'context_mlm_trans.pkl')),strict= False)
self.context_order_trans = SelfSorting(self.encoder_config.hidden_size)
self.context_order_trans.load_state_dict(torch.load(path.join(model_dir, 'context_order_trans.pkl')), strict= False)
self.decoder_config = BertConfig.from_pretrained(model_dir)
self.decoder = BertLMHeadModel.from_pretrained(path.join(model_dir, 'decoder'))
File "D:\NLP\DialogBERT-master\solvers.py", line 77, in load
self.model.to(args.device)
AttributeError: 'NoneType' object has no attribute 'to'
DialogBERT.from_pretrained is none ,how can i solve it?
Metadata
Metadata
Assignees
Labels
No labels