Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions model_center/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@
from .config import *

# Model Architecture
from .cpm1 import CPM1
from .cpm2 import CPM2
from .t5 import T5
from .gpt2 import GPT2
from .gptj import GPTj
from .bert import Bert
from .cpm1 import CPM1, CPM1ForLM
from .cpm2 import CPM2, CPM2ForLM
from .t5 import T5, T5ForLM
from .gpt2 import GPT2, GPT2ForLM
from .gptj import GPTj, GPTjForLM
from .bert import Bert, BertForLM
from .roberta import Roberta, RobertaForLM
242 changes: 152 additions & 90 deletions model_center/model/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@
from ..layer import Encoder, Embedding, Linear, LayerNorm
from .basemodel import BaseModel
from .config import BertConfig
from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions
from .modeling_output import ModelOutput,BaseModelOutputWithPoolingAndCrossAttentions
from typing import Optional, Tuple
import bmtrain as bmt


class BertPooler(torch.nn.Module):
def __init__(self, dim_model):
Expand All @@ -30,7 +33,7 @@ def forward(self, hidden_states):
pooled_output = self.activation(pooled_output)
return pooled_output


class BertLMHead(torch.nn.Module):
def __init__(self, dim_model, vocab_size, norm_eps):
super().__init__()
Expand All @@ -47,90 +50,78 @@ def forward(self, hidden_states):
return logits


class Bert(BaseModel):
class BertForPreTrainingOutput(ModelOutput):
loss: Optional[torch.FloatTensor] = None
prediction_logits: torch.FloatTensor = None
seq_relationship_logits: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None


class Bert(BaseModel):
_CONFIG_TYPE = BertConfig

def __init__(self, config: BertConfig):
super().__init__()

self.input_embedding = Embedding(
vocab_size = config.vocab_size,
embedding_size = config.dim_model,
length_scale = config.length_scale,
dtype = config.dtype,
int8 = config.int8,
init_mean = config.emb_init_mean,
init_std = config.emb_init_std,
vocab_size=config.vocab_size,
embedding_size=config.dim_model,
length_scale=config.length_scale,
dtype=config.dtype,
int8=config.int8,
init_mean=config.emb_init_mean,
init_std=config.emb_init_std,
)

self.position_embedding = Embedding(
vocab_size = config.position_size,
embedding_size = config.dim_model,
length_scale = config.length_scale,
dtype = config.dtype,
int8 = config.int8,
init_mean = config.emb_init_mean,
init_std = config.emb_init_std,
vocab_size=config.position_size,
embedding_size=config.dim_model,
length_scale=config.length_scale,
dtype=config.dtype,
int8=config.int8,
init_mean=config.emb_init_mean,
init_std=config.emb_init_std,
)

self.token_type_embedding = Embedding(
vocab_size = config.type_size,
embedding_size = config.dim_model,
length_scale = config.length_scale,
dtype = config.dtype,
int8 = config.int8,
init_mean = config.emb_init_mean,
init_std = config.emb_init_std,
vocab_size=config.type_size,
embedding_size=config.dim_model,
length_scale=config.length_scale,
dtype=config.dtype,
int8=config.int8,
init_mean=config.emb_init_mean,
init_std=config.emb_init_std,
)

self.embed_dropout = torch.nn.Dropout(config.dropout_p)

