-
Notifications
You must be signed in to change notification settings - Fork 1
Open
Description
class QSparseDecoderLayer(nn.Module):
xxxxxxxx
def __init__(self, d_model, nhead, dim_feedforward, dropout=0.1, k_ratio=0.5, quantized=False):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
self.feed_forward = ReLU2GLU(d_model, dim_feedforward)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
self.sparsity = TopKSparsity(k_ratio) if not quantized else QuantizedTopKSparsity(k_ratio)
def forward(self, x, mask=None):
# Apply sparsity to input
x = self.sparsity(x)
# Self-attention mechanism
attn_output, _ = self.self_attn(x, x, x, attn_mask=mask)
x = x + self.dropout(attn_output)
x = self.norm1(x)
xxxxthe code has error on self.self_attn with lacking of q_proj k_proj v_proj and o_proj
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels