From b1dfe2b94575cc1a797191c49d659615f6f34497 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Fri, 26 Jul 2024 11:00:27 -0700 Subject: [PATCH] [5/x] clean up casting: cast_to_float8_e4m3_dynamic -> cast_to_float8_dynamic Summary: Moves the dtype from function name to argument, to match delayed scaling version. Test Plan: ``` ./test/test_everything.sh ``` Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned] --- float8_experimental/float8_linear.py | 7 ++++--- float8_experimental/float8_scaling_utils.py | 8 ++++---- float8_experimental/float8_tensor_parallel.py | 12 ++++++++---- float8_experimental/fsdp_utils.py | 5 +++-- test/test_base.py | 1 - 5 files changed, 19 insertions(+), 14 deletions(-) diff --git a/float8_experimental/float8_linear.py b/float8_experimental/float8_linear.py index 6e184c2..7a7adf2 100644 --- a/float8_experimental/float8_linear.py +++ b/float8_experimental/float8_linear.py @@ -19,7 +19,7 @@ from float8_experimental.float8_scaling_utils import ( _maybe_initialize_amaxes_scales_for_float8_cast, cast_to_float8_delayed, - cast_to_float8_e4m3_dynamic, + cast_to_float8_dynamic, NoopFwToFloat8E5M2BwDelayed, NoopFwToFloat8E5M2BwDynamic, ) @@ -270,7 +270,7 @@ def cast_input_to_float8( ) else: assert self.scaling_type_input is ScalingType.DYNAMIC - input_fp8 = cast_to_float8_e4m3_dynamic(input, self.linear_mm_config) + input_fp8 = cast_to_float8_dynamic(input, e4m3_dtype, self.linear_mm_config) return input_fp8 def cast_weight_to_float8( @@ -305,8 +305,9 @@ def cast_weight_to_float8( if isinstance(self.weight, Float8Tensor): # cast by FSDP weight_fp8 = self.weight else: - weight_fp8 = cast_to_float8_e4m3_dynamic( + weight_fp8 = cast_to_float8_dynamic( self.weight, + e4m3_dtype, self.linear_mm_config, gemm_input_role=GemmInputRole.WEIGHT, ) diff --git a/float8_experimental/float8_scaling_utils.py b/float8_experimental/float8_scaling_utils.py index a590d62..7c387bf 100644 --- a/float8_experimental/float8_scaling_utils.py +++ b/float8_experimental/float8_scaling_utils.py @@ -30,25 +30,25 @@ ) -def cast_to_float8_e4m3_dynamic( +def cast_to_float8_dynamic( inpt_tensor: torch.Tensor, + float8_dtype: torch.dtype, linear_mm_config: LinearMMConfig, reduce_amax: bool = False, gemm_input_role: GemmInputRole = GemmInputRole.INPUT, ) -> Float8Tensor: if tensor_already_casted_to_fp8(inpt_tensor): return inpt_tensor - scale = tensor_to_scale(inpt_tensor, e4m3_dtype, reduce_amax) + scale = tensor_to_scale(inpt_tensor, float8_dtype, reduce_amax) return hp_tensor_and_scale_to_float8( inpt_tensor, scale, - e4m3_dtype, + float8_dtype, linear_mm_config, gemm_input_role, ) -# TODO(future PR): align name with cast_to_float8_e4m3_dynamic def cast_to_float8_delayed( tensor: torch.Tensor, scale: torch.Tensor, diff --git a/float8_experimental/float8_tensor_parallel.py b/float8_experimental/float8_tensor_parallel.py index 54127af..4a77a45 100644 --- a/float8_experimental/float8_tensor_parallel.py +++ b/float8_experimental/float8_tensor_parallel.py @@ -2,10 +2,11 @@ import torch.nn as nn from float8_experimental.config import ScalingType from float8_experimental.float8_scaling_utils import ( - cast_to_float8_e4m3_dynamic, + cast_to_float8_dynamic, NoopFwToFloat8E5M2BwDynamic, ) from float8_experimental.float8_tensor import GemmInputRole +from float8_experimental.float8_utils import e4m3_dtype from torch.distributed._tensor import DTensor from torch.distributed.device_mesh import DeviceMesh from torch.distributed.tensor.parallel import ( @@ -45,8 +46,9 @@ def _prepare_input_fn( input_tensor, device_mesh, input_layouts, run_check=False ) - input_tensor = cast_to_float8_e4m3_dynamic( + input_tensor = cast_to_float8_dynamic( input_tensor, + e4m3_dtype, mod.linear_mm_config, gemm_input_role=GemmInputRole.INPUT, ) # DTensor(Float8Tensor) @@ -98,8 +100,9 @@ def _prepare_input_fn( input_tensor, device_mesh, input_layouts, run_check=False ) - input_tensor = cast_to_float8_e4m3_dynamic( + input_tensor = cast_to_float8_dynamic( input_tensor, + e4m3_dtype, mod.linear_mm_config, gemm_input_role=GemmInputRole.INPUT, ) # DTensor(Float8Tensor) @@ -196,8 +199,9 @@ def _prepare_input_arg(self, input, mesh, input_layout, desired_layout): input, mesh, (input_layout,), run_check=False ) - dt_inp = cast_to_float8_e4m3_dynamic( + dt_inp = cast_to_float8_dynamic( dt_inp, + e4m3_dtype, self.linear_mm_config, gemm_input_role=GemmInputRole.INPUT, ) # DTensor(Float8Tensor) diff --git a/float8_experimental/fsdp_utils.py b/float8_experimental/fsdp_utils.py index 702df26..c5424ac 100644 --- a/float8_experimental/fsdp_utils.py +++ b/float8_experimental/fsdp_utils.py @@ -12,7 +12,7 @@ import torch.utils._pytree as pytree from float8_experimental.float8_scaling_utils import ( cast_to_float8_delayed, - cast_to_float8_e4m3_dynamic, + cast_to_float8_dynamic, ) from float8_experimental.float8_tensor import ( @@ -175,8 +175,9 @@ def fsdp_pre_all_gather(self, mesh): GemmInputRole.WEIGHT, ) else: - float8_tensor = cast_to_float8_e4m3_dynamic( + float8_tensor = cast_to_float8_dynamic( self._tensor, + e4m3_dtype, self._linear_mm_config, reduce_amax=True, gemm_input_role=GemmInputRole.WEIGHT, diff --git a/test/test_base.py b/test/test_base.py index 82d0c60..4e0c685 100644 --- a/test/test_base.py +++ b/test/test_base.py @@ -24,7 +24,6 @@ sync_float8_amax_and_scale_history, ) from float8_experimental.float8_python_api import addmm_float8_unwrapped -from float8_experimental.float8_scaling_utils import cast_to_float8_e4m3_dynamic from float8_experimental.float8_tensor import ( Float8Tensor, GemmInputRole,