diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 9c9a055b55..9577f8d891 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -3170,6 +3170,7 @@ def __init__( qkv_weight_interleaved: bool = True, ub_bulk_wgrad: bool = False, ub_bulk_dgrad: bool = False, + ub_overlap_rs_dgrad: bool = False, ub_overlap_rs: bool = False, ub_overlap_ag: bool = False, bias: bool = True, @@ -3258,6 +3259,7 @@ def __init__( zero_centered_gamma=zero_centered_gamma, ub_bulk_wgrad=ub_bulk_wgrad, ub_bulk_dgrad=ub_bulk_dgrad, + ub_overlap_rs_dgrad=ub_overlap_rs_dgrad, ub_overlap_ag=ub_overlap_ag, normalization=normalization, ub_name="qkv", @@ -3289,6 +3291,7 @@ def __init__( zero_centered_gamma=zero_centered_gamma, ub_bulk_wgrad=ub_bulk_wgrad, ub_bulk_dgrad=ub_bulk_dgrad, + ub_overlap_rs_dgrad=ub_overlap_rs_dgrad, ub_overlap_ag=ub_overlap_ag, normalization=normalization, ub_name="qkv", diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 59e5949e06..72fc350849 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -137,6 +137,7 @@ def initialize_ub( "bulk":["qkv_dgrad", "qkv_wgrad", "fc1_dgrad", "fc1_wgrad"], } layers_reduce_scatter_overlap = ["proj_fprop", "fc2_fprop", "qkv_wgrad", "fc1_wgrad"] + dgrad_reduce_scatter_overlap = ["qkv_dgrad", "fc1_dgrad"] def get_method(name): for method, names in methods.items(): @@ -207,6 +208,14 @@ def add_ub( ) _ub_communicators[name] = ub_obj + if ub_cfgs is not None: + for name in dgrad_reduce_scatter_overlap: + if name in ub_cfgs and 'method' in ub_cfgs[name] and ub_cfgs[name]['method'] != 'bulk': + wgrad_name = name.replace('dgrad','wgrad') + assert wgrad_name not in ub_cfgs + layers_reduce_scatter_overlap.remove(wgrad_name) + layers_reduce_scatter_overlap.append(name) + for name in (methods["ring_exchange"]+methods["pipeline"]+methods["bulk"]): if ub_cfgs is not None and name in ub_cfgs: ub_cfg = ub_cfgs[name] diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 3711d9898f..b933734b25 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -86,6 +86,7 @@ def forward( primary_weights_in_fp8: bool, ub_bulk_wgrad: bool, ub_bulk_dgrad: bool, + ub_overlap_rs_dgrad: bool, ub_overlap_ag: bool, ub_name: str, ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: @@ -293,6 +294,7 @@ def forward( ctx.zero_centered_gamma = zero_centered_gamma ctx.ub_bulk_wgrad = ub_bulk_wgrad ctx.ub_bulk_dgrad = ub_bulk_dgrad + ctx.ub_overlap_rs_dgrad = ub_overlap_rs_dgrad ctx.ub_name = ub_name ctx.requires_dgrad = inp.requires_grad ctx.normalization = normalization @@ -344,9 +346,15 @@ def backward( update_cache="reuse_only" if ctx.is_first_microbatch is None else "lazy", ) - if ctx.ub_bulk_dgrad: + if ctx.ub_overlap_rs_dgrad: + ctx.ub_bulk_dgrad = False + ctx.ub_bulk_wgrad = False tp_world_size = get_distributed_world_size(ctx.tp_group) if tp_world_size == 1: + ctx.ub_overlap_rs_dgrad = False + if ctx.ub_bulk_dgrad: + tp_world_size = get_distributed_world_size(ctx.tp_group) + if tp_world_size == 1 or not weight.requires_grad: ctx.ub_bulk_dgrad = False if ctx.ub_bulk_dgrad: dim_size = list(ln_out.size()) @@ -393,9 +401,35 @@ def backward( if ctx.ub_bulk_wgrad: # allocate dgrad output ub_obj_dgrad = get_ub(ctx.ub_name+"_wgrad") dgrad = ub_obj_dgrad.get_ubuf_output(1) # AllGather output + elif ctx.ub_overlap_rs_dgrad: + ub_obj_dgrad = get_ub(ctx.ub_name+"_dgrad") + dgrad = ub_obj_dgrad.get_ubuf_output(1) # AllGather output else: dgrad = torch.empty(dgrad_size, dtype=ctx.activation_dtype, device=weight.device) + if ctx.ub_bulk_dgrad: + ub_algo = tex.UbufOverlapAlgo.BULK_OVERLAP_AG + ub_obj = ub_obj_lnout + elif ctx.ub_overlap_rs_dgrad: + dim_size = list(grad_output.size()) + dim_size[0] = dim_size[0] // tp_world_size + dim_size[1] = weight.size(1) + rs_out = torch.empty(dim_size, dtype=ctx.activation_dtype, device=grad_output.device) + if ub_obj_dgrad.is_p2p_overlap(): + if ctx.fp8 and ub_obj_dgrad.is_atomic_gemm(): + ub_algo=tex.UbufOverlapAlgo.ATOMIC_GEMM_RS_P2P + else: + ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P + else: + if ctx.fp8 and ub_obj_dgrad.is_atomic_gemm(): + ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_RS + else: + ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS + ub_obj = ub_obj_dgrad + else: + ub_algo = None + ub_obj = None + if ctx.fp8: fp8_dtype_forward = get_fp8_te_dtype( ctx.fp8_meta["recipe"], fprop_tensor=True @@ -405,14 +439,14 @@ def backward( ) out_index, meta_tensor, out_te_type, out_type = ( None, None, None, ctx.activation_dtype) - if ctx.ub_bulk_wgrad and ub_obj_dgrad.is_fp8_ubuf(): + if (ctx.ub_bulk_wgrad or ctx.ub_overlap_rs_dgrad) and ub_obj_dgrad.is_fp8_ubuf(): out_index = tex.FP8BwdTensors.GRAD_INPUT1 meta_tensor = ctx.fp8_meta["scaling_bwd"] out_te_type = fp8_dtype_backward out_type = torch.uint8 ub_obj_dgrad.set_ubuf_scale_inv(meta_tensor.scale_inv[out_index]) - # DGRAD: Evaluated unconditionally to feed into Linear backward + # DGRAD: Evaluated unconditionally to feed into Linear backward vasu _ = tex.fp8_gemm( weight_t_fp8._data, fwd_scale_inverses, @@ -426,8 +460,9 @@ def backward( get_workspace(), out=dgrad, use_split_accumulator=_2X_ACC_DGRAD, - ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_AG if ctx.ub_bulk_dgrad else None, - ub=ub_obj_lnout if ctx.ub_bulk_dgrad else None, + ub_algo=ub_algo, + ub=ub_obj, + extra_output_tensor=rs_out if ctx.ub_overlap_rs_dgrad else None, out_index=out_index, fp8_meta_tensor = meta_tensor, D_dtype = out_te_type, @@ -443,8 +478,9 @@ def backward( out=dgrad, layout="NN", grad=True, - ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_AG if ctx.ub_bulk_dgrad else None, - ub=ub_obj_lnout if ctx.ub_bulk_dgrad else None + ub_algo=ub_algo, + ub=ub_obj, + extra_output_tensor=rs_out if ctx.ub_overlap_rs_dgrad else None, ) if ctx.ub_bulk_dgrad: ln_out_total = ub_obj_lnout.get_ubuf_output(1) @@ -453,7 +489,7 @@ def backward( if ctx.parallel_mode == "column" and ctx.sequence_parallel: if not ctx.ub_bulk_dgrad and handle is not None: handle.wait() - if not ctx.ub_bulk_wgrad: + if not ctx.ub_bulk_wgrad and not ctx.ub_overlap_rs_dgrad: if ctx.return_layernorm_output and ctx.return_layernorm_output_gathered: dgrad = dgrad + grad_outputs[1].view_as(dgrad) dgrad, handle = reduce_scatter_along_first_dim( @@ -546,7 +582,10 @@ def backward( handle.wait() # LayerNorm gradient - dgrad = dgrad.view(inputmat.shape) + if ctx.ub_overlap_rs_dgrad: + dgrad = rs_out.view(inputmat.shape) + else: + dgrad = dgrad.view(inputmat.shape) # Residual gradient if ctx.return_layernorm_output and not ctx.return_layernorm_output_gathered: @@ -622,6 +661,7 @@ def backward( None, None, None, + None, ) @@ -735,6 +775,7 @@ def __init__( ub_bulk_wgrad: bool = False, ub_bulk_dgrad: bool = False, ub_overlap_ag: bool = False, + ub_overlap_rs_dgrad: bool = False, ub_name: Optional[str] = None, ) -> None: super().__init__() @@ -755,7 +796,8 @@ def __init__( self.ub_bulk_wgrad = ub_bulk_wgrad self.ub_bulk_dgrad = ub_bulk_dgrad self.ub_overlap_ag = ub_overlap_ag - if any([ub_bulk_wgrad, ub_bulk_dgrad, ub_overlap_ag]): + self.ub_overlap_rs_dgrad = ub_overlap_rs_dgrad + if any([ub_bulk_wgrad, ub_bulk_dgrad, ub_overlap_ag, ub_overlap_rs_dgrad]): assert ub_name is not None, "Userbuffer name [string] is not set." self.ub_name = ub_name @@ -1087,6 +1129,7 @@ def forward( self.primary_weights_in_fp8, self.ub_bulk_wgrad, self.ub_bulk_dgrad, + self.ub_overlap_rs_dgrad, self.ub_overlap_ag, self.ub_name, ) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index c9338aee90..31c8fbad7e 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -117,6 +117,7 @@ def forward( primary_weights_in_fp8: bool, ub_bulk_wgrad: bool, ub_bulk_dgrad: bool, + ub_overlap_rs_dgrad: bool, ub_overlap_rs: bool, ub_overlap_ag: bool, gemm_gelu_fusion: bool, @@ -518,6 +519,7 @@ def forward( ctx.zero_centered_gamma = zero_centered_gamma ctx.ub_bulk_wgrad = ub_bulk_wgrad ctx.ub_bulk_dgrad = ub_bulk_dgrad + ctx.ub_overlap_rs_dgrad = ub_overlap_rs_dgrad ctx.ub_overlap_ag = ub_overlap_ag ctx.requires_dgrad = inp.requires_grad ctx.normalization = normalization @@ -583,9 +585,15 @@ def backward( activation_func = _act_func(ctx.activation)[1] - if ctx.ub_bulk_dgrad: + if ctx.ub_overlap_rs_dgrad: + ctx.ub_bulk_dgrad = False + ctx.ub_bulk_wgrad = False tp_world_size = get_distributed_world_size(ctx.tp_group) if tp_world_size == 1: + ctx.ub_overlap_rs_dgrad = False + if ctx.ub_bulk_dgrad: + tp_world_size = get_distributed_world_size(ctx.tp_group) + if tp_world_size == 1 or not fc1_weight.requires_grad: ctx.ub_bulk_dgrad = False if ctx.ub_bulk_dgrad: dim_size = list(ln_out.size()) @@ -758,19 +766,49 @@ def backward( None, None, None, ctx.activation_dtype) fc1_dgrad_size = list(dgelu.size()) fc1_dgrad_size[1] = fc1_weight.size(1) + # Get/alloc fc1_dgrad if ctx.ub_bulk_wgrad: # allocate dgrad output ub_obj_dgrad = get_ub("fc1_wgrad") fc1_dgrad = ub_obj_dgrad.get_ubuf_output(1) # AllGather output - if ub_obj_dgrad.is_fp8_ubuf(): - out_index = tex.FP8BwdTensors.GRAD_INPUT2 - meta_tensor = ctx.fp8_meta["scaling_bwd"] - out_te_type = fp8_dtype_backward - out_type = torch.uint8 - ub_obj_dgrad.set_ubuf_scale_inv(meta_tensor.scale_inv[out_index]) + elif ctx.ub_overlap_rs_dgrad: + ub_obj_dgrad = get_ub("fc1_dgrad") + fc1_dgrad = ub_obj_dgrad.get_ubuf_output(1) # AllGather output else: fc1_dgrad = torch.empty( fc1_dgrad_size, dtype=ctx.activation_dtype, device=fc1_weight.device ) + + # FP8 RS + if (ctx.ub_bulk_wgrad or ctx.ub_overlap_rs_dgrad) and ub_obj_dgrad.is_fp8_ubuf(): + out_index = tex.FP8BwdTensors.GRAD_INPUT2 + meta_tensor = ctx.fp8_meta["scaling_bwd"] + out_te_type = fp8_dtype_backward + out_type = torch.uint8 + ub_obj_dgrad.set_ubuf_scale_inv(meta_tensor.scale_inv[out_index]) + + # Set UB algo and UB obj for fc1_dgrad bulk/pipelined overlap + if ctx.ub_bulk_dgrad: + ub_algo = tex.UbufOverlapAlgo.BULK_OVERLAP_AG + ub_obj = ub_obj_lnout + elif ctx.ub_overlap_rs_dgrad: + dim_size = list(dgelu.size()) + dim_size[0] = dim_size[0] // tp_world_size + dim_size[1] = fc1_weight_t_fp8.size(0) + rs_out = torch.empty(dim_size, dtype=ctx.activation_dtype, device=dgelu.device) + if ub_obj_dgrad.is_p2p_overlap(): + if ub_obj_dgrad.is_atomic_gemm(): + ub_algo=tex.UbufOverlapAlgo.ATOMIC_GEMM_RS_P2P + else: + ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P + else: + if ub_obj_dgrad.is_atomic_gemm(): + ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_RS + else: + ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS + ub_obj = ub_obj_dgrad + else: + ub_algo = None + ub_obj = None # FC1 DGRAD: Unconditional _ = tex.fp8_gemm( fc1_weight_t_fp8._data, @@ -785,8 +823,9 @@ def backward( get_workspace(), out=fc1_dgrad, use_split_accumulator=_2X_ACC_DGRAD, - ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_AG if ctx.ub_bulk_dgrad else None, - ub=ub_obj_lnout if ctx.ub_bulk_dgrad else None, + ub_algo=ub_algo, + ub=ub_obj, + extra_output_tensor=rs_out if ctx.ub_overlap_rs_dgrad else None, out_index=out_index, fp8_meta_tensor = meta_tensor, D_dtype = out_te_type, @@ -843,11 +882,31 @@ def backward( if ctx.ub_bulk_wgrad: # allocate dgrad output ub_obj_dgrad = get_ub("fc1_wgrad") fc1_dgrad = ub_obj_dgrad.get_ubuf_output(1) # AllGather output + elif ctx.ub_overlap_rs_dgrad: + ub_obj_dgrad = get_ub("fc1_dgrad") + fc1_dgrad = ub_obj_dgrad.get_ubuf_output(1) # AllGather output else: fc1_dgrad = torch.empty( fc1_dgrad_size, dtype=ctx.activation_dtype, device=fc1_weight.device ) + # Set UB algo and UB obj for fc1_dgrad bulk/pipelined overlap + if ctx.ub_bulk_dgrad: + ub_algo = tex.UbufOverlapAlgo.BULK_OVERLAP_AG + ub_obj = ub_obj_lnout + elif ctx.ub_overlap_rs_dgrad: + dim_size = list(dgelu.size()) + dim_size[0] = dim_size[0] // tp_world_size + dim_size[1] = fc1_weight.size(1) + rs_out = torch.empty(dim_size, dtype=ctx.activation_dtype, device=dgelu.device) + if ub_obj_dgrad.is_p2p_overlap(): + ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P + else: + ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS + ub_obj = ub_obj_dgrad + else: + ub_algo = None + ub_obj = None # FC1 DGRAD: Unconditional _ = tex.gemm( fc1_weight, @@ -857,8 +916,9 @@ def backward( out=fc1_dgrad, layout="NN", grad=True, - ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_AG if ctx.ub_bulk_dgrad else None, - ub=ub_obj_lnout if ctx.ub_bulk_dgrad else None + ub_algo=ub_algo, + ub=ub_obj, + extra_output_tensor=rs_out if ctx.ub_overlap_rs_dgrad else None, ) if ctx.ub_bulk_dgrad: @@ -867,7 +927,7 @@ def backward( if ctx.set_parallel_mode and ctx.sequence_parallel: if not ctx.ub_bulk_dgrad and handle is not None: handle.wait() - if not ctx.ub_bulk_wgrad: + if not ctx.ub_bulk_wgrad and not ctx.ub_overlap_rs_dgrad: if ctx.return_layernorm_output and ctx.return_layernorm_output_gathered: fc1_dgrad = fc1_dgrad + grad_outputs[1].view_as(fc1_dgrad) fc1_dgrad, handle = reduce_scatter_along_first_dim( @@ -969,7 +1029,10 @@ def backward( handle.wait() # LayerNorm gradient - dgrad = fc1_dgrad.view(inputmat.shape) + if ctx.ub_overlap_rs_dgrad: + dgrad = rs_out.view(inputmat.shape) + else: + dgrad = fc1_dgrad.view(inputmat.shape) # Residual gradient if ctx.return_layernorm_output and not ctx.return_layernorm_output_gathered: @@ -1071,6 +1134,7 @@ def backward( None, None, None, + None, ) @@ -1193,6 +1257,7 @@ def __init__( device: Union[torch.device, str] = "cuda", ub_bulk_wgrad: bool = False, ub_bulk_dgrad: bool = False, + ub_overlap_rs_dgrad: bool = False, ub_overlap_rs: bool = False, ub_overlap_ag: bool = False, ) -> None: @@ -1215,13 +1280,14 @@ def __init__( self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters() self.ub_bulk_wgrad = ub_bulk_wgrad self.ub_bulk_dgrad = ub_bulk_dgrad + self.ub_overlap_rs_dgrad = ub_overlap_rs_dgrad self.ub_overlap_rs = ub_overlap_rs self.ub_overlap_ag = ub_overlap_ag # GEMM-GELU fusion is currently only supported with split GEMM-AG overlap self.gemm_gelu_fusion = (bool(int(os.getenv("NVTE_GEMM_GELU_FUSION", "0"))) and self.activation == 'gelu' and not get_ub("fc1_fprop").is_atomic_gemm()) - if any([ub_bulk_wgrad, ub_bulk_dgrad, ub_overlap_rs, ub_overlap_ag]): + if any([ub_bulk_wgrad, ub_bulk_dgrad, ub_overlap_rs, ub_overlap_ag, ub_overlap_rs_dgrad]): assert ( tex.userbuf_comm_available() ), "Userbuffer communication backend not available." @@ -1475,6 +1541,7 @@ def forward( self.primary_weights_in_fp8, self.ub_bulk_wgrad, self.ub_bulk_dgrad, + self.ub_overlap_rs_dgrad, self.ub_overlap_rs, self.ub_overlap_ag, self.gemm_gelu_fusion, diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index a0fd231913..2e00333fa0 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -261,6 +261,7 @@ def __init__( ub_bulk_dgrad: bool = True, ub_overlap_ag: bool = True, ub_overlap_rs: bool = True, + ub_overlap_rs_dgrad: bool = False, bias: bool = True, activation: str = 'gelu', normalization: str = "LayerNorm", @@ -282,6 +283,7 @@ def __init__( ub_bulk_dgrad = ub_tp_comm_overlap and ub_bulk_dgrad ub_overlap_ag = ub_tp_comm_overlap and ub_overlap_ag ub_overlap_rs = ub_tp_comm_overlap and ub_overlap_rs + ub_overlap_rs_dgrad = ub_tp_comm_overlap and ub_overlap_rs_dgrad bias_dropout_fusion = bool(int(os.getenv("NVTE_BIAS_DROPOUT_FUSION", "1"))) self.layer_number = layer_number @@ -357,6 +359,7 @@ def __init__( "ub_bulk_dgrad" : ub_bulk_dgrad, "ub_overlap_ag" : ub_overlap_ag, "ub_overlap_rs" : ub_overlap_rs, + "ub_overlap_rs_dgrad" : ub_overlap_rs_dgrad, "qkv_format" : self.attn_input_format, } @@ -410,6 +413,7 @@ def __init__( zero_centered_gamma=zero_centered_gamma, ub_bulk_wgrad=ub_bulk_wgrad, ub_bulk_dgrad=ub_bulk_dgrad, + ub_overlap_rs_dgrad=ub_overlap_rs_dgrad, ub_overlap_rs=ub_overlap_rs, ub_overlap_ag=ub_overlap_ag, activation=activation,