self.encoder = Encoder(
num_layers = config.num_layers,
dim_model = config.dim_model,
dim_ff = config.dim_ff,
num_heads = config.num_heads,
dim_head = config.dim_head,
dtype = config.dtype,
int8 = config.int8,
norm_eps = config.norm_eps,
norm_init_var = config.norm_init_var,
norm_bias = config.norm_bias,
att_init_mean = config.att_init_mean,
att_init_std = config.att_init_std,
att_bias = config.att_bias,
att_mask_value = float(config.att_mask_value),
pos_bias_type = config.pos_bias_type,
ffn_init_mean = config.ffn_init_mean,
ffn_init_std = config.ffn_init_std,
ffn_bias = config.ffn_bias,
ffn_activate_fn = config.ffn_activate_fn,
length_scale = config.length_scale,
attn_scale = config.attn_scale,
dropout_p = config.dropout_p,
post_layer_norm = config.post_layer_norm,
num_layers=config.num_layers,
dim_model=config.dim_model,
dim_ff=config.dim_ff,
num_heads=config.num_heads,
dim_head=config.dim_head,
dtype=config.dtype,
int8=config.int8,
norm_eps=config.norm_eps,
norm_init_var=config.norm_init_var,
norm_bias=config.norm_bias,
att_init_mean=config.att_init_mean,
att_init_std=config.att_init_std,
att_bias=config.att_bias,
att_mask_value=float(config.att_mask_value),
pos_bias_type=config.pos_bias_type,
ffn_init_mean=config.ffn_init_mean,
ffn_init_std=config.ffn_init_std,
ffn_bias=config.ffn_bias,
ffn_activate_fn=config.ffn_activate_fn,
length_scale=config.length_scale,
attn_scale=config.attn_scale,
dropout_p=config.dropout_p,
post_layer_norm=config.post_layer_norm,
)

self.tied = config.tied
self.cls_head = config.cls_head
if self.cls_head:
self.cls_projection = Linear(
dim_out = self.cls_head,
dim_in = config.dim_model,
length_scale = config.length_scale,
dtype = config.dtype,
int8 = config.int8,
init_mean = config.proj_init_mean,
init_std = config.proj_init_std,
bias = config.proj_bias,
)
if not self.tied:
self.lm_head = BertLMHead(
dim_model = config.dim_model,
vocab_size = config.vocab_size,
norm_eps = config.norm_eps,
)

self.pooler = BertPooler(config.dim_model)

Expand All @@ -140,36 +131,38 @@ def forward(self,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None, #unused
head_mask=None, # unused
inputs_embeds=None,
encoder_hidden_states=None, #unused
encoder_attention_mask=None, #unused
output_attentions=None, #unused
output_hidden_states=None, #unused
encoder_hidden_states=None, # unused
encoder_attention_mask=None, # unused
output_attentions=None, # unused
output_hidden_states=None, # unused
return_dict=True,
return_logits = False,
):
# TODO:删除
# return_logits=False,
):
""" This model inherits from BaseModel. This model is also a PyTorch torch.nn.Module subclass.
You can use it as a regular PyTorch Module.
You can also select the data and data type that you want the model to return through changing the value of `return_dict` and `return_logits`.

Args:
input_ids (:obj:`torch.Tensor` of shape ``(batch, seq_length)``): Indices of input sequence tokens. It will be embedded by model's internal embedding lookup matrix.
length (:obj:`torch.Tensor` of shape ``(batch)``): Length of input sequence before padding.
length (:obj:`torch.Tensor` of shape ``(batch)``): Length of input sequence before padding.
attention_mask (:obj:`torch.Tensor` of shape ``(batch, seq_length)``): Used to avoid performing attention on padding token indices.
token_type_ids(:obj:`torch.Tensor` of shape ``(batch, seq_length)``): Unused.
token_type_ids(:obj:`torch.Tensor` of shape ``(batch, seq_length)``): Unused.
position_ids(:obj:`torch.Tensor` of shape ``(batch, seq_length)``): Unused.
head_mask (:obj:`torch.Tensor` of shape ``(num_layers, num_heads)``): Unused.
inputs_embeds (:obj:`torch.Tensor` of shape ``(batch, seq_length, dim_model)``): Embedding of the input. You can choose to directly pass the inputs embedding to control the way of embedding.
inputs_embeds (:obj:`torch.Tensor` of shape ``(batch, seq_length, dim_model)``): Embedding of the input. You can choose to directly pass the inputs embedding to control the way of embedding.
encoder_hidden_states(:obj:`torch.Tensor` of shape(batch, seq_length, dim_model)): Unused.
encoder_attention_mask (:obj:`torch.Tensor` of shape ``(batch, seq_length)``): Unused.
encoder_attention_mask (:obj:`torch.Tensor` of shape ``(batch, seq_length)``): Unused.
output_attentions (:obj:`torch.Tensor` of shape ``(batch, num_heads, seq_length, seq_length)``): Unused.
output_hidden_states (:obj:`torch.Tensor` of shape ``(batch, seq_length, dim_model)``): Unused.
return_dict (:obj:`bool`): Whether to return a BaseModelOutputWithPoolingAndCrossAttentions instead of just a tuple.
return_logits (:obj:`bool`): Whether to return the prediction score for each token in vocabulary (before softmax).
TODO:删除
// return_logits (:obj:`bool`): Whether to return the prediction score for each token in vocabulary (before softmax).

Return:
BaseModelOutputWithPoolingAndCrossAttentions or tuple or torch.Tensor of shape (batch, seq_length, vocab_output_size) or (batch, seqlen, cls_head): The Bert output. Depended on the value of `return_dict` and `return_logits`
BaseModelOutputWithPoolingAndCrossAttentions or tuple or torch.Tensor of shape (batch, seq_length, vocab_output_size) or (batch, seqlen, cls_head): The Bert output. Depended on the value of `return_dict` and `return_logits`

"""
assert input_ids is not None or inputs_embeds is not None
Expand Down Expand Up @@ -209,16 +202,6 @@ def forward(self,

hidden_states = self.encoder(hidden_states, attention_mask)

if self.cls_head:
logits = self.cls_projection(hidden_states)
elif self.tied:
logits = self.input_embedding.projection(hidden_states)
elif not self.tied:
logits = self.lm_head(hidden_states)

if return_logits:
return logits

pooled_output = self.pooler(hidden_states)

if not return_dict:
Expand All @@ -231,4 +214,83 @@ def forward(self,
hidden_states=None,
attentions=None,
cross_attentions=None,
)
)


