From c5d94aeaadadb7128b56866bac6a305e3e931e4e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B9=94=E5=AD=90=E5=8D=BF?= <2020011015@secoder.net> Date: Mon, 18 Apr 2022 06:43:36 +0800 Subject: [PATCH 1/3] ModelForLM for CPM1 CPM2 Bert GPT2 GPTJ T5 and corresponding return datatype --- model_center/model/__init__.py | 12 +- model_center/model/bert.py | 220 +++++++++++------ model_center/model/cpm1.py | 179 ++++++++------ model_center/model/cpm2.py | 225 +++++++++++------- model_center/model/gpt2.py | 235 +++++++++++------- model_center/model/gptj.py | 202 ++++++++++------ model_center/model/modeling_output.py | 89 +++++++ model_center/model/t5.py | 327 ++++++++++++++++---------- 8 files changed, 957 insertions(+), 532 deletions(-) create mode 100644 model_center/model/modeling_output.py diff --git a/model_center/model/__init__.py b/model_center/model/__init__.py index fa904c8..c4fda46 100644 --- a/model_center/model/__init__.py +++ b/model_center/model/__init__.py @@ -16,9 +16,9 @@ from .config import * # Model Architecture -from .cpm1 import CPM1 -from .cpm2 import CPM2 -from .t5 import T5 -from .gpt2 import GPT2 -from .gptj import GPTj -from .bert import Bert +from .cpm1 import CPM1, CPM1ForLM +from .cpm2 import CPM2, CPM2ForLM +from .t5 import T5, T5ForLM +from .gpt2 import GPT2, GPT2ForLM +from .gptj import GPTj, GPTjForLM +from .bert import Bert, BertForLM diff --git a/model_center/model/bert.py b/model_center/model/bert.py index be11bd3..1a0b5df 100644 --- a/model_center/model/bert.py +++ b/model_center/model/bert.py @@ -17,7 +17,10 @@ from ..layer import Encoder, Embedding, Linear, LayerNorm from .basemodel import BaseModel from .config import BertConfig -from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions +from .modeling_output import ModelOutput,BaseModelOutputWithPoolingAndCrossAttentions +from typing import Optional, Tuple +import bmtrain as bmt + class BertPooler(torch.nn.Module): def __init__(self, dim_model): @@ -30,7 +33,7 @@ def forward(self, hidden_states): pooled_output = self.activation(pooled_output) return pooled_output - + class BertLMHead(torch.nn.Module): def __init__(self, dim_model, vocab_size, norm_eps): super().__init__() @@ -47,89 +50,96 @@ def forward(self, hidden_states): return logits -class Bert(BaseModel): +class BertForPreTrainingOutput(ModelOutput): + loss: Optional[torch.FloatTensor] = None + prediction_logits: torch.FloatTensor = None + seq_relationship_logits: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + +class Bert(BaseModel): _CONFIG_TYPE = BertConfig def __init__(self, config: BertConfig): super().__init__() self.input_embedding = Embedding( - vocab_size = config.vocab_size, - embedding_size = config.dim_model, - length_scale = config.length_scale, - dtype = config.dtype, - int8 = config.int8, - init_mean = config.emb_init_mean, - init_std = config.emb_init_std, + vocab_size=config.vocab_size, + embedding_size=config.dim_model, + length_scale=config.length_scale, + dtype=config.dtype, + int8=config.int8, + init_mean=config.emb_init_mean, + init_std=config.emb_init_std, ) self.position_embedding = Embedding( - vocab_size = config.position_size, - embedding_size = config.dim_model, - length_scale = config.length_scale, - dtype = config.dtype, - int8 = config.int8, - init_mean = config.emb_init_mean, - init_std = config.emb_init_std, + vocab_size=config.position_size, + embedding_size=config.dim_model, + length_scale=config.length_scale, + dtype=config.dtype, + int8=config.int8, + init_mean=config.emb_init_mean, + init_std=config.emb_init_std, ) self.token_type_embedding = Embedding( - vocab_size = config.type_size, - embedding_size = config.dim_model, - length_scale = config.length_scale, - dtype = config.dtype, - int8 = config.int8, - init_mean = config.emb_init_mean, - init_std = config.emb_init_std, + vocab_size=config.type_size, + embedding_size=config.dim_model, + length_scale=config.length_scale, + dtype=config.dtype, + int8=config.int8, + init_mean=config.emb_init_mean, + init_std=config.emb_init_std, ) self.embed_dropout = torch.nn.Dropout(config.dropout_p) self.encoder = Encoder( - num_layers = config.num_layers, - dim_model = config.dim_model, - dim_ff = config.dim_ff, - num_heads = config.num_heads, - dim_head = config.dim_head, - dtype = config.dtype, - int8 = config.int8, - norm_eps = config.norm_eps, - norm_init_var = config.norm_init_var, - norm_bias = config.norm_bias, - att_init_mean = config.att_init_mean, - att_init_std = config.att_init_std, - att_bias = config.att_bias, - att_mask_value = float(config.att_mask_value), - pos_bias_type = config.pos_bias_type, - ffn_init_mean = config.ffn_init_mean, - ffn_init_std = config.ffn_init_std, - ffn_bias = config.ffn_bias, - ffn_activate_fn = config.ffn_activate_fn, - length_scale = config.length_scale, - attn_scale = config.attn_scale, - dropout_p = config.dropout_p, - post_layer_norm = config.post_layer_norm, + num_layers=config.num_layers, + dim_model=config.dim_model, + dim_ff=config.dim_ff, + num_heads=config.num_heads, + dim_head=config.dim_head, + dtype=config.dtype, + int8=config.int8, + norm_eps=config.norm_eps, + norm_init_var=config.norm_init_var, + norm_bias=config.norm_bias, + att_init_mean=config.att_init_mean, + att_init_std=config.att_init_std, + att_bias=config.att_bias, + att_mask_value=float(config.att_mask_value), + pos_bias_type=config.pos_bias_type, + ffn_init_mean=config.ffn_init_mean, + ffn_init_std=config.ffn_init_std, + ffn_bias=config.ffn_bias, + ffn_activate_fn=config.ffn_activate_fn, + length_scale=config.length_scale, + attn_scale=config.attn_scale, + dropout_p=config.dropout_p, + post_layer_norm=config.post_layer_norm, ) self.tied = config.tied self.cls_head = config.cls_head if self.cls_head: self.cls_projection = Linear( - dim_out = self.cls_head, - dim_in = config.dim_model, - length_scale = config.length_scale, - dtype = config.dtype, - int8 = config.int8, - init_mean = config.proj_init_mean, - init_std = config.proj_init_std, - bias = config.proj_bias, + dim_out=self.cls_head, + dim_in=config.dim_model, + length_scale=config.length_scale, + dtype=config.dtype, + int8=config.int8, + init_mean=config.proj_init_mean, + init_std=config.proj_init_std, + bias=config.proj_bias, ) if not self.tied: self.lm_head = BertLMHead( - dim_model = config.dim_model, - vocab_size = config.vocab_size, - norm_eps = config.norm_eps, + dim_model=config.dim_model, + vocab_size=config.vocab_size, + norm_eps=config.norm_eps, ) self.pooler = BertPooler(config.dim_model) @@ -140,36 +150,38 @@ def forward(self, attention_mask=None, token_type_ids=None, position_ids=None, - head_mask=None, #unused + head_mask=None, # unused inputs_embeds=None, - encoder_hidden_states=None, #unused - encoder_attention_mask=None, #unused - output_attentions=None, #unused - output_hidden_states=None, #unused + encoder_hidden_states=None, # unused + encoder_attention_mask=None, # unused + output_attentions=None, # unused + output_hidden_states=None, # unused return_dict=True, - return_logits = False, - ): + # TODO:删除 + # return_logits=False, + ): """ This model inherits from BaseModel. This model is also a PyTorch torch.nn.Module subclass. You can use it as a regular PyTorch Module. You can also select the data and data type that you want the model to return through changing the value of `return_dict` and `return_logits`. Args: input_ids (:obj:`torch.Tensor` of shape ``(batch, seq_length)``): Indices of input sequence tokens. It will be embedded by model's internal embedding lookup matrix. - length (:obj:`torch.Tensor` of shape ``(batch)``): Length of input sequence before padding. + length (:obj:`torch.Tensor` of shape ``(batch)``): Length of input sequence before padding. attention_mask (:obj:`torch.Tensor` of shape ``(batch, seq_length)``): Used to avoid performing attention on padding token indices. - token_type_ids(:obj:`torch.Tensor` of shape ``(batch, seq_length)``): Unused. + token_type_ids(:obj:`torch.Tensor` of shape ``(batch, seq_length)``): Unused. position_ids(:obj:`torch.Tensor` of shape ``(batch, seq_length)``): Unused. head_mask (:obj:`torch.Tensor` of shape ``(num_layers, num_heads)``): Unused. - inputs_embeds (:obj:`torch.Tensor` of shape ``(batch, seq_length, dim_model)``): Embedding of the input. You can choose to directly pass the inputs embedding to control the way of embedding. + inputs_embeds (:obj:`torch.Tensor` of shape ``(batch, seq_length, dim_model)``): Embedding of the input. You can choose to directly pass the inputs embedding to control the way of embedding. encoder_hidden_states(:obj:`torch.Tensor` of shape(batch, seq_length, dim_model)): Unused. - encoder_attention_mask (:obj:`torch.Tensor` of shape ``(batch, seq_length)``): Unused. + encoder_attention_mask (:obj:`torch.Tensor` of shape ``(batch, seq_length)``): Unused. output_attentions (:obj:`torch.Tensor` of shape ``(batch, num_heads, seq_length, seq_length)``): Unused. output_hidden_states (:obj:`torch.Tensor` of shape ``(batch, seq_length, dim_model)``): Unused. return_dict (:obj:`bool`): Whether to return a BaseModelOutputWithPoolingAndCrossAttentions instead of just a tuple. - return_logits (:obj:`bool`): Whether to return the prediction score for each token in vocabulary (before softmax). + TODO:删除 + // return_logits (:obj:`bool`): Whether to return the prediction score for each token in vocabulary (before softmax). Return: - BaseModelOutputWithPoolingAndCrossAttentions or tuple or torch.Tensor of shape (batch, seq_length, vocab_output_size) or (batch, seqlen, cls_head): The Bert output. Depended on the value of `return_dict` and `return_logits` + BaseModelOutputWithPoolingAndCrossAttentions or tuple or torch.Tensor of shape (batch, seq_length, vocab_output_size) or (batch, seqlen, cls_head): The Bert output. Depended on the value of `return_dict` and `return_logits` """ assert input_ids is not None or inputs_embeds is not None @@ -209,7 +221,7 @@ def forward(self, hidden_states = self.encoder(hidden_states, attention_mask) - if self.cls_head: + """if self.cls_head: logits = self.cls_projection(hidden_states) elif self.tied: logits = self.input_embedding.projection(hidden_states) @@ -217,7 +229,7 @@ def forward(self, logits = self.lm_head(hidden_states) if return_logits: - return logits + return logits""" pooled_output = self.pooler(hidden_states) @@ -231,4 +243,64 @@ def forward(self, hidden_states=None, attentions=None, cross_attentions=None, - ) \ No newline at end of file + ) + + +class BertForLM(BaseModel): + _CONFIG_TYPE = BertConfig + + def __init__(self, config: BertConfig): + super().__init__() + self.bert = Bert(config) + self.seq_cls = torch.nn.Linear(config.hidden_size, 2) + + def forward(self, + input_ids=None, + length=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, # unused + inputs_embeds=None, + encoder_hidden_states=None, # unused + encoder_attention_mask=None, # unused + labels=None, + next_sentence_label=None, + output_attentions=None, # unused + output_hidden_states=None, # unused + return_dict=True, + ): + outputs = self.bert( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states, pooler_output = outputs[:2] + if self.cls_head: + logits = self.cls_projection(hidden_states) + elif self.tied: + logits = self.input_embedding.projection(hidden_states) + elif not self.tied: + logits = self.lm_head(hidden_states) + + seq_relationship_score = self.seq_cls(pooler_output) + loss_func = bmt.loss.FusedCrossEntropy(ignore_index=-100) + masked_loss = loss_func(logits.view(-1, self.config.vocab_size), labels.view(-1)) + next_sentence_loss = loss_func(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) + total_loss = masked_loss + next_sentence_loss + + if not return_dict: + return (total_loss, logits, seq_relationship_score, outputs.hidden_states, outputs.attentions) + return BertForPreTrainingOutput( + loss=total_loss, + prediction_logits=logits, + seq_relationship_logits=seq_relationship_score, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/model_center/model/cpm1.py b/model_center/model/cpm1.py index 53a85da..f07ccbc 100644 --- a/model_center/model/cpm1.py +++ b/model_center/model/cpm1.py @@ -17,96 +17,99 @@ from ..layer import Encoder, Embedding, Linear, RelativePositionEmbedding from .config import CPM1Config from .basemodel import BaseModel +from .modeling_output import CausalLMOutputWithCrossAttentions, BaseModelOutputWithPastAndCrossAttentions +import bmtrain as bmt + class CPM1(BaseModel): _CONFIG_TYPE = CPM1Config - + def __init__(self, config: CPM1Config): - + super().__init__() self.encoder = Encoder( - num_layers = config.num_layers, - dim_model = config.dim_model, - dim_ff = config.dim_ff, - num_heads = config.num_heads, - dim_head = config.dim_head, - dtype = config.dtype, - int8 = config.int8, - norm_eps = config.norm_eps, - norm_init_var = config.norm_init_var, - norm_bias = config.norm_bias, - att_init_mean = config.att_init_mean, - att_init_std = config.att_init_std, - att_bias = config.att_bias, - att_mask_value = float(config.att_mask_value), - pos_bias_type = config.pos_bias_type, - ffn_init_mean = config.ffn_init_mean, - ffn_init_std = config.ffn_init_std, - ffn_bias = config.ffn_bias, - ffn_activate_fn = config.ffn_activate_fn, - length_scale = config.length_scale, - attn_scale = config.attn_scale, - dropout_p = config.dropout_p, + num_layers=config.num_layers, + dim_model=config.dim_model, + dim_ff=config.dim_ff, + num_heads=config.num_heads, + dim_head=config.dim_head, + dtype=config.dtype, + int8=config.int8, + norm_eps=config.norm_eps, + norm_init_var=config.norm_init_var, + norm_bias=config.norm_bias, + att_init_mean=config.att_init_mean, + att_init_std=config.att_init_std, + att_bias=config.att_bias, + att_mask_value=float(config.att_mask_value), + pos_bias_type=config.pos_bias_type, + ffn_init_mean=config.ffn_init_mean, + ffn_init_std=config.ffn_init_std, + ffn_bias=config.ffn_bias, + ffn_activate_fn=config.ffn_activate_fn, + length_scale=config.length_scale, + attn_scale=config.attn_scale, + dropout_p=config.dropout_p, ) self.input_embedding = Embedding( - vocab_size = config.vocab_size, - embedding_size = config.dim_model, - length_scale = config.length_scale, - dtype = config.dtype, - int8 = config.int8, - init_mean = config.emb_init_mean, - init_std = config.emb_init_std, + vocab_size=config.vocab_size, + embedding_size=config.dim_model, + length_scale=config.length_scale, + dtype=config.dtype, + int8=config.int8, + init_mean=config.emb_init_mean, + init_std=config.emb_init_std, ) self.position_bias = RelativePositionEmbedding( - num_heads = config.num_heads, - num_buckets = config.position_bias_num_buckets, - max_distance = config.position_bias_max_distance, - bidirectional = True, - dtype = config.dtype, - init_mean = config.pos_init_mean, - init_std = config.pos_init_std, + num_heads=config.num_heads, + num_buckets=config.position_bias_num_buckets, + max_distance=config.position_bias_max_distance, + bidirectional=True, + dtype=config.dtype, + init_mean=config.pos_init_mean, + init_std=config.pos_init_std, ) - + self.tied = config.tied self.cls_head = config.cls_head if self.cls_head: self.output_projection = Linear( - dim_out = config.cls_head, - dim_in = config.dim_model, - length_scale = config.length_scale, - dtype = config.dtype, - int8 = config.int8, - init_mean = config.proj_init_mean, - init_std = config.proj_init_std, - bias = config.proj_bias, + dim_out=config.cls_head, + dim_in=config.dim_model, + length_scale=config.length_scale, + dtype=config.dtype, + int8=config.int8, + init_mean=config.proj_init_mean, + init_std=config.proj_init_std, + bias=config.proj_bias, ) elif not self.tied: self.output_projection = Linear( - dim_out = config.vocab_size, - dim_in = config.dim_model, - length_scale = config.length_scale, - dtype = config.dtype, - int8 = config.int8, - init_mean = config.proj_init_mean, - init_std = config.proj_init_std, - bias = config.proj_bias, + dim_out=config.vocab_size, + dim_in=config.dim_model, + length_scale=config.length_scale, + dtype=config.dtype, + int8=config.int8, + init_mean=config.proj_init_mean, + init_std=config.proj_init_std, + bias=config.proj_bias, ) - def forward(self, input : torch.Tensor, # (batch, seqlen) - length : torch.Tensor, # (batch) - context : torch.Tensor, # (batch, seqlen) - span : torch.Tensor): # (batch, seqlen) + def forward(self, input: torch.Tensor, # (batch, seqlen) + length: torch.Tensor, # (batch) + context: torch.Tensor, # (batch, seqlen) + span: torch.Tensor): # (batch, seqlen) """ This model inherits from BaseModel. This model is also a PyTorch torch.nn.Module subclass. You can use it as a regular PyTorch Module. - + Args: - input (:obj:`torch.Tensor` of shape ``(batch, seqlen)``): - length (:obj:`torch.Tensor` of shape ``(batch)``): - context (:obj:`torch.Tensor` of shape ``(batch, seqlen)``): - span (:obj:`torch.Tensor` of shape ``(batch, seqlen)``): + input (:obj:`torch.Tensor` of shape ``(batch, seqlen)``): + length (:obj:`torch.Tensor` of shape ``(batch)``): + context (:obj:`torch.Tensor` of shape ``(batch, seqlen)``): + span (:obj:`torch.Tensor` of shape ``(batch, seqlen)``): Return: torch.Tensor of shape (batch, seqlen, vocab_size) or (batch, seqlen, cls_head): The CPM output. Prediction scores of the language modeling before SoftMax. @@ -116,12 +119,12 @@ def forward(self, input : torch.Tensor, # (batch, seqlen) seqlen = input.size(1) with torch.no_grad(): - device = input.device directional_mask_2d = torch.arange(seqlen, device=device) <= torch.arange(seqlen, device=device).view(-1, 1) # attention_mask = context[:, :, None] | directional_mask_2d.view(1, seqlen, seqlen) - attention_mask = context[:, None, :] | (context[:, :, None].logical_not() & directional_mask_2d.view(1, seqlen, seqlen)) + attention_mask = context[:, None, :] | ( + context[:, :, None].logical_not() & directional_mask_2d.view(1, seqlen, seqlen)) attention_mask = attention_mask & (span[:, None, :] == span[:, :, None]) mask_1d = torch.arange(seqlen, device=device)[None, :].repeat(batch, 1) < length[:, None] @@ -131,12 +134,52 @@ def forward(self, input : torch.Tensor, # (batch, seqlen) hidden_states = self.input_embedding(input) hidden_states = self.encoder(hidden_states, attention_mask, position_bias) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=None, + hidden_states=None, + attentions=None, + cross_attentions=None, + ) + +class CPM1ForLM(BaseModel): + _CONFIG_TYPE = CPM1Config + + def __init__(self, config: CPM1Config): + super().__init__() + self.cpm1 = CPM1(config) + + def forward(self, + input: torch.Tensor, # (batch, seqlen) + length: torch.Tensor, # (batch) + context: torch.Tensor, # (batch, seqlen) + span: torch.Tensor, + labels: torch.Tendor, + ): + outputs = self.cpm1( + input=input, # (batch, seqlen) + length=length, # (batch) + context=context, # (batch, seqlen) + span=span, + ) + hidden_states = outputs[0] if self.cls_head: logits = self.output_projection(hidden_states) elif not self.tied: logits = self.output_projection(hidden_states) else: logits = self.input_embedding.projection(hidden_states) - - return logits + if labels: + _logits = logits[..., :-1, :].contiguous() + _labels = labels[..., 1:].contiguous() + loss_func = bmt.loss.FusedCrossEntropy(ignore_index=-100) + loss = loss_func(_logits.view(-1, _logits.size(-1)), labels.view(-1)) + return CausalLMOutputWithCrossAttentions( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) diff --git a/model_center/model/cpm2.py b/model_center/model/cpm2.py index 9e8b13a..22023c5 100644 --- a/model_center/model/cpm2.py +++ b/model_center/model/cpm2.py @@ -17,140 +17,144 @@ from ..layer import Encoder, Decoder, Embedding, Linear, RelativePositionEmbedding from .config import CPM2Config from .basemodel import BaseModel +from .modeling_output import Seq2SeqModelOutput, Seq2SeqLMOutput +import bmtrain as bmt + class CPM2(BaseModel): _CONFIG_TYPE = CPM2Config - + def __init__(self, config: CPM2Config): - super().__init__() self.encoder = Encoder( - num_layers = config.num_encoder_layers, - dim_model = config.dim_model, - dim_ff = config.dim_ff, - num_heads = config.num_heads, - dim_head = config.dim_head, - dtype = config.dtype, - int8 = config.int8, - norm_eps = config.norm_eps, - norm_init_var = config.norm_init_var, - norm_bias = config.norm_bias, - att_init_mean = config.att_init_mean, - att_init_std = config.att_init_std, - att_bias = config.att_bias, - att_mask_value = float(config.att_mask_value), - pos_bias_type = config.pos_bias_type, - ffn_init_mean = config.ffn_init_mean, - ffn_init_std = config.ffn_init_std, - ffn_bias = config.ffn_bias, - ffn_activate_fn = config.ffn_activate_fn, - length_scale = config.length_scale, - attn_scale = config.attn_scale, - dropout_p = config.dropout_p, + num_layers=config.num_encoder_layers, + dim_model=config.dim_model, + dim_ff=config.dim_ff, + num_heads=config.num_heads, + dim_head=config.dim_head, + dtype=config.dtype, + int8=config.int8, + norm_eps=config.norm_eps, + norm_init_var=config.norm_init_var, + norm_bias=config.norm_bias, + att_init_mean=config.att_init_mean, + att_init_std=config.att_init_std, + att_bias=config.att_bias, + att_mask_value=float(config.att_mask_value), + pos_bias_type=config.pos_bias_type, + ffn_init_mean=config.ffn_init_mean, + ffn_init_std=config.ffn_init_std, + ffn_bias=config.ffn_bias, + ffn_activate_fn=config.ffn_activate_fn, + length_scale=config.length_scale, + attn_scale=config.attn_scale, + dropout_p=config.dropout_p, ) self.decoder = Decoder( - num_layers = config.num_decoder_layers, - dim_model = config.dim_model, - dim_ff = config.dim_ff, - num_heads = config.num_heads, - dim_head = config.dim_head, - dtype = config.dtype, - int8 = config.int8, - norm_eps = config.norm_eps, - norm_init_var = config.norm_init_var, - norm_bias = config.norm_bias, - att_init_mean = config.att_init_mean, - att_init_std = config.att_init_std, - att_bias = config.att_bias, - att_mask_value = float(config.att_mask_value), - pos_bias_type = config.pos_bias_type, - ffn_init_mean = config.ffn_init_mean, - ffn_init_std = config.ffn_init_std, - ffn_bias = config.ffn_bias, - ffn_activate_fn = config.ffn_activate_fn, - length_scale = config.length_scale, - attn_scale = config.attn_scale, - dropout_p = config.dropout_p, + num_layers=config.num_decoder_layers, + dim_model=config.dim_model, + dim_ff=config.dim_ff, + num_heads=config.num_heads, + dim_head=config.dim_head, + dtype=config.dtype, + int8=config.int8, + norm_eps=config.norm_eps, + norm_init_var=config.norm_init_var, + norm_bias=config.norm_bias, + att_init_mean=config.att_init_mean, + att_init_std=config.att_init_std, + att_bias=config.att_bias, + att_mask_value=float(config.att_mask_value), + pos_bias_type=config.pos_bias_type, + ffn_init_mean=config.ffn_init_mean, + ffn_init_std=config.ffn_init_std, + ffn_bias=config.ffn_bias, + ffn_activate_fn=config.ffn_activate_fn, + length_scale=config.length_scale, + attn_scale=config.attn_scale, + dropout_p=config.dropout_p, ) self.input_embedding = Embedding( - vocab_size = config.vocab_size, - embedding_size = config.dim_model, - length_scale = False, # TODO not an elegent implementation # config.length_scale, - dtype = config.dtype, - int8 = config.int8, - init_mean = config.emb_init_mean, - init_std = config.emb_init_std, + vocab_size=config.vocab_size, + embedding_size=config.dim_model, + length_scale=False, # TODO not an elegent implementation # config.length_scale, + dtype=config.dtype, + int8=config.int8, + init_mean=config.emb_init_mean, + init_std=config.emb_init_std, ) self.position_bias_enc = RelativePositionEmbedding( - num_heads = config.num_heads, - num_buckets = config.position_bias_num_buckets, - max_distance = config.position_bias_max_distance, - bidirectional = True, - dtype = config.dtype, - init_mean = config.pos_init_mean, - init_std = config.pos_init_std, + num_heads=config.num_heads, + num_buckets=config.position_bias_num_buckets, + max_distance=config.position_bias_max_distance, + bidirectional=True, + dtype=config.dtype, + init_mean=config.pos_init_mean, + init_std=config.pos_init_std, ) self.position_bias_dec = RelativePositionEmbedding( - num_heads = config.num_heads, - num_buckets = config.position_bias_num_buckets, - max_distance = config.position_bias_max_distance, - bidirectional = False, - dtype = config.dtype, - init_mean = config.pos_init_mean, - init_std = config.pos_init_std, + num_heads=config.num_heads, + num_buckets=config.position_bias_num_buckets, + max_distance=config.position_bias_max_distance, + bidirectional=False, + dtype=config.dtype, + init_mean=config.pos_init_mean, + init_std=config.pos_init_std, ) self.cls_head = config.cls_head self.output_projection = Linear( - dim_out = self.cls_head if self.cls_head else config.vocab_size, - dim_in = config.dim_model, - length_scale = config.length_scale, - dtype = config.dtype, - int8 = config.int8, - init_mean = config.proj_init_mean, - init_std = config.proj_init_std, - bias = config.proj_bias, + dim_out=self.cls_head if self.cls_head else config.vocab_size, + dim_in=config.dim_model, + length_scale=config.length_scale, + dtype=config.dtype, + int8=config.int8, + init_mean=config.proj_init_mean, + init_std=config.proj_init_std, + bias=config.proj_bias, ) - def forward(self, - enc_input : torch.Tensor, # (batch, seq_enc) - enc_length : torch.Tensor, # (batch) - dec_input : torch.Tensor, # (batch, seq_dec) - dec_length : torch.Tensor, # (batch) - ): + def forward(self, + enc_input: torch.Tensor, # (batch, seq_enc) + enc_length: torch.Tensor, # (batch) + dec_input: torch.Tensor, # (batch, seq_dec) + dec_length: torch.Tensor, # (batch) + ): """ This model inherits from BaseModel. This model is also a PyTorch torch.nn.Module subclass. You can use it as a regular PyTorch Module. Args: enc_input (:obj:`torch.Tensor` of shape ``(batch, seq_enc)``): Indices of input sequence tokens for encoder. It will be embedded by model's internal embedding lookup matrix. - enc_length (:obj:`torch.Tensor` of shape ``(batch)``): Length of input sequence for encoder before padding. + enc_length (:obj:`torch.Tensor` of shape ``(batch)``): Length of input sequence for encoder before padding. dec_input (:obj:`torch.Tensor` of shape ``(batch, seq_dec)``): Indices of input sequence tokens for decoder. It will be embedded by model's internal embedding lookup matrix. dec_length (:obj:`torch.Tensor` of shape ``(batch)``): Length of input sequence for encoder before padding. Return: torch.Tensor of shape (batch, seq_dec, vocab_output_size) or (batch, seqlen, cls_head): The CPM-2 output. Prediction scores of the language modeling before SoftMax. """ - + batch = enc_input.size(0) seq_enc = enc_input.size(1) seq_dec = dec_input.size(1) with torch.no_grad(): - device = enc_input.device enc_mask_1d = torch.arange(seq_enc, device=device)[None, :].repeat(batch, 1) < enc_length[:, None] dec_mask_1d = torch.arange(seq_dec, device=device)[None, :].repeat(batch, 1) < dec_length[:, None] - directional_mask_2d = torch.arange(seq_dec, device=device) <= torch.arange(seq_dec, device=device).view(-1, 1) + directional_mask_2d = torch.arange(seq_dec, device=device) <= torch.arange(seq_dec, device=device).view(-1, + 1) # (batch, seq_enc, seq_enc) enc_attention_mask = enc_mask_1d.view(batch, seq_enc, 1) & enc_mask_1d.view(batch, 1, seq_enc) # (batch, seq_dec, seq_dec) - dec_attention_mask = dec_mask_1d.view(batch, seq_dec, 1) & dec_mask_1d.view(batch, 1, seq_dec) & directional_mask_2d.view(1, seq_dec, seq_dec) + dec_attention_mask = dec_mask_1d.view(batch, seq_dec, 1) & dec_mask_1d.view(batch, 1, + seq_dec) & directional_mask_2d.view( + 1, seq_dec, seq_dec) # (batch, seq_dec, seq_enc) cross_attention_mask = enc_mask_1d.view(batch, 1, seq_enc) & dec_mask_1d.view(batch, seq_dec, 1) @@ -171,6 +175,53 @@ def forward(self, hidden_states_enc, cross_attention_mask, None) # (batch, seq_dec, vocab_output_size) - logits = self.output_projection(hidden_states_dec) + return Seq2SeqModelOutput( + last_hidden_state=hidden_states_dec, + encoder_last_hidden_state=hidden_states_enc, + past_key_values=None, + encoder_hidden_states=None, + decoder_hidden_states=None, + decoder_attentions=None, + cross_attentions=None, + encoder_attentions=None, + ) + return logits + +class CPM2ForLM(BaseModel): + _CONFIG_TYPE = CPM2Config + + def __init__(self, config: CPM2Config): + super().__init__() + self.cpm2 = CPM2(config) + + def forward(self, + enc_input: torch.Tensor, # (batch, seq_enc) + enc_length: torch.Tensor, # (batch) + dec_input: torch.Tensor, # (batch, seq_dec) + dec_length: torch.Tensor, # (batch) + labels: torch.Tendor, + ): + outputs = self.cpm1( + enc_input=enc_input, # (batch, seq_enc) + enc_length=enc_length, # (batch) + dec_input=dec_input, # (batch, seq_dec) + dec_length=dec_length, # (batch) + ) + hidden_states = outputs[0] + logits = self.output_projection(hidden_states) + if labels: + loss_func = bmt.loss.FusedCrossEntropy(ignore_index=-100) + loss = loss_func(logits.view(-1, logits.size(-1)), labels.view(-1)) + return Seq2SeqLMOutput( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) diff --git a/model_center/model/gpt2.py b/model_center/model/gpt2.py index 232180b..5ad893b 100644 --- a/model_center/model/gpt2.py +++ b/model_center/model/gpt2.py @@ -17,117 +17,116 @@ from ..layer import Encoder, Embedding, Linear from .basemodel import BaseModel from .config import GPT2Config -from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions +from .modeling_output import CausalLMOutputWithCrossAttentions, BaseModelOutputWithPastAndCrossAttentions +import bmtrain as bmt class GPT2(BaseModel): - _CONFIG_TYPE = GPT2Config def __init__(self, config: GPT2Config): - + super().__init__() self.encoder = Encoder( - num_layers = config.num_layers, - dim_model = config.dim_model, - dim_ff = config.dim_ff, - num_heads = config.num_heads, - dim_head = config.dim_head, - dtype = config.dtype, - int8 = config.int8, - norm_eps = config.norm_eps, - norm_init_var = config.norm_init_var, - norm_bias = config.norm_bias, - att_init_mean = config.att_init_mean, - att_init_std = config.att_init_std, - att_bias = config.att_bias, - att_mask_value = float(config.att_mask_value), - pos_bias_type = config.pos_bias_type, - ffn_init_mean = config.ffn_init_mean, - ffn_init_std = config.ffn_init_std, - ffn_bias = config.ffn_bias, - ffn_activate_fn = config.ffn_activate_fn, - length_scale = config.length_scale, - attn_scale = config.attn_scale, - dropout_p = config.dropout_p, + num_layers=config.num_layers, + dim_model=config.dim_model, + dim_ff=config.dim_ff, + num_heads=config.num_heads, + dim_head=config.dim_head, + dtype=config.dtype, + int8=config.int8, + norm_eps=config.norm_eps, + norm_init_var=config.norm_init_var, + norm_bias=config.norm_bias, + att_init_mean=config.att_init_mean, + att_init_std=config.att_init_std, + att_bias=config.att_bias, + att_mask_value=float(config.att_mask_value), + pos_bias_type=config.pos_bias_type, + ffn_init_mean=config.ffn_init_mean, + ffn_init_std=config.ffn_init_std, + ffn_bias=config.ffn_bias, + ffn_activate_fn=config.ffn_activate_fn, + length_scale=config.length_scale, + attn_scale=config.attn_scale, + dropout_p=config.dropout_p, ) self.embed_dropout = torch.nn.Dropout(config.dropout_p) self.input_embedding = Embedding( - vocab_size = config.vocab_size, - embedding_size = config.dim_model, - length_scale = config.length_scale, - dtype = config.dtype, - int8 = config.int8, - init_mean = config.emb_init_mean, - init_std = config.emb_init_std, + vocab_size=config.vocab_size, + embedding_size=config.dim_model, + length_scale=config.length_scale, + dtype=config.dtype, + int8=config.int8, + init_mean=config.emb_init_mean, + init_std=config.emb_init_std, ) self.position_embedding = Embedding( - vocab_size = config.position_size, - embedding_size = config.dim_model, - length_scale = config.length_scale, - dtype = config.dtype, - int8 = config.int8, - init_mean = config.emb_init_mean, - init_std = config.emb_init_std, + vocab_size=config.position_size, + embedding_size=config.dim_model, + length_scale=config.length_scale, + dtype=config.dtype, + int8=config.int8, + init_mean=config.emb_init_mean, + init_std=config.emb_init_std, ) self.tied = config.tied self.cls_head = config.cls_head if self.cls_head: self.cls_projection = Linear( - dim_out = self.cls_head, - dim_in = config.dim_model, - length_scale = config.length_scale, - dtype = config.dtype, - int8 = config.int8, - init_mean = config.proj_init_mean, - init_std = config.proj_init_std, - bias = config.proj_bias, + dim_out=self.cls_head, + dim_in=config.dim_model, + length_scale=config.length_scale, + dtype=config.dtype, + int8=config.int8, + init_mean=config.proj_init_mean, + init_std=config.proj_init_std, + bias=config.proj_bias, ) if not self.tied: self.output_projection = Linear( - dim_out = config.vocab_size, - dim_in = config.dim_model, - length_scale = config.length_scale, - dtype = config.dtype, - int8 = config.int8, - init_mean = config.proj_init_mean, - init_std = config.proj_init_std, - bias = config.proj_bias, + dim_out=config.vocab_size, + dim_in=config.dim_model, + length_scale=config.length_scale, + dtype=config.dtype, + int8=config.int8, + init_mean=config.proj_init_mean, + init_std=config.proj_init_std, + bias=config.proj_bias, ) - def forward(self, - input_ids = None, # (batch, seqlen) - length = None, # (batch) - attention_mask = None, # (batch, seqlen) - token_type_ids = None, - position_ids = None, - head_mask = None, #unused - inputs_embeds = None, - encoder_hidden_states = None, #unused - encoder_attention_mask = None, #unused - output_attentions = None, #unused - output_hidden_states = None, #unused - return_dict = True, - return_logits = False, - ): + def forward(self, + input_ids=None, # (batch, seqlen) + length=None, # (batch) + attention_mask=None, # (batch, seqlen) + token_type_ids=None, + position_ids=None, + head_mask=None, # unused + inputs_embeds=None, + encoder_hidden_states=None, # unused + encoder_attention_mask=None, # unused + output_attentions=None, # unused + output_hidden_states=None, # unused + return_dict=True, + ): """ The GPT-2 Model transformer outputs raw hidden-states or logits as you want. This model inherits from BaseModel. This model is also a PyTorch torch.nn.Module subclass.You can use it as a regular PyTorch Module. You can also select the data and data type that you want the model to return through changing the value of `return_dict` and `return_logits`. Args: input_ids (:obj:`torch.Tensor` of shape ``(batch, seq_length)``): Indices of input sequence tokens. It will be embedded by model's internal embedding lookup matrix. - length (:obj:`torch.Tensor` of shape ``(batch)``): Length of input sequence before padding. + length (:obj:`torch.Tensor` of shape ``(batch)``): Length of input sequence before padding. attention_mask (:obj:`torch.Tensor` of shape ``(batch, seq_length)``): Used to avoid performing attention on padding token indices. - token_type_ids(:obj:`torch.Tensor` of shape ``(batch, seq_length)``): Unused. - position_ids(:obj:`torch.Tensor` of shape ``(batch, seq_length)``): Unused. + token_type_ids(:obj:`torch.Tensor` of shape ``(batch, seq_length)``): Unused. + position_ids(:obj:`torch.Tensor` of shape ``(batch, seq_length)``): Unused. head_mask (:obj:`torch.Tensor` of shape ``(num_layers, num_heads)``): Unused. - inputs_embeds (:obj:`torch.Tensor` of shape ``(batch, seq_length, dim_model)``): Embedding of the input. You can choose to directly pass the inputs embedding to control the way of embedding. - encoder_hidden_states(:obj:`torch.Tensor` of shape(batch, seq_length, dim_model)):Unused. + inputs_embeds (:obj:`torch.Tensor` of shape ``(batch, seq_length, dim_model)``): Embedding of the input. You can choose to directly pass the inputs embedding to control the way of embedding. + encoder_hidden_states(:obj:`torch.Tensor` of shape(batch, seq_length, dim_model)):Unused. encoder_attention_mask(:obj:`torch.Tensor` of shape(batch, seq_length)):Unused. output_attentions (:obj:`torch.Tensor` of shape ``(batch, num_heads, seq_length, seq_length)``): Unused. output_hidden_states (:obj:`torch.Tensor` of shape ``(batch, seq_dec, dim_model)``): Unused. @@ -135,7 +134,7 @@ def forward(self, return_logits (:obj:`bool`): Whether to return the prediction score for each token in vocabulary (before softmax). Return: - BaseModelOutputWithPastAndCrossAttentions or tuple or torch.Tensor of shape (batch, seq_dec, vocab_output_size) or (batch, seqlen, cls_head): The GPT-2 output. Depended on the value of `return_dict` and `return_logits` + BaseModelOutputWithPastAndCrossAttentions or tuple or torch.Tensor of shape (batch, seq_dec, vocab_output_size) or (batch, seqlen, cls_head): The GPT-2 output. Depended on the value of `return_dict` and `return_logits` """ assert input_ids is not None or inputs_embeds is not None @@ -156,8 +155,10 @@ def forward(self, else: attention_mask = torch.arange(seq_length, device=device)[None, :].repeat(batch, 1) < length[:, None] - directional_mask_2d = torch.arange(seq_length, device=device) <= torch.arange(seq_length, device=device).view(-1, 1) - attention_mask = attention_mask.view(batch, 1, seq_length) & directional_mask_2d.view(1, seq_length, seq_length) + directional_mask_2d = torch.arange(seq_length, device=device) <= torch.arange(seq_length, + device=device).view(-1, 1) + attention_mask = attention_mask.view(batch, 1, seq_length) & directional_mask_2d.view(1, seq_length, + seq_length) if position_ids is None: position_ids = torch.arange(seq_length, dtype=torch.int32, device=device)[None, :].repeat(batch, 1) @@ -173,18 +174,6 @@ def forward(self, hidden_states = self.encoder(hidden_states, attention_mask) - if self.cls_head: - logits = self.cls_projection(hidden_states) - elif self.tied: - logits = self.input_embedding.projection(hidden_states) - logits[:, :, -1] = -float("inf") # TODO not an elegant implementation, gpt2 vocab is odd number, expand to even and ignore last - elif not self.tied: - logits = self.output_projection(hidden_states) - logits[:, :, -1] = -float("inf") # TODO not an elegant implementation, gpt2 vocab is odd number, expand to even and ignore last - - if return_logits: - return logits - if not return_dict: return tuple(hidden_states, None, None, None, None) else: @@ -195,3 +184,69 @@ def forward(self, attentions=None, cross_attentions=None, ) + + +class GPT2ForLM(BaseModel): + _CONFIG_TYPE = GPT2Config + + def __init__(self, config: GPT2Config): + super().__init__() + self.gpt2 = GPT2(config) + + def forward(self, + input_ids=None, # (batch, seqlen) + length=None, # (batch) + attention_mask=None, # (batch, seqlen) + token_type_ids=None, + position_ids=None, + head_mask=None, # unused + inputs_embeds=None, + encoder_hidden_states=None, # unused + encoder_attention_mask=None, # unused + labels=None, + output_attentions=None, # unused + output_hidden_states=None, # unused + return_dict=True, + ): + outputs = self.gpt2( + input_ids=input_ids, + length=length, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = outputs[0] + if self.cls_head: + logits = self.cls_projection(hidden_states) + elif self.tied: + logits = self.input_embedding.projection(hidden_states) + logits[:, :, -1] = -float( + "inf") # TODO not an elegant implementation, gpt2 vocab is odd number, expand to even and ignore last + elif not self.tied: + logits = self.output_projection(hidden_states) + logits[:, :, -1] = -float( + "inf") # TODO not an elegant implementation, gpt2 vocab is odd number, expand to even and ignore last + if labels: + loss_func = bmt.loss.FusedCrossEntropy(ignore_index=-100) + # TODO:检查一下fusedCrossEntropy是不是就是 + _logits = logits[..., :-1, :].contiguous() + _labels = labels[..., 1:].contiguous() + loss = loss_func(_logits.view(-1, _logits.size(-1)), labels.view(-1)) + + if not return_dict: + return (loss, logits, outputs.past_key_values, hidden_states, outputs.attentions, outputs.cross_attentions) + return CausalLMOutputWithCrossAttentions( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) diff --git a/model_center/model/gptj.py b/model_center/model/gptj.py index 2efeed2..9fa9ac2 100644 --- a/model_center/model/gptj.py +++ b/model_center/model/gptj.py @@ -17,113 +17,113 @@ from ..layer import Encoder, Embedding, Linear, RotaryEmbedding from .basemodel import BaseModel from .config import GPTjConfig -from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions +from .modeling_output import CausalLMOutputWithPast, BaseModelOutputWithPastAndCrossAttentions +import bmtrain as bmt -class GPTj(BaseModel): +class GPTj(BaseModel): _CONFIG_TYPE = GPTjConfig def __init__(self, config: GPTjConfig): - + super().__init__() self.encoder = Encoder( - num_layers = config.num_layers, - dim_model = config.dim_model, - dim_ff = config.dim_ff, - num_heads = config.num_heads, - dim_head = config.dim_head, - dtype = config.dtype, - int8 = config.int8, - norm_eps = config.norm_eps, - norm_init_var = config.norm_init_var, - norm_bias = config.norm_bias, - att_init_mean = config.att_init_mean, - att_init_std = config.att_init_std, - att_bias = config.att_bias, - att_mask_value = float(config.att_mask_value), - pos_bias_type = config.pos_bias_type, - ffn_init_mean = config.ffn_init_mean, - ffn_init_std = config.ffn_init_std, - ffn_bias = config.ffn_bias, - ffn_activate_fn = config.ffn_activate_fn, - length_scale = config.length_scale, - attn_scale = config.attn_scale, - dropout_p = config.dropout_p, - parallel_ffn = True, + num_layers=config.num_layers, + dim_model=config.dim_model, + dim_ff=config.dim_ff, + num_heads=config.num_heads, + dim_head=config.dim_head, + dtype=config.dtype, + int8=config.int8, + norm_eps=config.norm_eps, + norm_init_var=config.norm_init_var, + norm_bias=config.norm_bias, + att_init_mean=config.att_init_mean, + att_init_std=config.att_init_std, + att_bias=config.att_bias, + att_mask_value=float(config.att_mask_value), + pos_bias_type=config.pos_bias_type, + ffn_init_mean=config.ffn_init_mean, + ffn_init_std=config.ffn_init_std, + ffn_bias=config.ffn_bias, + ffn_activate_fn=config.ffn_activate_fn, + length_scale=config.length_scale, + attn_scale=config.attn_scale, + dropout_p=config.dropout_p, + parallel_ffn=True, ) self.input_embedding = Embedding( - vocab_size = config.vocab_size, - embedding_size = config.dim_model, - length_scale = config.length_scale, - dtype = config.dtype, - int8 = config.int8, - init_mean = config.emb_init_mean, - init_std = config.emb_init_std, + vocab_size=config.vocab_size, + embedding_size=config.dim_model, + length_scale=config.length_scale, + dtype=config.dtype, + int8=config.int8, + init_mean=config.emb_init_mean, + init_std=config.emb_init_std, ) self.position_bias = RotaryEmbedding( - rotary_dim = config.pos_rotary_dim, + rotary_dim=config.pos_rotary_dim, ) - + self.tied = config.tied self.cls_head = config.cls_head if self.cls_head: self.cls_projection = Linear( - dim_out = self.cls_head, - dim_in = config.dim_model, - length_scale = config.length_scale, - dtype = config.dtype, - int8 = config.int8, - init_mean = config.proj_init_mean, - init_std = config.proj_init_std, - bias = config.proj_bias, + dim_out=self.cls_head, + dim_in=config.dim_model, + length_scale=config.length_scale, + dtype=config.dtype, + int8=config.int8, + init_mean=config.proj_init_mean, + init_std=config.proj_init_std, + bias=config.proj_bias, ) if not self.tied: self.output_projection = Linear( - dim_out = config.vocab_size, - dim_in = config.dim_model, - length_scale = config.length_scale, - dtype = config.dtype, - int8 = config.int8, - init_mean = config.proj_init_mean, - init_std = config.proj_init_std, - bias = config.proj_bias, + dim_out=config.vocab_size, + dim_in=config.dim_model, + length_scale=config.length_scale, + dtype=config.dtype, + int8=config.int8, + init_mean=config.proj_init_mean, + init_std=config.proj_init_std, + bias=config.proj_bias, ) - def forward(self, - input_ids=None, # (batch, seqlen) - length=None, # (batch) + def forward(self, + input_ids=None, # (batch, seqlen) + length=None, # (batch) attention_mask=None, - token_type_ids=None, #unused - position_ids=None, #unused - head_mask=None, #unused + token_type_ids=None, # unused + position_ids=None, # unused + head_mask=None, # unused inputs_embeds=None, - output_attentions=None, #unused - output_hidden_states=None, #unused + output_attentions=None, # unused + output_hidden_states=None, # unused return_dict=True, - return_logits = False, - ): + ): """ The GPT-J Model transformer outputs raw hidden-states or logits as you want. This model inherits from BaseModel. This model is also a PyTorch torch.nn.Module subclass.You can use it as a regular PyTorch Module. You can also select the data and data type that you want the model to return through changing the value of `return_dict` and `return_logits`. Args: input_ids (:obj:`torch.Tensor` of shape ``(batch, seq_length)``): Indices of input sequence tokens. It will be embedded by model's internal embedding lookup matrix. - length (:obj:`torch.Tensor` of shape ``(batch)``): Length of input sequence before padding. + length (:obj:`torch.Tensor` of shape ``(batch)``): Length of input sequence before padding. attention_mask (:obj:`torch.Tensor` of shape ``(batch, seq_length)``): Used to avoid performing attention on padding token indices. - token_type_ids(:obj:`torch.Tensor` of shape ``(batch, seq_length)``): Unused. + token_type_ids(:obj:`torch.Tensor` of shape ``(batch, seq_length)``): Unused. position_ids(:obj:`torch.Tensor` of shape ``(batch, seq_length)``): Unused. head_mask (:obj:`torch.Tensor` of shape ``(num_layers, num_heads)``): Unused. - inputs_embeds (:obj:`torch.Tensor` of shape ``(batch, seq_length, dim_model)``): Embedding of the input. You can choose to directly pass the inputs embedding to control the way of embedding. + inputs_embeds (:obj:`torch.Tensor` of shape ``(batch, seq_length, dim_model)``): Embedding of the input. You can choose to directly pass the inputs embedding to control the way of embedding. output_attentions (:obj:`torch.Tensor` of shape ``(batch, num_heads, seq_length, seq_length)``): Unused. output_hidden_states (:obj:`torch.Tensor` of shape ``(batch, seq_dec, dim_model)``): Unused. return_dict (:obj:`bool`): Whether to return a BaseModelOutputWithPastAndCrossAttentions instead of just a tuple. return_logits (:obj:`bool`): Whether to return the prediction score for each token in vocabulary (before softmax). Return: - BaseModelOutputWithPastAndCrossAttentions or tuple or torch.Tensor of shape (batch, seq_dec, vocab_output_size) or (batch, seqlen, cls_head): The GPT-J output. Depended on the value of `return_dict` and `return_logits` + BaseModelOutputWithPastAndCrossAttentions or tuple or torch.Tensor of shape (batch, seq_dec, vocab_output_size) or (batch, seqlen, cls_head): The GPT-J output. Depended on the value of `return_dict` and `return_logits` """ assert input_ids is not None or inputs_embeds is not None @@ -142,8 +142,10 @@ def forward(self, attention_mask = attention_mask.to(torch.bool) else: attention_mask = torch.arange(seq_length, device=device)[None, :].repeat(batch, 1) < length[:, None] - directional_mask_2d = torch.arange(seq_length, device=device) <= torch.arange(seq_length, device=device).view(-1, 1) - attention_mask = attention_mask.view(batch, 1, seq_length) & directional_mask_2d.view(1, seq_length, seq_length) + directional_mask_2d = torch.arange(seq_length, device=device) <= torch.arange(seq_length, + device=device).view(-1, 1) + attention_mask = attention_mask.view(batch, 1, seq_length) & directional_mask_2d.view(1, seq_length, + seq_length) if inputs_embeds is None: hidden_states = self.input_embedding(input_ids) @@ -152,16 +154,6 @@ def forward(self, hidden_states = self.encoder(hidden_states, attention_mask, self.position_bias) - if self.cls_head: - logits = self.cls_projection(hidden_states) - elif self.tied: - logits = self.input_embedding.projection(hidden_states) - elif not self.tied: - logits = self.output_projection(hidden_states) - - if return_logits: - return logits - if not return_dict: return tuple(hidden_states, None, None, None, None) else: @@ -172,3 +164,59 @@ def forward(self, attentions=None, cross_attentions=None, ) + + +class GPTjForLM(BaseModel): + _CONFIG_TYPE = GPTjConfig + + def __init__(self, config: GPTjConfig): + super().__init__() + self.gptj = GPTj(config) + + def forward(self, + input_ids=None, # (batch, seqlen) + length=None, # (batch) + attention_mask=None, + token_type_ids=None, # unused + position_ids=None, # unused + head_mask=None, # unused + inputs_embeds=None, + labels=None, + output_attentions=None, # unused + output_hidden_states=None, # unused + return_dict=True, + ): + outputs = self.gptj( + imput_ids=input_ids, + length = length, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = outputs[0] + if self.cls_head: + logits = self.cls_projection(hidden_states) + elif self.tied: + logits = self.input_embedding.projection(hidden_states) + elif not self.tied: + logits = self.output_projection(hidden_states) + if labels: + _logits = logits[..., :-1,:].contiguous() + _labels = labels[..., 1:].contiguous() + loss_func = bmt.loss.FusedCrossEntropy(ignore_index=-100) + loss = loss_func(_logits.view(-1, _logits.size(-1)), labels.view(-1)) + + if not return_dict: + return (loss, logits, outputs.past_key_values, hidden_states, outputs.attentions) + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) \ No newline at end of file diff --git a/model_center/model/modeling_output.py b/model_center/model/modeling_output.py new file mode 100644 index 0000000..2150776 --- /dev/null +++ b/model_center/model/modeling_output.py @@ -0,0 +1,89 @@ +import os +from typing import Union +import torch +import bmtrain as bmt +from .config.config import Config +from ..utils import check_web_and_convert_path +from collections import OrderedDict +from typing import Optional, Tuple + + +class ModelOutput(OrderedDict): + def __getitem__(self, k): + if isinstance(k, str): + inner_dict = {k: v for (k, v) in self.items()} + return inner_dict[k] + else: + return self.to_tuple()[k] + + def __setattr__(self, name, value): + if name in self.keys() and value is not None: + # Don't call self.__setitem__ to avoid recursion errors + super().__setitem__(name, value) + super().__setattr__(name, value) + + def __setitem__(self, key, value): + # Will raise a KeyException if needed + super().__setitem__(key, value) + # Don't call self.__setattr__ to avoid recursion errors + super().__setattr__(key, value) + + def to_tuple(self): + """ + Convert self to a tuple containing all the attributes/keys that are not `None`. + """ + return tuple(self[k] for k in self.keys()) + + +class BaseModelOutputWithPoolingAndCrossAttentions(ModelOutput): + last_hidden_state: torch.FloatTensor = None + pooler_output: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor]] = None + + +class CausalLMOutputWithCrossAttentions(ModelOutput): + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor]] = None + +class BaseModelOutputWithPastAndCrossAttentions(ModelOutput): + last_hidden_state: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor]] = None + +class CausalLMOutputWithPast(ModelOutput): + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +class Seq2SeqModelOutput(ModelOutput): + last_hidden_state: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor]] = None + encoder_last_hidden_state: Optional[torch.FloatTensor] = None + encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None + +class Seq2SeqLMOutput(ModelOutput): + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor]] = None + encoder_last_hidden_state: Optional[torch.FloatTensor] = None + encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None \ No newline at end of file diff --git a/model_center/model/t5.py b/model_center/model/t5.py index 4f88c67..b774234 100644 --- a/model_center/model/t5.py +++ b/model_center/model/t5.py @@ -17,149 +17,148 @@ from ..layer import Encoder, Decoder, Embedding, Linear, RelativePositionEmbedding from .basemodel import BaseModel from .config import T5Config -from transformers.modeling_outputs import Seq2SeqModelOutput +from .modeling_output import Seq2SeqModelOutput, Seq2SeqLMOutput +import bmtrain as bmt -class T5(BaseModel): - +class T5(BaseModel): _CONFIG_TYPE = T5Config def __init__(self, config: T5Config): - + super().__init__() self.config = config self.encoder = Encoder( - num_layers = config.num_encoder_layers, - dim_model = config.dim_model, - dim_ff = config.dim_ff, - num_heads = config.num_heads, - dim_head = config.dim_head, - dtype = config.dtype, - int8 = config.int8, - norm_eps = config.norm_eps, - norm_init_var = config.norm_init_var, - norm_bias = config.norm_bias, - att_init_mean = config.att_init_mean, - att_init_std = config.att_init_std, - att_bias = config.att_bias, - att_mask_value = float(config.att_mask_value), - pos_bias_type = config.pos_bias_type, - ffn_init_mean = config.ffn_init_mean, - ffn_init_std = config.ffn_init_std, - ffn_bias = config.ffn_bias, - ffn_activate_fn = config.ffn_activate_fn, - length_scale = config.length_scale, - attn_scale = config.attn_scale, - dropout_p = config.dropout_p, + num_layers=config.num_encoder_layers, + dim_model=config.dim_model, + dim_ff=config.dim_ff, + num_heads=config.num_heads, + dim_head=config.dim_head, + dtype=config.dtype, + int8=config.int8, + norm_eps=config.norm_eps, + norm_init_var=config.norm_init_var, + norm_bias=config.norm_bias, + att_init_mean=config.att_init_mean, + att_init_std=config.att_init_std, + att_bias=config.att_bias, + att_mask_value=float(config.att_mask_value), + pos_bias_type=config.pos_bias_type, + ffn_init_mean=config.ffn_init_mean, + ffn_init_std=config.ffn_init_std, + ffn_bias=config.ffn_bias, + ffn_activate_fn=config.ffn_activate_fn, + length_scale=config.length_scale, + attn_scale=config.attn_scale, + dropout_p=config.dropout_p, ) self.decoder = Decoder( - num_layers = config.num_decoder_layers, - dim_model = config.dim_model, - dim_ff = config.dim_ff, - num_heads = config.num_heads, - dim_head = config.dim_head, - dtype = config.dtype, - int8 = config.int8, - norm_eps = config.norm_eps, - norm_init_var = config.norm_init_var, - norm_bias = config.norm_bias, - att_init_mean = config.att_init_mean, - att_init_std = config.att_init_std, - att_bias = config.att_bias, - att_mask_value = float(config.att_mask_value), - pos_bias_type = config.pos_bias_type, - ffn_init_mean = config.ffn_init_mean, - ffn_init_std = config.ffn_init_std, - ffn_bias = config.ffn_bias, - ffn_activate_fn = config.ffn_activate_fn, - length_scale = config.length_scale, - attn_scale = config.attn_scale, - dropout_p = config.dropout_p, + num_layers=config.num_decoder_layers, + dim_model=config.dim_model, + dim_ff=config.dim_ff, + num_heads=config.num_heads, + dim_head=config.dim_head, + dtype=config.dtype, + int8=config.int8, + norm_eps=config.norm_eps, + norm_init_var=config.norm_init_var, + norm_bias=config.norm_bias, + att_init_mean=config.att_init_mean, + att_init_std=config.att_init_std, + att_bias=config.att_bias, + att_mask_value=float(config.att_mask_value), + pos_bias_type=config.pos_bias_type, + ffn_init_mean=config.ffn_init_mean, + ffn_init_std=config.ffn_init_std, + ffn_bias=config.ffn_bias, + ffn_activate_fn=config.ffn_activate_fn, + length_scale=config.length_scale, + attn_scale=config.attn_scale, + dropout_p=config.dropout_p, ) self.input_embedding = Embedding( - vocab_size = config.vocab_size, - embedding_size = config.dim_model, - length_scale = config.length_scale, - dtype = config.dtype, - int8 = config.int8, - init_mean = config.emb_init_mean, - init_std = config.emb_init_std, + vocab_size=config.vocab_size, + embedding_size=config.dim_model, + length_scale=config.length_scale, + dtype=config.dtype, + int8=config.int8, + init_mean=config.emb_init_mean, + init_std=config.emb_init_std, ) self.position_bias_enc = RelativePositionEmbedding( - num_heads = config.num_heads, - num_buckets = config.position_bias_num_buckets, - max_distance = config.position_bias_max_distance, - bidirectional = True, - dtype = config.dtype, - init_mean = config.pos_init_mean, - init_std = config.pos_init_std, + num_heads=config.num_heads, + num_buckets=config.position_bias_num_buckets, + max_distance=config.position_bias_max_distance, + bidirectional=True, + dtype=config.dtype, + init_mean=config.pos_init_mean, + init_std=config.pos_init_std, ) self.position_bias_dec = RelativePositionEmbedding( - num_heads = config.num_heads, - num_buckets = config.position_bias_num_buckets, - max_distance = config.position_bias_max_distance, - bidirectional = False, - dtype = config.dtype, - init_mean = config.pos_init_mean, - init_std = config.pos_init_std, + num_heads=config.num_heads, + num_buckets=config.position_bias_num_buckets, + max_distance=config.position_bias_max_distance, + bidirectional=False, + dtype=config.dtype, + init_mean=config.pos_init_mean, + init_std=config.pos_init_std, ) self.tied = config.tied self.cls_head = config.cls_head if self.cls_head: self.cls_projection = Linear( - dim_out = self.cls_head, - dim_in = config.dim_model, - length_scale = config.length_scale, - dtype = config.dtype, - int8 = config.int8, - init_mean = config.proj_init_mean, - init_std = config.proj_init_std, - bias = config.proj_bias, + dim_out=self.cls_head, + dim_in=config.dim_model, + length_scale=config.length_scale, + dtype=config.dtype, + int8=config.int8, + init_mean=config.proj_init_mean, + init_std=config.proj_init_std, + bias=config.proj_bias, ) if not self.tied: self.output_projection = Linear( - dim_out = config.vocab_size, - dim_in = config.dim_model, - length_scale = config.length_scale, - dtype = config.dtype, - int8 = config.int8, - init_mean = config.proj_init_mean, - init_std = config.proj_init_std, - bias = config.proj_bias, + dim_out=config.vocab_size, + dim_in=config.dim_model, + length_scale=config.length_scale, + dtype=config.dtype, + int8=config.int8, + init_mean=config.proj_init_mean, + init_std=config.proj_init_std, + bias=config.proj_bias, ) - def forward(self, - input_ids = None, # (batch, seq_enc) - length = None, # (batch) - decoder_input_ids = None, # (batch, seq_dec) - decoder_length = None, # (batch) - attention_mask = None, # (batch, seq_enc) - decoder_attention_mask = None, # (batch, seq_dec) - head_mask = None, # unused - decoder_head_mask = None, # unused - cross_attn_head_mask = None, # unused - encoder_outputs = None, - inputs_embeds = None, - decoder_inputs_embeds = None, - output_attentions = None, # unused - output_hidden_states = None, # unused - return_dict = True, - return_logits = False, - ): + def forward(self, + input_ids=None, # (batch, seq_enc) + length=None, # (batch) + decoder_input_ids=None, # (batch, seq_dec) + decoder_length=None, # (batch) + attention_mask=None, # (batch, seq_enc) + decoder_attention_mask=None, # (batch, seq_dec) + head_mask=None, # unused + decoder_head_mask=None, # unused + cross_attn_head_mask=None, # unused + encoder_outputs=None, + inputs_embeds=None, + decoder_inputs_embeds=None, + output_attentions=None, # unused + output_hidden_states=None, # unused + return_dict=True + ): """ T5 is an encoder-decoder model and converts problems into a text-to-text format. This model inherits from BaseModel. This model is also a PyTorch torch.nn.Module subclass. You can use it as a regular PyTorch Module. You can also select the data and data type that you want the model to return through changing the value of `return_dict` and `return_logits`. - + Args: input_ids (:obj:`torch.Tensor` of shape ``(batch, seq_enc)``): Indices of input sequence tokens. It will be embedded by model's internal embedding lookup matrix. - length (:obj:`torch.Tensor` of shape ``(batch)``): Length of input sequence before padding. + length (:obj:`torch.Tensor` of shape ``(batch)``): Length of input sequence before padding. attention_mask (:obj:`torch.Tensor` of shape ``(batch, seq_enc)``): Used to avoid performing attention on padding token indices in input. decoder_input_ids (:obj:`torch.Tensor` of shape ``(batch, seq_enc)``): Indices of decoder input sequence tokens . decoder_length (:obj:`torch.Tensor` of shape ``(batch)``): Length of decoder input sequence before padding. @@ -167,19 +166,19 @@ def forward(self, head_mask (:obj:`torch.Tensor` of shape ``(num_layers, num_heads)``): Unused. decoder_head_mask (:obj:`torch.Tensor` of shape ``(num_layers, num_heads)``): Unused. cross_attn_head_mask (:obj:`torch.Tensor` of shape ``(num_layers, num_heads)``): Unused. - encoder_outputs (:obj:`torch.Tensor` of shape ``(batch, dim_model, seq_enc)``): Outputs of encoder. - inputs_embeds (:obj:`torch.Tensor` of shape ``(batch, seq_enc, dim_model)``): Embedding of the input. You can choose to directly pass the inputs embedding to control the way of embedding. - decoder_inputs_embeds (:obj:`torch.Tensor` of shape ``(batch, seq_dec, dim_model)``): Embedding of the decoder input. You can choose to directly pass the inputs embedding to control the way of embedding. + encoder_outputs (:obj:`torch.Tensor` of shape ``(batch, dim_model, seq_enc)``): Outputs of encoder. + inputs_embeds (:obj:`torch.Tensor` of shape ``(batch, seq_enc, dim_model)``): Embedding of the input. You can choose to directly pass the inputs embedding to control the way of embedding. + decoder_inputs_embeds (:obj:`torch.Tensor` of shape ``(batch, seq_dec, dim_model)``): Embedding of the decoder input. You can choose to directly pass the inputs embedding to control the way of embedding. output_attentions (:obj:`torch.Tensor` of shape ``(batch, num_heads, seq_enc, seq_enc)``): Unused. output_hidden_states (:obj:`torch.Tensor` of shape ``(batch, seq_dec, dim_model)``): Unused. return_dict (:obj:`bool`): Whether to return a Seq2SeqModelOutput instead of just a tuple. return_logits (:obj:`bool`): Whether to return the prediction score for each token in vocabulary (before softmax). Return: - Seq2SeqModelOutput or tuple or torch.Tensor of shape (batch, seq_dec, vocab_output_size) or (batch, seqlen, cls_head): The T5 output. Depended on the value of `return_dict` and `return_logits` + Seq2SeqModelOutput or tuple or torch.Tensor of shape (batch, seq_dec, vocab_output_size) or (batch, seqlen, cls_head): The T5 output. Depended on the value of `return_dict` and `return_logits` + + """ - """ - # encoder if encoder_outputs is None: assert input_ids is not None or inputs_embeds is not None @@ -192,7 +191,7 @@ def forward(self, batch = inputs_embeds.size(0) seq_enc = inputs_embeds.size(1) device = inputs_embeds.device - + with torch.no_grad(): if attention_mask is not None: attention_mask = attention_mask.to(torch.bool) @@ -203,7 +202,7 @@ def forward(self, # (num_heads, seq_enc, seq_enc) enc_position_bias = self.position_bias_enc(seq_enc, seq_enc) - + # (batch, dim_model, seq_enc) if inputs_embeds is None: hidden_states_enc = self.input_embedding(input_ids) @@ -224,17 +223,23 @@ def forward(self, batch = decoder_inputs_embeds.size(0) seq_dec = decoder_inputs_embeds.size(1) device = decoder_inputs_embeds.device - + with torch.no_grad(): if decoder_attention_mask is not None: decoder_attention_mask = decoder_attention_mask.to(torch.bool) else: - decoder_attention_mask = torch.arange(seq_dec, device=device)[None, :].repeat(batch, 1) < decoder_length[:, None] - directional_mask_2d = torch.arange(seq_dec, device=device) <= torch.arange(seq_dec, device=device).view(-1, 1) + decoder_attention_mask = torch.arange(seq_dec, device=device)[None, :].repeat(batch, + 1) < decoder_length[:, + None] + directional_mask_2d = torch.arange(seq_dec, device=device) <= torch.arange(seq_dec, device=device).view(-1, + 1) # (batch, seq_dec, seq_dec) - dec_attention_mask = decoder_attention_mask.view(batch, seq_dec, 1) & decoder_attention_mask.view(batch, 1, seq_dec) & directional_mask_2d.view(1, seq_dec, seq_dec) + dec_attention_mask = decoder_attention_mask.view(batch, seq_dec, 1) & decoder_attention_mask.view(batch, 1, + seq_dec) & directional_mask_2d.view( + 1, seq_dec, seq_dec) # (batch, seq_dec, seq_enc) - cross_attention_mask = attention_mask.view(batch, 1, seq_enc) & decoder_attention_mask.view(batch, seq_dec, 1) + cross_attention_mask = attention_mask.view(batch, 1, seq_enc) & decoder_attention_mask.view(batch, seq_dec, + 1) # (num_heads, seq_dec, seq_dec) dec_position_bias = self.position_bias_dec(seq_dec, seq_dec) @@ -248,17 +253,6 @@ def forward(self, decoder_outputs = self.decoder(hidden_states_dec, dec_attention_mask, dec_position_bias, encoder_outputs, cross_attention_mask, None) - # (batch, seq_dec, vocab_output_size) - if self.cls_head: - logits = self.cls_projection(decoder_outputs) - elif self.tied: - logits = self.input_embedding.projection(decoder_outputs) - elif not self.tied: - logits = self.output_projection(decoder_outputs) - - if return_logits: - return logits#*(100*self.config.dim_model**-0.5) - if not return_dict: return tuple(decoder_outputs, None, None, None, None) else: @@ -271,4 +265,77 @@ def forward(self, decoder_attentions=None, cross_attentions=None, encoder_attentions=None, - ) \ No newline at end of file + ) + + +class T5ForLM(BaseModel): + _CONFIG_TYPE = T5Config + + def __init__(self, config: T5Config): + super().__init__() + self.t5 = T5(config) + + def forward(self, + input_ids=None, # (batch, seq_enc) + length=None, # (batch) + decoder_input_ids=None, # (batch, seq_dec) + decoder_length=None, # (batch) + attention_mask=None, # (batch, seq_enc) + decoder_attention_mask=None, # (batch, seq_dec) + head_mask=None, # unused + decoder_head_mask=None, # unused + cross_attn_head_mask=None, # unused + encoder_outputs=None, + inputs_embeds=None, + decoder_inputs_embeds=None, + labels=None, + output_attentions=None, # unused + output_hidden_states=None, # unused + return_dict=True, + ): + outputs = self.t5( + input_ids=input_ids, # (batch, seq_enc) + length=length, # (batch) + decoder_input_ids=decoder_input_ids, # (batch, seq_dec) + decoder_length=decoder_length, # (batch) + attention_mask=attention_mask, # (batch, seq_enc) + decoder_attention_mask=decoder_attention_mask, # (batch, seq_dec) + head_mask=head_mask, # unused + decoder_head_mask=decoder_head_mask, # unused + cross_attn_head_mask=cross_attn_head_mask, # unused + encoder_outputs=encoder_outputs, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + output_attentions=output_attentions, # unused + output_hidden_states=output_hidden_states, # unused + return_dict=return_dict, + ) + last_hidden_state = outputs[0] + # (batch, seq_dec, vocab_output_size) + if self.cls_head: + logits = self.cls_projection(last_hidden_state) + elif self.tied: + logits = self.input_embedding.projection(last_hidden_state) + elif not self.tied: + logits = self.output_projection(last_hidden_state) + + if labels: + loss_func = bmt.FusedCrossEntropy(ignore_index=-100) + loss = loss_func(logits.view(-1, logits.size(-1)), labels.view(-1)) + + if not return_dict: + return (loss, logits, outputs.past_key_values, outputs.decoder_hidden_states, outputs.decoder_attentions, + outputs.cross_attentions, outputs.encoder_last_hidden_state, + outputs.encoder_hidden_states, outputs.encoder_attentions) + + return Seq2SeqLMOutput( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) From bfdda9675c82648af0df2a6095ea0ef0ea36d035 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B9=94=E5=AD=90=E5=8D=BF?= <2020011015@secoder.net> Date: Tue, 26 Apr 2022 20:22:45 +0800 Subject: [PATCH 2/3] fix bugs in ModelForLM and add Roberta model --- model_center/model/bert.py | 50 ++-- model_center/model/config/roberta_config.py | 107 +++++++ model_center/model/cpm1.py | 50 ++-- model_center/model/cpm2.py | 23 +- model_center/model/gpt2.py | 50 ++-- model_center/model/gptj.py | 50 ++-- model_center/model/roberta.py | 296 ++++++++++++++++++++ model_center/model/t5.py | 50 ++-- tests/test.sh | 1 + tests/test_roberta.py | 34 +++ transfer/hugRoberta_bmtrainRoberta.py | 130 +++++++++ 11 files changed, 700 insertions(+), 141 deletions(-) create mode 100644 model_center/model/config/roberta_config.py create mode 100644 model_center/model/roberta.py create mode 100644 tests/test_roberta.py create mode 100644 transfer/hugRoberta_bmtrainRoberta.py diff --git a/model_center/model/bert.py b/model_center/model/bert.py index 1a0b5df..51a76b6 100644 --- a/model_center/model/bert.py +++ b/model_center/model/bert.py @@ -122,25 +122,6 @@ def __init__(self, config: BertConfig): post_layer_norm=config.post_layer_norm, ) - self.tied = config.tied - self.cls_head = config.cls_head - if self.cls_head: - self.cls_projection = Linear( - dim_out=self.cls_head, - dim_in=config.dim_model, - length_scale=config.length_scale, - dtype=config.dtype, - int8=config.int8, - init_mean=config.proj_init_mean, - init_std=config.proj_init_std, - bias=config.proj_bias, - ) - if not self.tied: - self.lm_head = BertLMHead( - dim_model=config.dim_model, - vocab_size=config.vocab_size, - norm_eps=config.norm_eps, - ) self.pooler = BertPooler(config.dim_model) @@ -221,16 +202,6 @@ def forward(self, hidden_states = self.encoder(hidden_states, attention_mask) - """if self.cls_head: - logits = self.cls_projection(hidden_states) - elif self.tied: - logits = self.input_embedding.projection(hidden_states) - elif not self.tied: - logits = self.lm_head(hidden_states) - - if return_logits: - return logits""" - pooled_output = self.pooler(hidden_states) if not return_dict: @@ -253,6 +224,25 @@ def __init__(self, config: BertConfig): super().__init__() self.bert = Bert(config) self.seq_cls = torch.nn.Linear(config.hidden_size, 2) + self.tied = config.tied + self.cls_head = config.cls_head + if self.cls_head: + self.cls_projection = Linear( + dim_out=self.cls_head, + dim_in=config.dim_model, + length_scale=config.length_scale, + dtype=config.dtype, + int8=config.int8, + init_mean=config.proj_init_mean, + init_std=config.proj_init_std, + bias=config.proj_bias, + ) + if not self.tied: + self.lm_head = BertLMHead( + dim_model=config.dim_model, + vocab_size=config.vocab_size, + norm_eps=config.norm_eps, + ) def forward(self, input_ids=None, @@ -285,7 +275,7 @@ def forward(self, if self.cls_head: logits = self.cls_projection(hidden_states) elif self.tied: - logits = self.input_embedding.projection(hidden_states) + logits = self.bert.input_embedding.projection(hidden_states) elif not self.tied: logits = self.lm_head(hidden_states) diff --git a/model_center/model/config/roberta_config.py b/model_center/model/config/roberta_config.py new file mode 100644 index 0000000..1ee4bbe --- /dev/null +++ b/model_center/model/config/roberta_config.py @@ -0,0 +1,107 @@ +# coding=utf-8 +# Copyright 2022 The OpenBMB team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .config import Config +import torch + +class RobertaConfig(Config): + """ + This is a configuration class that stores the configuration of the BERT model, which inherits from the Config class. + It is used to instantiate the Bert model according to the specified parameters and define the model architecture. + You can set specific parameters to control the output of the model. + + For example: + [`dim_model`] is used to determine the Dimension of the encoder layers and the pooler layer. + You can choose to use the default value of 768 or customize their dimensions. + + """ + + def __init__(self, vocab_size=50265, + type_size=1, + position_size=514, + dim_model=1024, + num_heads=16, + dim_head=64, + dim_ff=4096, + num_layers=24, + dropout_p=0.1, + emb_init_mean = 0.0, + emb_init_std = 1, + pos_bias_type = "none", + position_bias_max_distance = 1024, + norm_init_var = 1.0, + norm_bias = True, + norm_eps = 1e-05, + att_init_mean = 0.0, + att_init_std = 0.02, + att_bias = True, + att_mask_value = float("-1e4"), + ffn_init_mean = 0.0, + ffn_init_std = 0.02, + ffn_bias = True, + ffn_activate_fn = "gelu", + proj_init_mean = 0.0, + proj_init_std = 1, + proj_bias = True, + length_scale = False, + attn_scale = True, + half = True, + int8 = False, + tied = False, + cls_head = None, + post_layer_norm = True, + pad_token_id = 1, + ): + + super().__init__() + + self.vocab_size = vocab_size + self.type_size = type_size + self.position_size = position_size + self.position_size = position_size + self.dim_model = dim_model + self.num_heads = num_heads + self.dim_head = dim_head + self.dim_ff = dim_ff + self.num_layers = num_layers + self.dropout_p = dropout_p + self.emb_init_mean = emb_init_mean + self.emb_init_std = emb_init_std + self.pos_bias_type = pos_bias_type + self.position_bias_max_distance = position_bias_max_distance + self.norm_init_var = norm_init_var + self.norm_bias = norm_bias + self.norm_eps = norm_eps + self.att_init_mean = att_init_mean + self.att_init_std = att_init_std + self.att_bias = att_bias + self.att_mask_value = att_mask_value + self.ffn_init_mean = ffn_init_mean + self.ffn_init_std = ffn_init_std + self.ffn_bias = ffn_bias + self.ffn_activate_fn = ffn_activate_fn + self.proj_init_mean = proj_init_mean + self.proj_init_std = proj_init_std + self.proj_bias = proj_bias + self.length_scale = length_scale + self.attn_scale = attn_scale + self.int8 = int8 + self.tied = tied + if half: + self.dtype = torch.half + else: + self.dtype = torch.float + self.cls_head = cls_head + self.post_layer_norm = post_layer_norm + self.pad_token_id = pad_token_id diff --git a/model_center/model/cpm1.py b/model_center/model/cpm1.py index f07ccbc..ba9edbd 100644 --- a/model_center/model/cpm1.py +++ b/model_center/model/cpm1.py @@ -73,30 +73,6 @@ def __init__(self, config: CPM1Config): init_std=config.pos_init_std, ) - self.tied = config.tied - self.cls_head = config.cls_head - if self.cls_head: - self.output_projection = Linear( - dim_out=config.cls_head, - dim_in=config.dim_model, - length_scale=config.length_scale, - dtype=config.dtype, - int8=config.int8, - init_mean=config.proj_init_mean, - init_std=config.proj_init_std, - bias=config.proj_bias, - ) - elif not self.tied: - self.output_projection = Linear( - dim_out=config.vocab_size, - dim_in=config.dim_model, - length_scale=config.length_scale, - dtype=config.dtype, - int8=config.int8, - init_mean=config.proj_init_mean, - init_std=config.proj_init_std, - bias=config.proj_bias, - ) def forward(self, input: torch.Tensor, # (batch, seqlen) length: torch.Tensor, # (batch) @@ -149,6 +125,30 @@ class CPM1ForLM(BaseModel): def __init__(self, config: CPM1Config): super().__init__() self.cpm1 = CPM1(config) + self.tied = config.tied + self.cls_head = config.cls_head + if self.cls_head: + self.output_projection = Linear( + dim_out=config.cls_head, + dim_in=config.dim_model, + length_scale=config.length_scale, + dtype=config.dtype, + int8=config.int8, + init_mean=config.proj_init_mean, + init_std=config.proj_init_std, + bias=config.proj_bias, + ) + elif not self.tied: + self.output_projection = Linear( + dim_out=config.vocab_size, + dim_in=config.dim_model, + length_scale=config.length_scale, + dtype=config.dtype, + int8=config.int8, + init_mean=config.proj_init_mean, + init_std=config.proj_init_std, + bias=config.proj_bias, + ) def forward(self, input: torch.Tensor, # (batch, seqlen) @@ -169,7 +169,7 @@ def forward(self, elif not self.tied: logits = self.output_projection(hidden_states) else: - logits = self.input_embedding.projection(hidden_states) + logits = self.cpm1.input_embedding.projection(hidden_states) if labels: _logits = logits[..., :-1, :].contiguous() _labels = labels[..., 1:].contiguous() diff --git a/model_center/model/cpm2.py b/model_center/model/cpm2.py index 22023c5..f1a6be1 100644 --- a/model_center/model/cpm2.py +++ b/model_center/model/cpm2.py @@ -107,17 +107,7 @@ def __init__(self, config: CPM2Config): init_std=config.pos_init_std, ) - self.cls_head = config.cls_head - self.output_projection = Linear( - dim_out=self.cls_head if self.cls_head else config.vocab_size, - dim_in=config.dim_model, - length_scale=config.length_scale, - dtype=config.dtype, - int8=config.int8, - init_mean=config.proj_init_mean, - init_std=config.proj_init_std, - bias=config.proj_bias, - ) + def forward(self, enc_input: torch.Tensor, # (batch, seq_enc) @@ -195,6 +185,17 @@ class CPM2ForLM(BaseModel): def __init__(self, config: CPM2Config): super().__init__() self.cpm2 = CPM2(config) + self.cls_head = config.cls_head + self.output_projection = Linear( + dim_out=self.cls_head if self.cls_head else config.vocab_size, + dim_in=config.dim_model, + length_scale=config.length_scale, + dtype=config.dtype, + int8=config.int8, + init_mean=config.proj_init_mean, + init_std=config.proj_init_std, + bias=config.proj_bias, + ) def forward(self, enc_input: torch.Tensor, # (batch, seq_enc) diff --git a/model_center/model/gpt2.py b/model_center/model/gpt2.py index 5ad893b..e56540f 100644 --- a/model_center/model/gpt2.py +++ b/model_center/model/gpt2.py @@ -75,30 +75,6 @@ def __init__(self, config: GPT2Config): init_std=config.emb_init_std, ) - self.tied = config.tied - self.cls_head = config.cls_head - if self.cls_head: - self.cls_projection = Linear( - dim_out=self.cls_head, - dim_in=config.dim_model, - length_scale=config.length_scale, - dtype=config.dtype, - int8=config.int8, - init_mean=config.proj_init_mean, - init_std=config.proj_init_std, - bias=config.proj_bias, - ) - if not self.tied: - self.output_projection = Linear( - dim_out=config.vocab_size, - dim_in=config.dim_model, - length_scale=config.length_scale, - dtype=config.dtype, - int8=config.int8, - init_mean=config.proj_init_mean, - init_std=config.proj_init_std, - bias=config.proj_bias, - ) def forward(self, input_ids=None, # (batch, seqlen) @@ -192,6 +168,30 @@ class GPT2ForLM(BaseModel): def __init__(self, config: GPT2Config): super().__init__() self.gpt2 = GPT2(config) + self.tied = config.tied + self.cls_head = config.cls_head + if self.cls_head: + self.cls_projection = Linear( + dim_out=self.cls_head, + dim_in=config.dim_model, + length_scale=config.length_scale, + dtype=config.dtype, + int8=config.int8, + init_mean=config.proj_init_mean, + init_std=config.proj_init_std, + bias=config.proj_bias, + ) + if not self.tied: + self.output_projection = Linear( + dim_out=config.vocab_size, + dim_in=config.dim_model, + length_scale=config.length_scale, + dtype=config.dtype, + int8=config.int8, + init_mean=config.proj_init_mean, + init_std=config.proj_init_std, + bias=config.proj_bias, + ) def forward(self, input_ids=None, # (batch, seqlen) @@ -226,7 +226,7 @@ def forward(self, if self.cls_head: logits = self.cls_projection(hidden_states) elif self.tied: - logits = self.input_embedding.projection(hidden_states) + logits = self.gpt2.input_embedding.projection(hidden_states) logits[:, :, -1] = -float( "inf") # TODO not an elegant implementation, gpt2 vocab is odd number, expand to even and ignore last elif not self.tied: diff --git a/model_center/model/gptj.py b/model_center/model/gptj.py index 9fa9ac2..811c48e 100644 --- a/model_center/model/gptj.py +++ b/model_center/model/gptj.py @@ -68,30 +68,6 @@ def __init__(self, config: GPTjConfig): rotary_dim=config.pos_rotary_dim, ) - self.tied = config.tied - self.cls_head = config.cls_head - if self.cls_head: - self.cls_projection = Linear( - dim_out=self.cls_head, - dim_in=config.dim_model, - length_scale=config.length_scale, - dtype=config.dtype, - int8=config.int8, - init_mean=config.proj_init_mean, - init_std=config.proj_init_std, - bias=config.proj_bias, - ) - if not self.tied: - self.output_projection = Linear( - dim_out=config.vocab_size, - dim_in=config.dim_model, - length_scale=config.length_scale, - dtype=config.dtype, - int8=config.int8, - init_mean=config.proj_init_mean, - init_std=config.proj_init_std, - bias=config.proj_bias, - ) def forward(self, input_ids=None, # (batch, seqlen) @@ -172,6 +148,30 @@ class GPTjForLM(BaseModel): def __init__(self, config: GPTjConfig): super().__init__() self.gptj = GPTj(config) + self.tied = config.tied + self.cls_head = config.cls_head + if self.cls_head: + self.cls_projection = Linear( + dim_out=self.cls_head, + dim_in=config.dim_model, + length_scale=config.length_scale, + dtype=config.dtype, + int8=config.int8, + init_mean=config.proj_init_mean, + init_std=config.proj_init_std, + bias=config.proj_bias, + ) + if not self.tied: + self.output_projection = Linear( + dim_out=config.vocab_size, + dim_in=config.dim_model, + length_scale=config.length_scale, + dtype=config.dtype, + int8=config.int8, + init_mean=config.proj_init_mean, + init_std=config.proj_init_std, + bias=config.proj_bias, + ) def forward(self, input_ids=None, # (batch, seqlen) @@ -202,7 +202,7 @@ def forward(self, if self.cls_head: logits = self.cls_projection(hidden_states) elif self.tied: - logits = self.input_embedding.projection(hidden_states) + logits = self.gptj.input_embedding.projection(hidden_states) elif not self.tied: logits = self.output_projection(hidden_states) if labels: diff --git a/model_center/model/roberta.py b/model_center/model/roberta.py new file mode 100644 index 0000000..380111a --- /dev/null +++ b/model_center/model/roberta.py @@ -0,0 +1,296 @@ +# coding=utf-8 +# Copyright 2022 The OpenBMB team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + +from ..layer import Encoder, Embedding, Linear, LayerNorm +from .basemodel import BaseModel +from .config import RobertaConfig +from .modeling_output import ModelOutput, BaseModelOutputWithPoolingAndCrossAttentions +from typing import Optional, Tuple +import bmtrain as bmt + + +class RoertaPooler(torch.nn.Module): + def __init__(self, dim_model): + super().__init__() + self.dense = Linear(dim_model, dim_model, bias=True) + self.activation = torch.nn.Tanh() + + def forward(self, hidden_states): + pooled_output = self.dense(hidden_states[:, 0, :]) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class RoertaLMHead(torch.nn.Module): + def __init__(self, dim_model, vocab_size, norm_eps): + super().__init__() + self.dense = Linear(dim_model, dim_model, bias=True) + self.act_fn = torch.nn.functional.gelu + self.layer_norm = LayerNorm(dim_model, eps=norm_eps) + self.decoder = Linear(dim_model, vocab_size, bias=True) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.act_fn(hidden_states) + hidden_states = self.layer_norm(hidden_states) + logits = self.decoder(hidden_states) + return logits + + +class RobertaForPreTrainingOutput(ModelOutput): + loss: Optional[torch.FloatTensor] = None + prediction_logits: torch.FloatTensor = None + past_key_values: Optional[Tuple[torch.FloatTensor]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + cross_attention: Optional[Tuple[torch.FloatTensor]] = None + + +class Roberta(BaseModel): + _CONFIG_TYPE = RobertaConfig + + def __init__(self, config: RobertaConfig): + super().__init__() + + self.input_embedding = Embedding( + vocab_size=config.vocab_size, + embedding_size=config.dim_model, + length_scale=config.length_scale, + dtype=config.dtype, + int8=config.int8, + init_mean=config.emb_init_mean, + init_std=config.emb_init_std, + ) + + self.position_embedding = Embedding( + vocab_size=config.position_size, + embedding_size=config.dim_model, + length_scale=config.length_scale, + dtype=config.dtype, + int8=config.int8, + init_mean=config.emb_init_mean, + init_std=config.emb_init_std, + ) + + self.token_type_embedding = Embedding( + vocab_size=config.type_size, + embedding_size=config.dim_model, + length_scale=config.length_scale, + dtype=config.dtype, + int8=config.int8, + init_mean=config.emb_init_mean, + init_std=config.emb_init_std, + ) + + self.embed_dropout = torch.nn.Dropout(config.dropout_p) + + self.encoder = Encoder( + num_layers=config.num_layers, + dim_model=config.dim_model, + dim_ff=config.dim_ff, + num_heads=config.num_heads, + dim_head=config.dim_head, + dtype=config.dtype, + int8=config.int8, + norm_eps=config.norm_eps, + norm_init_var=config.norm_init_var, + norm_bias=config.norm_bias, + att_init_mean=config.att_init_mean, + att_init_std=config.att_init_std, + att_bias=config.att_bias, + att_mask_value=float(config.att_mask_value), + pos_bias_type=config.pos_bias_type, + ffn_init_mean=config.ffn_init_mean, + ffn_init_std=config.ffn_init_std, + ffn_bias=config.ffn_bias, + ffn_activate_fn=config.ffn_activate_fn, + length_scale=config.length_scale, + attn_scale=config.attn_scale, + dropout_p=config.dropout_p, + post_layer_norm=config.post_layer_norm, + ) + + self.pooler = RoertaPooler(config.dim_model) + self.padding_idx = config.pad_token_id + + def forward(self, + input_ids=None, + length=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, # unused + inputs_embeds=None, + encoder_hidden_states=None, # unused + encoder_attention_mask=None, # unused + output_attentions=None, # unused + output_hidden_states=None, # unused + return_dict=True, + ): + """ This model inherits from BaseModel. This model is also a PyTorch torch.nn.Module subclass. + You can use it as a regular PyTorch Module. + You can also select the data and data type that you want the model to return through changing the value of `return_dict` and `return_logits`. + + Args: + input_ids (:obj:`torch.Tensor` of shape ``(batch, seq_length)``): Indices of input sequence tokens. It will be embedded by model's internal embedding lookup matrix. + length (:obj:`torch.Tensor` of shape ``(batch)``): Length of input sequence before padding. + attention_mask (:obj:`torch.Tensor` of shape ``(batch, seq_length)``): Used to avoid performing attention on padding token indices. + token_type_ids(:obj:`torch.Tensor` of shape ``(batch, seq_length)``): Unused. + position_ids(:obj:`torch.Tensor` of shape ``(batch, seq_length)``): Unused. + head_mask (:obj:`torch.Tensor` of shape ``(num_layers, num_heads)``): Unused. + inputs_embeds (:obj:`torch.Tensor` of shape ``(batch, seq_length, dim_model)``): Embedding of the input. You can choose to directly pass the inputs embedding to control the way of embedding. + encoder_hidden_states(:obj:`torch.Tensor` of shape(batch, seq_length, dim_model)): Unused. + encoder_attention_mask (:obj:`torch.Tensor` of shape ``(batch, seq_length)``): Unused. + output_attentions (:obj:`torch.Tensor` of shape ``(batch, num_heads, seq_length, seq_length)``): Unused. + output_hidden_states (:obj:`torch.Tensor` of shape ``(batch, seq_length, dim_model)``): Unused. + return_dict (:obj:`bool`): Whether to return a BaseModelOutputWithPoolingAndCrossAttentions instead of just a tuple. + + Return: + BaseModelOutputWithPoolingAndCrossAttentions or tuple or torch.Tensor of shape (batch, seq_length, vocab_output_size) or (batch, seqlen, cls_head): The Bert output. Depended on the value of `return_dict` and `return_logits` + + """ + assert input_ids is not None or inputs_embeds is not None + + if input_ids is not None: + batch = input_ids.size(0) + seq_length = input_ids.size(1) + device = input_ids.device + else: + batch = inputs_embeds.size(0) + seq_length = inputs_embeds.size(1) + device = inputs_embeds.device + + with torch.no_grad(): + + if attention_mask is not None: + attention_mask = attention_mask.to(torch.bool) + else: + attention_mask = torch.arange(seq_length, device=device)[None, :].repeat(batch, 1) < length[:, None] + attention_mask = attention_mask.view(batch, seq_length, 1) & attention_mask.view(batch, 1, seq_length) + + if position_ids is None: + position_ids = torch.arange(self.padding_idx + 1, seq_length + self.padding_idx + 1, dtype=torch.int32, + device=device)[None, :].repeat(batch, 1) + + if token_type_ids is None: + token_type_ids = torch.zeros(seq_length, dtype=torch.int32, device=device)[None, :].repeat(batch, 1) + + if inputs_embeds is None: + hidden_states = self.input_embedding(input_ids.to(torch.int32)) + else: + hidden_states = inputs_embeds + + position_embeds = self.position_embedding(position_ids.to(torch.int32)) + token_type_embeds = self.token_type_embedding(token_type_ids.to(torch.int32)) + hidden_states = hidden_states + token_type_embeds + position_embeds + + hidden_states = self.embed_dropout(hidden_states) + + hidden_states = self.encoder(hidden_states, attention_mask) + + pooled_output = self.pooler(hidden_states) + + if not return_dict: + return (hidden_states, pooled_output, None, None, None, None) + else: + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=hidden_states, + pooler_output=pooled_output, + past_key_values=None, + hidden_states=None, + attentions=None, + cross_attentions=None, + ) + + +class RobertaForLM(BaseModel): + _CONFIG_TYPE = RobertaConfig + + def __init__(self, config: RobertaConfig): + super().__init__() + self.roberta = Roberta(config) + self.cls_head = config.cls_head + self.tied = config.tied + self.vocab_size = config.vocab_size + if self.cls_head: + self.cls_projection = Linear( + dim_out=self.cls_head, + dim_in=config.dim_model, + length_scale=config.length_scale, + dtype=config.dtype, + int8=config.int8, + init_mean=config.proj_init_mean, + init_std=config.proj_init_std, + bias=config.proj_bias, + ) + + if not self.tied: + self.lm_head = RoertaLMHead( + dim_model=config.dim_model, + vocab_size=config.vocab_size, + norm_eps=config.norm_eps, + ) + + def forward(self, + input_ids=None, + length=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, # unused + inputs_embeds=None, + encoder_hidden_states=None, # unused + encoder_attention_mask=None, # unused + labels=None, + next_sentence_label=None, + output_attentions=None, # unused + output_hidden_states=None, # unused + return_dict=True, + ): + outputs = self.roberta( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = outputs[0] + if self.cls_head: + logits = self.cls_projection(hidden_states) + elif self.tied: + logits = self.roberta.input_embedding.projection(hidden_states) + elif not self.tied: + logits = self.lm_head(hidden_states) + + lm_logits = logits[:, :-1, :].contiguous() + labels = labels[:, 1:].contiguous() + loss_func = bmt.loss.FusedCrossEntropy(ignore_index=-100) + total_loss = loss_func(lm_logits.view(-1, self.vocab_size), labels.view(-1)) + + if not return_dict: + return (total_loss, logits, outputs.hidden_states, outputs.attentions) + return RobertaForPreTrainingOutput( + loss=total_loss, + prediction_logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) \ No newline at end of file diff --git a/model_center/model/t5.py b/model_center/model/t5.py index b774234..9294ba2 100644 --- a/model_center/model/t5.py +++ b/model_center/model/t5.py @@ -110,30 +110,6 @@ def __init__(self, config: T5Config): init_std=config.pos_init_std, ) - self.tied = config.tied - self.cls_head = config.cls_head - if self.cls_head: - self.cls_projection = Linear( - dim_out=self.cls_head, - dim_in=config.dim_model, - length_scale=config.length_scale, - dtype=config.dtype, - int8=config.int8, - init_mean=config.proj_init_mean, - init_std=config.proj_init_std, - bias=config.proj_bias, - ) - if not self.tied: - self.output_projection = Linear( - dim_out=config.vocab_size, - dim_in=config.dim_model, - length_scale=config.length_scale, - dtype=config.dtype, - int8=config.int8, - init_mean=config.proj_init_mean, - init_std=config.proj_init_std, - bias=config.proj_bias, - ) def forward(self, input_ids=None, # (batch, seq_enc) @@ -274,6 +250,30 @@ class T5ForLM(BaseModel): def __init__(self, config: T5Config): super().__init__() self.t5 = T5(config) + self.tied = config.tied + self.cls_head = config.cls_head + if self.cls_head: + self.cls_projection = Linear( + dim_out=self.cls_head, + dim_in=config.dim_model, + length_scale=config.length_scale, + dtype=config.dtype, + int8=config.int8, + init_mean=config.proj_init_mean, + init_std=config.proj_init_std, + bias=config.proj_bias, + ) + if not self.tied: + self.output_projection = Linear( + dim_out=config.vocab_size, + dim_in=config.dim_model, + length_scale=config.length_scale, + dtype=config.dtype, + int8=config.int8, + init_mean=config.proj_init_mean, + init_std=config.proj_init_std, + bias=config.proj_bias, + ) def forward(self, input_ids=None, # (batch, seq_enc) @@ -315,7 +315,7 @@ def forward(self, if self.cls_head: logits = self.cls_projection(last_hidden_state) elif self.tied: - logits = self.input_embedding.projection(last_hidden_state) + logits = self.t5.input_embedding.projection(last_hidden_state) elif not self.tied: logits = self.output_projection(last_hidden_state) diff --git a/tests/test.sh b/tests/test.sh index 4e5fb4e..858ce0e 100755 --- a/tests/test.sh +++ b/tests/test.sh @@ -14,3 +14,4 @@ python3 -m torch.distributed.launch ${DISTRIBUTED_ARGS} test_bert.py python3 -m torch.distributed.launch ${DISTRIBUTED_ARGS} test_t5.py python3 -m torch.distributed.launch ${DISTRIBUTED_ARGS} test_gpt2.py python3 -m torch.distributed.launch ${DISTRIBUTED_ARGS} test_gptj.py +python3 -m torch.distributed.launch ${DISTRIBUTED_ARGS} test_roberta.py diff --git a/tests/test_roberta.py b/tests/test_roberta.py new file mode 100644 index 0000000..dfd8a5f --- /dev/null +++ b/tests/test_roberta.py @@ -0,0 +1,34 @@ +#coding:utf-8 + +import torch +import bmtrain as bmt +from model_center.model.config import RobertaConfig +from model_center.model import RobertaForLM + +from transformers import RobertaForMaskedLM as hugRoberta + +def main(): + bmt.init_distributed() + + path = "roberta-large" + config = RobertaConfig.from_pretrained(path) + config.dropout_p = 0 + bmt_roberta = RobertaForLM.from_pretrained(path, config=config) + hug_roberta = hugRoberta.from_pretrained(path).cuda().eval().half() + + for _ in range(10): + batch = 1 + max_encoder_length = 512 + input_ids = torch.randint(config.vocab_size, (batch, max_encoder_length,), dtype=torch.int32).cuda() + length = torch.randint(max_encoder_length, (batch, ), dtype=torch.int32).cuda() + attention_mask = torch.arange(input_ids.shape[1], device=input_ids.device)[None, :].repeat(input_ids.shape[0], 1) < length[:, None] + labels = torch.arange(input_ids.shape[1], device=input_ids.device)[None, :].repeat(input_ids.shape[0], 1) < length[:, None] + bmt_logits = bmt_roberta(input_ids = input_ids, attention_mask = attention_mask,labels=labels).prediction_logits + hug_logits = hug_roberta(input_ids = input_ids, attention_mask = attention_mask).logits + b = bmt_logits*attention_mask[:,:,None] + h = hug_logits*attention_mask[:,:,None] + d = (h - b).abs() + print(d.max()) + +if __name__ == "__main__": + main() diff --git a/transfer/hugRoberta_bmtrainRoberta.py b/transfer/hugRoberta_bmtrainRoberta.py new file mode 100644 index 0000000..c68d8e1 --- /dev/null +++ b/transfer/hugRoberta_bmtrainRoberta.py @@ -0,0 +1,130 @@ +# coding=utf-8 +# Copyright 2020 The OpenBMB team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import torch +import json + +from collections import OrderedDict +from transformers import RobertaModel, RobertaConfig, RobertaTokenizer, RobertaForMaskedLM +from model_center.model.config import RobertaConfig as myConfig + +base_path = '/home/hx/qzq/ModelCenter_roberta' + + +def convert_tokenizer(version: str): + tokenizer: RobertaTokenizer = RobertaTokenizer.from_pretrained(version) + vocab_size = tokenizer.vocab_size + s = [''] * vocab_size + for word in tokenizer.encoder: + id = tokenizer.encoder[word] + s[id] = word + fo = open(os.path.join(base_path, 'configs', 'roberta', version, 'vocab.txt'), 'w') + for word in s: + print(word, file=fo) + fo.close() + + +def convert_model(version: str): + config: RobertaConfig = RobertaConfig.from_pretrained(version) + default_config = myConfig() + config_json = {} + config_json['dim_head'] = int(config.hidden_size / config.num_attention_heads) + if default_config.dim_model != config.hidden_size: + config_json['dim_model'] = config.hidden_size + if default_config.dim_ff != config.intermediate_size: + config_json['dim_ff'] = config.intermediate_size + if default_config.num_heads != config.num_attention_heads: + config_json['num_heads'] = config.num_attention_heads + if default_config.num_layers != config.num_hidden_layers: + config_json['num_layers'] = config.num_hidden_layers + if default_config.vocab_size != config.vocab_size: + config_json['vocab_size'] = config.vocab_size + + try: + os.mkdir(os.path.join(base_path, 'configs', 'roberta', version)) + except: + pass + print(json.dumps(config_json), + file=open(os.path.join(base_path, 'configs', 'roberta', version, 'config.json'), 'w')) + + num_layers = config.num_hidden_layers + lmhead_bert = RobertaForMaskedLM.from_pretrained(version) + dict = lmhead_bert.state_dict() + new_dict = OrderedDict() + + new_dict['roberta.input_embedding.weight'] = dict['roberta.embeddings.word_embeddings.weight'] + new_dict['roberta.position_embedding.weight'] = dict['roberta.embeddings.position_embeddings.weight'] + new_dict['roberta.token_type_embedding.weight'] = dict['roberta.embeddings.token_type_embeddings.weight'] + + for i in range(num_layers): + new_dict['roberta.encoder.layers.' + str(i) + '.self_att.layernorm_before_attention.weight'] = ( + dict['roberta.embeddings.LayerNorm.weight'] if i == 0 + else dict['roberta.encoder.layer.' + str(i - 1) + '.output.LayerNorm.weight']) + new_dict['roberta.encoder.layers.' + str(i) + '.self_att.layernorm_before_attention.bias'] = ( + dict['roberta.embeddings.LayerNorm.bias'] if i == 0 + else dict['roberta.encoder.layer.' + str(i - 1) + '.output.LayerNorm.bias']) + new_dict['roberta.encoder.layers.' + str(i) + '.self_att.self_attention.project_q.weight'] = dict[ + 'roberta.encoder.layer.' + str(i) + '.attention.self.query.weight'] + new_dict['roberta.encoder.layers.' + str(i) + '.self_att.self_attention.project_q.bias'] = dict[ + 'roberta.encoder.layer.' + str(i) + '.attention.self.query.bias'] + new_dict['roberta.encoder.layers.' + str(i) + '.self_att.self_attention.project_k.weight'] = dict[ + 'roberta.encoder.layer.' + str(i) + '.attention.self.key.weight'] + new_dict['roberta.encoder.layers.' + str(i) + '.self_att.self_attention.project_k.bias'] = dict[ + 'roberta.encoder.layer.' + str(i) + '.attention.self.key.bias'] + new_dict['roberta.encoder.layers.' + str(i) + '.self_att.self_attention.project_v.weight'] = dict[ + 'roberta.encoder.layer.' + str(i) + '.attention.self.value.weight'] + new_dict['roberta.encoder.layers.' + str(i) + '.self_att.self_attention.project_v.bias'] = dict[ + 'roberta.encoder.layer.' + str(i) + '.attention.self.value.bias'] + new_dict['roberta.encoder.layers.' + str(i) + '.self_att.self_attention.attention_out.weight'] = dict[ + 'roberta.encoder.layer.' + str(i) + '.attention.output.dense.weight'] + new_dict['roberta.encoder.layers.' + str(i) + '.self_att.self_attention.attention_out.bias'] = dict[ + 'roberta.encoder.layer.' + str(i) + '.attention.output.dense.bias'] + new_dict['roberta.encoder.layers.' + str(i) + '.ffn.layernorm_before_ffn.weight'] = dict[ + 'roberta.encoder.layer.' + str(i) + '.attention.output.LayerNorm.weight'] + new_dict['roberta.encoder.layers.' + str(i) + '.ffn.layernorm_before_ffn.bias'] = dict[ + 'roberta.encoder.layer.' + str(i) + '.attention.output.LayerNorm.bias'] + new_dict['roberta.encoder.layers.' + str(i) + '.ffn.ffn.w_in.w.weight'] = dict[ + 'roberta.encoder.layer.' + str(i) + '.intermediate.dense.weight'] + new_dict['roberta.encoder.layers.' + str(i) + '.ffn.ffn.w_in.w.bias'] = dict[ + 'roberta.encoder.layer.' + str(i) + '.intermediate.dense.bias'] + new_dict['roberta.encoder.layers.' + str(i) + '.ffn.ffn.w_out.weight'] = dict[ + 'roberta.encoder.layer.' + str(i) + '.output.dense.weight'] + new_dict['roberta.encoder.layers.' + str(i) + '.ffn.ffn.w_out.bias'] = dict[ + 'roberta.encoder.layer.' + str(i) + '.output.dense.bias'] + + new_dict['roberta.encoder.output_layernorm.weight'] = dict[ + 'roberta.encoder.layer.' + str(num_layers - 1) + '.output.LayerNorm.weight'] + new_dict['roberta.encoder.output_layernorm.bias'] = dict[ + 'roberta.encoder.layer.' + str(num_layers - 1) + '.output.LayerNorm.bias'] + + new_dict['lm_head.dense.weight'] = dict['lm_head.dense.weight'] + new_dict['lm_head.dense.bias'] = dict['lm_head.dense.bias'] + new_dict['lm_head.layer_norm.weight'] = dict['lm_head.layer_norm.weight'] + new_dict['lm_head.layer_norm.bias'] = dict['lm_head.layer_norm.bias'] + new_dict['lm_head.decoder.weight'] = dict['lm_head.decoder.weight'] + new_dict['lm_head.decoder.bias'] = dict['lm_head.decoder.bias'] + + roberta = RobertaModel.from_pretrained(version) + dict = roberta.state_dict() + new_dict['roberta.pooler.dense.weight'] = dict['pooler.dense.weight'] + new_dict['roberta.pooler.dense.bias'] = dict['pooler.dense.bias'] + + torch.save(new_dict, os.path.join(base_path, 'configs', 'roberta', version, 'pytorch_model.pt')) + + +if __name__ == "__main__": + convert_model("roberta-large") + convert_tokenizer("roberta-large") From ad2991d66f012fec88c630ac082882887d965b13 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B9=94=E5=AD=90=E5=8D=BF?= <2020011015@secoder.net> Date: Tue, 26 Apr 2022 20:31:08 +0800 Subject: [PATCH 3/3] add corresponding lines in __init__.py --- model_center/model/__init__.py | 1 + model_center/model/config/__init__.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/model_center/model/__init__.py b/model_center/model/__init__.py index c4fda46..7dd7e9a 100644 --- a/model_center/model/__init__.py +++ b/model_center/model/__init__.py @@ -22,3 +22,4 @@ from .gpt2 import GPT2, GPT2ForLM from .gptj import GPTj, GPTjForLM from .bert import Bert, BertForLM +from .roberta import Roberta, RobertaForLM diff --git a/model_center/model/config/__init__.py b/model_center/model/config/__init__.py index 6ec69f5..e6ed4bc 100644 --- a/model_center/model/config/__init__.py +++ b/model_center/model/config/__init__.py @@ -17,4 +17,5 @@ from .t5_config import T5Config from .gpt2_config import GPT2Config from .gptj_config import GPTjConfig -from .bert_config import BertConfig \ No newline at end of file +from .bert_config import BertConfig +from .roberta_config import RobertaConfig \ No newline at end of file