From c003e294e8767c57d2c2839a03f9d2597568814a Mon Sep 17 00:00:00 2001 From: Haixin Liu Date: Fri, 7 Jun 2024 09:28:56 -0700 Subject: [PATCH 1/4] use te dpa for grok mqa --- praxis/layers/grok.py | 4 ++ praxis/layers/multi_query_attention.py | 65 +++++++++++++++++--------- 2 files changed, 46 insertions(+), 23 deletions(-) diff --git a/praxis/layers/grok.py b/praxis/layers/grok.py index 265a3132..ad1ff5fc 100644 --- a/praxis/layers/grok.py +++ b/praxis/layers/grok.py @@ -59,6 +59,7 @@ def GrokStackedTransformerHParams( combine_qkv=False, bidirectional=False, use_fp8=False, + use_te_dpa=True, ) -> pax_fiddle.Config[transformers.StackedTransformer]: """Common setup for Grok-1 Transformer layers. @@ -169,6 +170,7 @@ def GrokStackedTransformerHParams( p.transformer_layer_params_tpl.tr_atten_tpl = pax_fiddle.Config( multi_query_attention.MultiQueryDotProductAttention, num_kv_heads=attention_num_groups, + use_te_dpa=use_te_dpa, ) tr_atten_tpl = p.transformer_layer_params_tpl.tr_atten_tpl tr_atten_tpl.combine_qkv = False @@ -228,6 +230,7 @@ def GrokUniTransformerLmHParams( model_type=LanguageModelType.CAUSAL, checkpoint_policy=AutodiffCheckpointType.SAVE_NOTHING, use_fp8=False, + use_te_dpa=True, ) -> pax_fiddle.Config[transformer_models.TransformerLm]: """Common setup for Grok-1 Decoder-only Transformer Model. @@ -331,6 +334,7 @@ def GrokUniTransformerLmHParams( bidirectional=bidirectional, moe_gating_embedding_level=moe_gating_embedding_level, use_fp8=use_fp8, + use_te_dpa=use_te_dpa, ) num_blocks = num_transformer_layers diff --git a/praxis/layers/multi_query_attention.py b/praxis/layers/multi_query_attention.py index acd69598..33b7ab12 100644 --- a/praxis/layers/multi_query_attention.py +++ b/praxis/layers/multi_query_attention.py @@ -31,7 +31,7 @@ from praxis.layers import base_ops from praxis.layers import embedding_softmax from praxis.layers import stochastics - +from transformer_engine.jax.flax.transformer import DotProductAttention as TEDotProductAttention WeightInit = base_layer.WeightInit WeightHParams = base_layer.WeightHParams @@ -209,6 +209,7 @@ class MultiQueryDotProductAttention(base_layer.BaseLayer): pv_einsum_tpl: LayerTpl = template_field(base_ops.EinsumOp) scale_query_by_dim_per_head: bool = False chunked_attn_num_seq_split: int = 1 + use_te_dpa: bool = False # SPMD partition related params. # @@ -347,6 +348,20 @@ def project_input_kv(input_dim, dim_per_head): self.create_child('post', post_proj_p) self.create_child('qk_einsum', self.qk_einsum_tpl.clone()) self.create_child('pv_einsum', self.pv_einsum_tpl.clone()) + self.dpa_layer = TEDotProductAttention( + head_dim=dim_per_head, + num_attention_heads=self.num_heads, + num_gqa_groups=self.num_kv_heads, + attn_mask_type='causal', # 'causal' or 'padding' + attn_bias_type='no_bias', # 'no_bias', 'pre_scale_bias' or 'post_scale_bias' + attention_dropout=0., + dropout_rng_name='aqt', + dtype=jnp.bfloat16, + float32_logits=False, + qkv_layout='BSHD_BSHD_BSHD', # 'BS3HD', 'BSHD_BS2HD' or 'BSHD_BSHD_BSHD' + scale_factor=1.0/math.sqrt(self.num_heads), + transpose_batch_sequence=False + ) def _shard_bnh(self, x: JTensor) -> JTensor: """Shards tensors of shape [b, n, h]. @@ -828,29 +843,33 @@ def _rep_d(x): else: key_proj = self._shard_blnh(key_proj) value_proj = self._shard_blnh(value_proj) - b, t, n, h = query_proj.shape - _, s, nk, _ = key_proj.shape - assert n % nk == 0 - v_q = jnp.reshape(query_proj, (b, t, nk, n // nk, h)) - if relative_bias is not None: - v_rb = jnp.reshape(relative_bias, (b, nk, n // nk, t, s)) + if self.use_te_dpa: + atten_probs = None + encoded = self.dpa_layer(query_proj, key_proj, value_proj) else: - v_rb = None - with self._context_for_kv_vmap(): - encoded, atten_probs = jax.vmap( - self._dot_atten, - in_axes=(2, 2, 2, None, 1), - out_axes=(2, 1), - )( - v_q, - key_proj, - value_proj, - atten_mask, - v_rb, - ) - encoded = self._shard_blnh(jnp.reshape(encoded, (b, t, n, h))) - if atten_probs is not None: - atten_probs = jnp.reshape(atten_probs, (b, t, n, s)) + b, t, n, h = query_proj.shape + _, s, nk, _ = key_proj.shape + assert n % nk == 0 + v_q = jnp.reshape(query_proj, (b, t, nk, n // nk, h)) + if relative_bias is not None: + v_rb = jnp.reshape(relative_bias, (b, nk, n // nk, t, s)) + else: + v_rb = None + with self._context_for_kv_vmap(): + encoded, atten_probs = jax.vmap( + self._dot_atten, + in_axes=(2, 2, 2, None, 1), + out_axes=(2, 1), + )( + v_q, + key_proj, + value_proj, + atten_mask, + v_rb, + ) + encoded = self._shard_blnh(jnp.reshape(encoded, (b, t, n, h))) + if atten_probs is not None: + atten_probs = jnp.reshape(atten_probs, (b, t, n, s)) # Post projection encoded = self.post(encoded) From 99bff322268ccd38d8b18d031eba66605d342172 Mon Sep 17 00:00:00 2001 From: Haixin Liu Date: Fri, 7 Jun 2024 10:37:20 -0700 Subject: [PATCH 2/4] add doc string and warning --- praxis/layers/multi_query_attention.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/praxis/layers/multi_query_attention.py b/praxis/layers/multi_query_attention.py index 33b7ab12..c11ffaca 100644 --- a/praxis/layers/multi_query_attention.py +++ b/praxis/layers/multi_query_attention.py @@ -17,7 +17,7 @@ import math from typing import Callable, Mapping, Sequence - +from absl import logging from flax import linen as nn import jax from jax import numpy as jnp @@ -209,7 +209,7 @@ class MultiQueryDotProductAttention(base_layer.BaseLayer): pv_einsum_tpl: LayerTpl = template_field(base_ops.EinsumOp) scale_query_by_dim_per_head: bool = False chunked_attn_num_seq_split: int = 1 - use_te_dpa: bool = False + use_te_dpa: bool = False # Experimental way to use TE flash attention when can't use standard TE # SPMD partition related params. # @@ -844,6 +844,9 @@ def _rep_d(x): key_proj = self._shard_blnh(key_proj) value_proj = self._shard_blnh(value_proj) if self.use_te_dpa: + logging.warning( + 'use_te_dpa is set to True, so TE dpa is used as an experimental way to use TE flash attention.' + ) atten_probs = None encoded = self.dpa_layer(query_proj, key_proj, value_proj) else: From 4eece5a90da2317994432c530ff37df42d6e617b Mon Sep 17 00:00:00 2001 From: Haixin Liu Date: Fri, 7 Jun 2024 16:47:13 -0700 Subject: [PATCH 3/4] default to not use dpa in grok --- praxis/layers/grok.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/praxis/layers/grok.py b/praxis/layers/grok.py index ad1ff5fc..2c1fd727 100644 --- a/praxis/layers/grok.py +++ b/praxis/layers/grok.py @@ -59,7 +59,7 @@ def GrokStackedTransformerHParams( combine_qkv=False, bidirectional=False, use_fp8=False, - use_te_dpa=True, + use_te_dpa=False, ) -> pax_fiddle.Config[transformers.StackedTransformer]: """Common setup for Grok-1 Transformer layers. @@ -230,7 +230,7 @@ def GrokUniTransformerLmHParams( model_type=LanguageModelType.CAUSAL, checkpoint_policy=AutodiffCheckpointType.SAVE_NOTHING, use_fp8=False, - use_te_dpa=True, + use_te_dpa=False, ) -> pax_fiddle.Config[transformer_models.TransformerLm]: """Common setup for Grok-1 Decoder-only Transformer Model. From e7221ed7c336098cbab3add67f8528e36eb959c9 Mon Sep 17 00:00:00 2001 From: Haixin Liu Date: Mon, 17 Jun 2024 17:45:21 -0700 Subject: [PATCH 4/4] fix checkpoint for grok --- praxis/layers/attentions.py | 3 +++ praxis/layers/grok.py | 1 + 2 files changed, 4 insertions(+) diff --git a/praxis/layers/attentions.py b/praxis/layers/attentions.py index 402830fa..bb63469a 100644 --- a/praxis/layers/attentions.py +++ b/praxis/layers/attentions.py @@ -1698,6 +1698,9 @@ def __call__( query_proj = self.query(query_vec) key_proj = self.key(key_vec) value_proj = self.value(value_vec) + query_proj = checkpoint_name(query_proj, 'query_proj') + key_proj = checkpoint_name(key_proj, 'key_proj') + value_proj = checkpoint_name(value_proj, 'value_proj') if not self.consolidate_rope_key_state: self._fprop_update_decode_state('key_state', key_proj) diff --git a/praxis/layers/grok.py b/praxis/layers/grok.py index 2c1fd727..2adfa76b 100644 --- a/praxis/layers/grok.py +++ b/praxis/layers/grok.py @@ -357,5 +357,6 @@ def GrokUniTransformerLmHParams( num_pipeline_stages=num_pipeline_stages, num_pipeline_microbatches=num_pipeline_microbatches, stream_io=True, + checkpoint_policy=checkpoint_policy, ) return p