class BertForLM(BaseModel):
_CONFIG_TYPE = BertConfig

def __init__(self, config: BertConfig):
super().__init__()
self.bert = Bert(config)
self.seq_cls = torch.nn.Linear(config.hidden_size, 2)
self.tied = config.tied
self.cls_head = config.cls_head
if self.cls_head:
self.cls_projection = Linear(
dim_out=self.cls_head,
dim_in=config.dim_model,
length_scale=config.length_scale,
dtype=config.dtype,
int8=config.int8,
init_mean=config.proj_init_mean,
init_std=config.proj_init_std,
bias=config.proj_bias,
)
if not self.tied:
self.lm_head = BertLMHead(
dim_model=config.dim_model,
vocab_size=config.vocab_size,
norm_eps=config.norm_eps,
)

def forward(self,
input_ids=None,
length=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None, # unused
inputs_embeds=None,
encoder_hidden_states=None, # unused
encoder_attention_mask=None, # unused
labels=None,
next_sentence_label=None,
output_attentions=None, # unused
output_hidden_states=None, # unused
return_dict=True,
):
outputs = self.bert(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states, pooler_output = outputs[:2]
if self.cls_head:
logits = self.cls_projection(hidden_states)
elif self.tied:
logits = self.bert.input_embedding.projection(hidden_states)
elif not self.tied:
logits = self.lm_head(hidden_states)

seq_relationship_score = self.seq_cls(pooler_output)
loss_func = bmt.loss.FusedCrossEntropy(ignore_index=-100)
masked_loss = loss_func(logits.view(-1, self.config.vocab_size), labels.view(-1))
next_sentence_loss = loss_func(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
total_loss = masked_loss + next_sentence_loss

if not return_dict:
return (total_loss, logits, seq_relationship_score, outputs.hidden_states, outputs.attentions)
return BertForPreTrainingOutput(
loss=total_loss,
prediction_logits=logits,
seq_relationship_logits=seq_relationship_score,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
3 changes: 2 additions & 1 deletion model_center/model/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,5 @@
from .t5_config import T5Config
from .gpt2_config import GPT2Config
from .gptj_config import GPTjConfig
from .bert_config import BertConfig
from .bert_config import BertConfig
from .roberta_config import RobertaConfig
Loading