From 183dbec75635482fcecd57839f0fd571651d176b Mon Sep 17 00:00:00 2001 From: vasiliy Date: Mon, 23 Sep 2024 09:52:24 -0700 Subject: [PATCH 1/8] Update [ghstack-poisoned] --- test/float8/test_base.py | 119 +++++++++++++++++++++++-- torchao/float8/config.py | 12 +++ torchao/float8/float8_ops.py | 93 ++++++++++++++++++- torchao/float8/float8_scaling_utils.py | 15 +++- torchao/float8/float8_tensor.py | 34 ++++--- torchao/float8/float8_utils.py | 25 ++++-- 6 files changed, 275 insertions(+), 23 deletions(-) diff --git a/test/float8/test_base.py b/test/float8/test_base.py index 2a875c44d6..e6dd67951c 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -22,7 +22,12 @@ pytest.skip("Unsupported PyTorch version", allow_module_level=True) -from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType +from torchao.float8.config import ( + CastConfig, + Float8LinearConfig, + ScalingGranularity, + ScalingType, +) from torchao.float8.float8_linear import Float8Linear from torchao.float8.float8_linear_utils import ( convert_to_float8_training, @@ -30,6 +35,7 @@ sync_float8_amax_and_scale_history, ) from torchao.float8.float8_python_api import addmm_float8_unwrapped +from torchao.float8.float8_scaling_utils import hp_tensor_to_float8_dynamic from torchao.float8.float8_tensor import ( Float8Tensor, GemmInputRole, @@ -58,7 +64,7 @@ def bitwise_identical(a: Float8Tensor, b: Float8Tensor) -> bool: return True -class TestFloat8Tensor(unittest.TestCase): +class TestFloat8Tensor: def test_preserves_dtype(self) -> None: # hp means high precision, lp means low precision hp_dtypes = (torch.float32, torch.float16, torch.bfloat16) @@ -68,7 +74,7 @@ def test_preserves_dtype(self) -> None: x1_s = tensor_to_scale(x1_hp, lp_dtype) x2_lp = hp_tensor_and_scale_to_float8(x1_hp, x1_s, lp_dtype) x3_hp = x2_lp.to_original_precision() - self.assertTrue(x3_hp.dtype == hp_dtype) + assert x3_hp.dtype == hp_dtype def test_differentiable_casts(self) -> None: lp_dtypes = (e4m3_dtype, e5m2_dtype) @@ -103,7 +109,7 @@ def test_index_put(self): fp8_b = hp_tensor_and_scale_to_float8(b, scale_a, torch.float8_e4m3fn) fp8_b_bad = hp_tensor_and_scale_to_float8(b, scale_b, torch.float8_e4m3fn) - with self.assertRaises(AssertionError): + with pytest.raises(AssertionError): b[index] = fp8_a fp8_b[index] = a fp8_b_bad[index] = fp8_a @@ -117,7 +123,7 @@ def test_copy_(self): b = torch.empty(16, dtype=torch.bfloat16) b.copy_(fp8_a) # Should work torch.testing.assert_close(b, fp8_a.to_original_precision()) - with self.assertRaises(RuntimeError): + with pytest.raises(RuntimeError): fp8_a.copy_(b) # Should fail fp8_b = Float8Tensor( @@ -129,6 +135,109 @@ def test_copy_(self): fp8_b.copy_(fp8_a) torch.testing.assert_close(fp8_a._data, fp8_b._data) + @pytest.mark.parametrize("shape", [(8, 16), (4, 8, 16), (2, 4, 8, 16)]) + @pytest.mark.parametrize("dim_name", ["first", "last"]) + def test_axiswise_dynamic_cast(self, shape, dim_name): + a = torch.randn(*shape, dtype=torch.bfloat16) + + if dim_name == "first": + dim = 0 + elif dim_name == "last": + dim = len(a.shape) - 1 + + linear_mm_config = LinearMMConfig() + a_fp8 = hp_tensor_to_float8_dynamic( + a, + e4m3_dtype, + linear_mm_config, + scaling_granularity=ScalingGranularity.AXISWISE, + axiswise_dim=dim, + ) + a_dq = a_fp8.to_original_precision() + sqnr = compute_error(a, a_dq) + assert sqnr >= 25.0 + + def test_axiswise_reshape(self): + a = torch.randn(3, 5, 7, dtype=torch.bfloat16, device="cuda") + linear_mm_config = LinearMMConfig() + + # if we scale across dim0, we can only reshape to [3, -1] + a_fp8_d0 = hp_tensor_to_float8_dynamic( + a, + e4m3_dtype, + linear_mm_config, + scaling_granularity=ScalingGranularity.AXISWISE, + axiswise_dim=0, + ) + assert list(a_fp8_d0._data.shape) == [3, 5, 7] + assert list(a_fp8_d0._scale.shape) == [1, 5, 7] + + a_fp8_d0_r = a_fp8_d0.reshape(3, -1) + assert list(a_fp8_d0_r.shape) == [3, 5 * 7] + assert list(a_fp8_d0_r._scale.shape) == [1, 5 * 7] + # verify numerics did not change + assert torch.allclose( + a_fp8_d0.to_original_precision(), + a_fp8_d0_r.to_original_precision().reshape(3, 5, 7), + atol=0, + rtol=0, + ) + with pytest.raises(RuntimeError): + a_fp8_d0_r2 = a_fp8_d0.reshape(-1, 7) + + # if we scale across dim2, we can only reshape to [-1, 7] + a_fp8_d2 = hp_tensor_to_float8_dynamic( + a, + e4m3_dtype, + linear_mm_config, + scaling_granularity=ScalingGranularity.AXISWISE, + axiswise_dim=2, + ) + assert list(a_fp8_d2._data.shape) == [3, 5, 7] + assert list(a_fp8_d2._scale.shape) == [3, 5, 1] + + a_fp8_d2_r = a_fp8_d2.reshape(-1, 7) + assert list(a_fp8_d2_r.shape) == [3 * 5, 7] + assert list(a_fp8_d2_r._scale.shape) == [3 * 5, 1] + # verify numerics did not change + assert torch.allclose( + a_fp8_d2.to_original_precision(), + a_fp8_d2_r.to_original_precision().reshape(3, 5, 7), + atol=0, + rtol=0, + ) + with pytest.raises(RuntimeError): + a_fp8_d2_r2 = a_fp8_d2.reshape(3, -1) + + @pytest.mark.parametrize("a_shape", [(16, 32), (2, 16, 32), (1, 2, 16, 32)]) + def test_axiswise_gemm(self, a_shape): + a = torch.randn(*a_shape, dtype=torch.bfloat16, device="cuda") + b = torch.randn(64, 32, dtype=torch.bfloat16, device="cuda") + + linear_mm_config = LinearMMConfig() + + a_fp8 = hp_tensor_to_float8_dynamic( + a, + e4m3_dtype, + linear_mm_config, + gemm_input_role=GemmInputRole.INPUT, + scaling_granularity=ScalingGranularity.AXISWISE, + axiswise_dim=-1, + ) + a_fp8 = a_fp8.reshape(-1, a_shape[-1]) + b_fp8 = hp_tensor_to_float8_dynamic( + b, + e4m3_dtype, + linear_mm_config, + gemm_input_role=GemmInputRole.WEIGHT, + scaling_granularity=ScalingGranularity.AXISWISE, + axiswise_dim=1, # will be transposed + ) + c_fp8_compute = torch.mm(a_fp8, b_fp8.t()) + a = a.reshape(-1, a_shape[-1]) + c_ref = torch.mm(a, b.t()) + sqnr = compute_error(c_ref, c_fp8_compute) + assert sqnr >= 25.0 class TestFloat8Linear: diff --git a/torchao/float8/config.py b/torchao/float8/config.py index eb28dcbd8e..b24b5ba749 100644 --- a/torchao/float8/config.py +++ b/torchao/float8/config.py @@ -26,6 +26,18 @@ def short_str(self): return "sta" +class ScalingGranularity(enum.Enum): + """ + Defines the granularity of scaling strategies for casting to float8 + """ + + # A single scaling factor for the entire tensor + TENSORWISE = "tensorwise" + # Scaling factors computed along one axis of the tensor, reducing it to + # size 1. + AXISWISE = "axiswise" + + @dataclass(frozen=True) class CastConfig: """ diff --git a/torchao/float8/float8_ops.py b/torchao/float8/float8_ops.py index f8115649b3..1bf9faaa4c 100644 --- a/torchao/float8/float8_ops.py +++ b/torchao/float8/float8_ops.py @@ -19,6 +19,15 @@ FLOAT8_OPS_TABLE: Dict[Any, Any] = {} +def _assert_tensorwise_scale(aten_op, scale): + assert ( + # TODO(future PR): figure out why tensorwise scaling can have + # both rank 0 and rank 1 + len(scale.shape) + in (0, 1) + ), f"{aten_op} with axiswise scaling is not supported yet" + + def implements(aten_ops): """Register aten ops to the float8 op table""" @@ -45,6 +54,7 @@ def decorator(func): ] ) def float8_desugar_op(aten_op, args, kwargs=None): + _assert_tensorwise_scale(aten_op, args[0]._scale) new_data = aten_op(args[0]._data, *args[1:], **kwargs) return Float8Tensor( new_data, @@ -55,10 +65,82 @@ def float8_desugar_op(aten_op, args, kwargs=None): ) +@implements( + [ + aten.t.default, + aten.transpose.int, + ] +) +def float8_desugar_data_and_scale(aten_op, args, kwargs=None): + new_data = aten_op(args[0]._data, *args[1:], **kwargs) + new_scale = aten_op(args[0]._scale, *args[1:], **kwargs) + + if aten_op == aten.transpose.int: + _assert_tensorwise_scale(aten_op, args[0]._scale) + + old_axiswise_dim = args[0]._axiswise_dim + new_axiswise_dim = old_axiswise_dim + if old_axiswise_dim is not None: + if old_axiswise_dim == 0: + new_axiswise_dim == -1 + else: + new_axiswise_dim == 0 + + return Float8Tensor( + new_data, + new_scale, + args[0]._orig_dtype, + args[0]._linear_mm_config, + args[0]._gemm_input_role, + new_axiswise_dim, + ) + + +@implements([aten.view.default]) +def float8_view(aten_op, args, kwargs=None): + if len(args[0]._scale.shape) < 2: + # tensorwise scaling + return float8_desugar_op(aten_op, args, kwargs) + + t, new_shape = args[0], args[1] + # for now, only support reshaping to [-1, dim] or [dim, -1] + axiswise_dim = t._axiswise_dim + if len(new_shape) == 2: + + if axiswise_dim == 0: + new_data = aten_op(t._data, new_shape, **kwargs) + new_scale_shape = [1, new_shape[-1]] + new_scale = aten_op(t._scale, new_scale_shape, **kwargs) + return Float8Tensor( + new_data, + new_scale, + t._orig_dtype, + t._linear_mm_config, + t._gemm_input_role, + t._axiswise_dim, + ) + elif axiswise_dim == -1 or axiswise_dim == (len(t.shape) - 1): + new_data = aten_op(t._data, new_shape, **kwargs) + new_scale_shape = [new_shape[0], 1] + new_scale = aten_op(t._scale, new_scale_shape, **kwargs) + new_axiswise_dim = -1 + return Float8Tensor( + new_data, + new_scale, + t._orig_dtype, + t._linear_mm_config, + t._gemm_input_role, + new_axiswise_dim, + ) + raise AssertionError( + f"{aten_op} with axiswise scaling and t.shape {t.shape} t._scale.shape {t._scale.shape} t._axiswise_dim {t._axiswise_dim} new_shape {new_shape} is not supported yet." + ) + + @implements([aten.split.Tensor]) def float8_split(aten_op, args, kwargs=None): new_data_tensors = aten_op(args[0]._data, *args[1:], **kwargs) - + _assert_tensorwise_scale(aten_op, args[0]._scale) def make_float8(data): return Float8Tensor( data, @@ -102,6 +184,7 @@ def float8_cat(aten_op, args, kwargs=None): assert ( chunk._gemm_input_role is gemm_input_role ), "Expecting all chunks to have the same gemm_input_role as a result of a split" + _assert_tensorwise_scale(aten_op, chunk._scale) chunk_data.append(chunk._data.view(torch.uint8)) new_data = aten_op(chunk_data, *args[1:], **kwargs) @@ -118,6 +201,7 @@ def float8_cast_up_op(aten_op, args, kwargs=None): "addmm" -> out "hp_gradBias" <-"sum" <- "identity" <- gradOut <- "hp_gradOut" """ + _assert_tensorwise_scale(aten_op, args[0]._scale) def unwrap(x): if isinstance(x, Float8Tensor): @@ -230,6 +314,7 @@ def float8_addmm(aten_op, args, kwargs=None): @implements([aten.is_same_size.default]) def float8_is_same_size(aten_op, args, kwargs=None): + _assert_tensorwise_scale(aten_op, args[0]._scale) return args[0].shape == args[1].shape @@ -239,6 +324,7 @@ def autocast_to_copy(aten_op, args, kwargs=None): when the input is a Float8Tensor, presenting as a fp32 tensor. """ + _assert_tensorwise_scale(aten_op, args[0]._scale) assert isinstance(args[0], Float8Tensor) assert ( len(kwargs) == 1 and "dtype" in kwargs @@ -266,6 +352,7 @@ def allgather_fp8(aten_op, args, kwargs=None): """ override funcol with FP8 handling """ + _assert_tensorwise_scale(aten_op, args[0]._scale) fp8_input = args[0] assert isinstance( fp8_input, Float8Tensor @@ -285,6 +372,7 @@ def allgather_fp8(aten_op, args, kwargs=None): @implements([c10d_functional.wait_tensor.default, _c10d_functional.wait_tensor.default]) def wait_tensor_fp8(aten_op, args, kwargs=None): + _assert_tensorwise_scale(aten_op, args[0]._scale) fp8_input = args[0] assert isinstance(fp8_input, Float8Tensor) @@ -305,6 +393,7 @@ def index_put_fp8(aten_op, args, kwargs=None): fp8_values = args[2] assert isinstance(fp8_self, Float8Tensor) assert isinstance(fp8_values, Float8Tensor) + _assert_tensorwise_scale(fp8_self, args[0]._scale) assert fp8_self._scale == fp8_values._scale assert fp8_self.dtype == fp8_values.dtype assert fp8_self._orig_dtype == fp8_values._orig_dtype @@ -335,8 +424,10 @@ def copy_fp8(aten_op, args, kwargs=None): if not isinstance(self, Float8Tensor) and isinstance(src, Float8Tensor): src_hp = src.to_original_precision() + _assert_tensorwise_scale(aten_op, src._scale) return aten_op(self, src_hp, *args[2:], **kwargs) elif isinstance(self, Float8Tensor) and isinstance(src, Float8Tensor): + _assert_tensorwise_scale(aten_op, src._scale) assert ( self._orig_dtype == src._orig_dtype ), "Expecting both Float8Tensors to be of the same dtype" diff --git a/torchao/float8/float8_scaling_utils.py b/torchao/float8/float8_scaling_utils.py index d2ae896320..f46293d616 100644 --- a/torchao/float8/float8_scaling_utils.py +++ b/torchao/float8/float8_scaling_utils.py @@ -12,6 +12,8 @@ import torch +from torchao.float8.config import ScalingGranularity + from torchao.float8.float8_tensor import ( Float8Tensor, GemmInputRole, @@ -36,6 +38,8 @@ def hp_tensor_to_float8_dynamic( linear_mm_config: LinearMMConfig, reduce_amax: bool = False, gemm_input_role: GemmInputRole = GemmInputRole.INPUT, + scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE, + axiswise_dim: Optional[int] = None, ) -> Float8Tensor: """ Given a high precision tensor `hp_tensor`, @@ -49,16 +53,25 @@ def hp_tensor_to_float8_dynamic( reduce_amax: whether to reduce the max(abs(hp_tensor)) value across distributed ranks gemm_input_role: Defines the role of this tensor (input, weight or grad_output) in the 3 fwd/bwd gemms of linear + scaling_granularity: Defines the scaling granularity + axiswise_dim: if axiswise granularity is used, defines the dim to scale across """ if tensor_already_casted_to_fp8(hp_tensor): return hp_tensor - scale = tensor_to_scale(hp_tensor, float8_dtype, reduce_amax) + scale = tensor_to_scale( + hp_tensor, + float8_dtype, + reduce_amax, + scaling_granularity, + axiswise_dim, + ) return hp_tensor_and_scale_to_float8( hp_tensor, scale, float8_dtype, linear_mm_config, gemm_input_role, + axiswise_dim, ) diff --git a/torchao/float8/float8_tensor.py b/torchao/float8/float8_tensor.py index 63110101a5..c8b68586c0 100644 --- a/torchao/float8/float8_tensor.py +++ b/torchao/float8/float8_tensor.py @@ -152,6 +152,7 @@ def forward( float8_dtype=e4m3_dtype, linear_mm_config: Optional[LinearMMConfig] = None, gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT, + axiswise_dim: Optional[int] = None, ): """ This function will apply the scaling, and then convert to a Float8Tensor @@ -180,6 +181,7 @@ def forward( tensor.dtype, linear_mm_config=linear_mm_config, gemm_input_role=gemm_input_role, + axiswise_dim=axiswise_dim, ) return DTensor.from_local( inner_float8_tensor, @@ -196,6 +198,7 @@ def forward( tensor.dtype, linear_mm_config=linear_mm_config, gemm_input_role=gemm_input_role, + axiswise_dim=axiswise_dim, ) @staticmethod @@ -226,6 +229,7 @@ def hp_tensor_and_scale_to_float8( float8_dtype=e4m3_dtype, linear_mm_config: Optional[LinearMMConfig] = None, gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT, + axiswise_dim: Optional[int] = None, ): """ Given a high precision tensor `hp_tensor` and a precalculated scale `s`, @@ -242,9 +246,10 @@ def hp_tensor_and_scale_to_float8( the 3 fwd/bwd gemms of linear gemm_input_role: Defines the role of this tensor (input, weight or grad_output) in the 3 fwd/bwd gemms of linear + axiswise_dim: for rowwise scaling, contains the axis scaled across """ return _ToFloat8ConstrFunc.apply( - hp_tensor, s, float8_dtype, linear_mm_config, gemm_input_role + hp_tensor, s, float8_dtype, linear_mm_config, gemm_input_role, axiswise_dim ) @@ -258,11 +263,19 @@ class Float8Tensor(torch.Tensor): * `_data`: the underlying e4m3 or e5m2 data * `_scale`: the scale used to scale the original fp32 tensor. We multiply by scale to go from fp32 range to fp8 range, and divide by scale to go - from fp8 range to fp32 range. + from fp8 range to fp32 range. Scale is guaranteed to have a shape compatible + with `_data`. For example: + - if scaling is tensorwise, `_scale` is a scalar tensor + - if scaling is axiswise and _data.shape is [3, 5], `_scale` could have + shape [1, 5] or [3, 1]. The dim of the non-one entry defines the scaling + axis. + - if scaling is axiswise and _data.shape is [2, 3, 5], `_scale` could have + shape [1, 1, 5] or [2, 1, 1]. The dim of the non-one entry defines the scaling + axis. Non-one entries which are not the first or last element are not + supported. * `_orig_dtype`: the original dtype of the tensor used to create this tensor. - * `_emulate`: if true using fp32 emulation for the matmuls, helpful - if you don't have access to h100 hardware. + * `_axiswise_dim`: for axiswise scaling only, contains the axis scales across Intended usage of this abstraction: 1. to bundle raw data + fp8 metadata together for easy passing through @@ -277,6 +290,7 @@ class Float8Tensor(torch.Tensor): _scale: torch.Tensor _orig_dtype: torch.dtype _linear_mm_config: LinearMMConfig + _axiswise_dim: Optional[int] __slots__ = ["_data", "_scale", "_orig_dtype", "_linear_mm_config"] def __new__( @@ -286,13 +300,8 @@ def __new__( orig_dtype: torch.dtype, linear_mm_config: Optional[LinearMMConfig], gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT, + axiswise_dim: Optional[int] = None, ): - assert ( - scale.numel() == 1 - ), "Scale should contain a single value, but got: {} elements".format( - scale.numel() - ) - self = torch.Tensor._make_wrapper_subclass( cls, data.size(), @@ -310,17 +319,19 @@ def __new__( linear_mm_config if linear_mm_config is not None else LinearMMConfig() ) self._gemm_input_role = gemm_input_role + self._axiswise_dim = axiswise_dim return self def __repr__(self): - return f"Float8Tensor(dtype={self._data.dtype}, scale={self._scale}, linear_mm_config={self._linear_mm_config}\ngemm_input_role={self._gemm_input_role}\nas_orig_prec={self.to_original_precision()}" + return f"Float8Tensor(dtype={self._data.dtype}, scale={self._scale}, linear_mm_config={self._linear_mm_config}, axiswise_dim={self._axiswise_dim}\ngemm_input_role={self._gemm_input_role}\nas_orig_prec={self.to_original_precision()}" def __tensor_flatten__(self): ctx = { "_orig_dtype": self._orig_dtype, "_linear_mm_config": self._linear_mm_config, "_gemm_input_role": self._gemm_input_role, + "_axiswise_dim": self._axiswise_dim, } return ["_data", "_scale"], ctx @@ -333,6 +344,7 @@ def __tensor_unflatten__(inner_tensors: Dict, metadata, outer_size, outer_stride metadata["_orig_dtype"], metadata["_linear_mm_config"], metadata["_gemm_input_role"], + metadata["_axiswise_dim"], ) def to_original_precision(self): diff --git a/torchao/float8/float8_utils.py b/torchao/float8/float8_utils.py index 54613e5b05..55e520f8ca 100644 --- a/torchao/float8/float8_utils.py +++ b/torchao/float8/float8_utils.py @@ -4,12 +4,13 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. -from typing import Iterable, Literal, Tuple, Union +from typing import Iterable, Literal, Optional, Tuple, Union import torchao.float8.config as config import torch import torch.distributed as dist +from torchao.float8.config import ScalingGranularity # Helpful visualizer for debugging (only supports fp32): # https://www.h-schmidt.net/FloatConverter/IEEE754.html @@ -98,8 +99,18 @@ def amax_history_to_scale_stack( @torch.no_grad() -def tensor_to_amax(x: torch.Tensor, reduce_amax: bool = False) -> torch.Tensor: - amax = torch.max(torch.abs(x)) +def tensor_to_amax( + x: torch.Tensor, + reduce_amax: bool = False, + scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE, + axiswise_dim: Optional[int] = None, +) -> torch.Tensor: + if scaling_granularity is ScalingGranularity.TENSORWISE: + amax = torch.max(torch.abs(x)) + else: + assert scaling_granularity is ScalingGranularity.AXISWISE, "unsupported" + assert axiswise_dim is not None, "unsupported" + amax = torch.amax(torch.abs(x), dim=axiswise_dim, keepdim=True) # If the user asked for distributed reduction, do it. # If the user did not ask for it, assume that it will @@ -112,9 +123,13 @@ def tensor_to_amax(x: torch.Tensor, reduce_amax: bool = False) -> torch.Tensor: @torch.no_grad() def tensor_to_scale( - x: torch.Tensor, float8_dtype: torch.dtype, reduce_amax: bool = False + x: torch.Tensor, + float8_dtype: torch.dtype, + reduce_amax: bool = False, + scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE, + axiswise_dim: Optional[int] = None, ) -> torch.Tensor: - amax = tensor_to_amax(x, reduce_amax=reduce_amax) + amax = tensor_to_amax(x, reduce_amax, scaling_granularity, axiswise_dim) return amax_to_scale(amax, float8_dtype, x.dtype) From 241f815f73b85f27a59fdf9f8e9d0a80cedb5f2c Mon Sep 17 00:00:00 2001 From: vasiliy Date: Mon, 23 Sep 2024 09:52:27 -0700 Subject: [PATCH 2/8] Update [ghstack-poisoned] --- benchmarks/float8/bench_linear_float8.py | 32 +++- benchmarks/float8/profile_linear_float8.py | 27 ++- test/float8/test_base.py | 32 +++- test/float8/test_compile.py | 103 ++++++++++- test/float8/test_numerics_integration.py | 37 +++- torchao/float8/config.py | 32 ++++ torchao/float8/float8_linear.py | 188 +++++++++++++++++++-- torchao/float8/float8_ops.py | 22 ++- 8 files changed, 434 insertions(+), 39 deletions(-) diff --git a/benchmarks/float8/bench_linear_float8.py b/benchmarks/float8/bench_linear_float8.py index e18006f0e4..f92303c627 100644 --- a/benchmarks/float8/bench_linear_float8.py +++ b/benchmarks/float8/bench_linear_float8.py @@ -14,7 +14,12 @@ import torch import torch.utils.benchmark as benchmark -from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType +from torchao.float8.config import ( + CastConfig, + Float8LinearConfig, + ScalingType, + ScalingGranularity, +) from torchao.float8.float8_linear import Float8Linear from torchao.float8.float8_linear_utils import ( linear_requires_sync, @@ -107,6 +112,7 @@ def main( scaling_type_input: str = "dynamic", scaling_type_weight: str = "dynamic", scaling_type_grad_output: str = "dynamic", + scaling_granularity: str = "tensorwise", ): device = "cuda" print(f"Compile is set to | {compile}") @@ -114,28 +120,41 @@ def main( scaling_type_input = ScalingType(scaling_type_input) scaling_type_weight = ScalingType(scaling_type_weight) scaling_type_grad_output = ScalingType(scaling_type_grad_output) + scaling_granularity = ScalingGranularity(scaling_granularity) if scaling_type_input is ScalingType.STATIC: cast_config_input=CastConfig( scaling_type=scaling_type_input, static_scale=torch.tensor([1.0], device="cuda"), + scaling_granularity=scaling_granularity, ) else: - cast_config_input=CastConfig(scaling_type=scaling_type_input) + cast_config_input=CastConfig( + scaling_type=scaling_type_input, + scaling_granularity=scaling_granularity, + ) if scaling_type_weight is ScalingType.STATIC: cast_config_weight=CastConfig( scaling_type=scaling_type_weight, static_scale=torch.tensor([1.0], device="cuda"), + scaling_granularity=scaling_granularity, ) else: - cast_config_weight=CastConfig(scaling_type=scaling_type_weight) + cast_config_weight=CastConfig( + scaling_type=scaling_type_weight, + scaling_granularity=scaling_granularity, + ) if scaling_type_grad_output is ScalingType.STATIC: cast_config_grad_output=CastConfig( scaling_type=scaling_type_grad_output, static_scale=torch.tensor([1.0], device="cuda"), + scaling_granularity=scaling_granularity, ) else: - cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output) + cast_config_grad_output=CastConfig( + scaling_type=scaling_type_grad_output, + scaling_granularity=scaling_granularity, + ) config = Float8LinearConfig( cast_config_input=cast_config_input, @@ -167,7 +186,7 @@ def main( copy.deepcopy(linear_ref), config=config, ) - scaling_repr = linear_float8.scaling_repr() + scaling_repr = f"{linear_float8.scaling_type_repr()},{linear_float8.scaling_granularity_repr()}" if fast_accum: linear_float8.forward_config = ScaledMMConfig(False, True, False) @@ -310,6 +329,7 @@ def invoke_main() -> None: parser.add_argument("--scaling_type_input", type=str, required=False) parser.add_argument("--scaling_type_weight", type=str, required=False) parser.add_argument("--scaling_type_grad_output", type=str, required=False) + parser.add_argument("--scaling_granularity", type=str, required=False) args = parser.parse_args() output_path = Path(args.output_path) if args.output_path is not None else None kwargs = {} @@ -327,6 +347,8 @@ def invoke_main() -> None: kwargs["scaling_type_weight"] = args.scaling_type_weight if args.scaling_type_grad_output is not None: kwargs["scaling_type_grad_output"] = args.scaling_type_grad_output + if args.scaling_granularity is not None: + kwargs["scaling_granularity"] = args.scaling_granularity main( output_path, not args.disable_compile, diff --git a/benchmarks/float8/profile_linear_float8.py b/benchmarks/float8/profile_linear_float8.py index c204d49b03..6afefa0096 100644 --- a/benchmarks/float8/profile_linear_float8.py +++ b/benchmarks/float8/profile_linear_float8.py @@ -22,7 +22,12 @@ import torch import torch.nn as nn import torch.nn.functional as F -from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType +from torchao.float8.config import ( + CastConfig, + Float8LinearConfig, + ScalingType, + ScalingGranularity, +) from torchao.float8.float8_linear_utils import ( convert_to_float8_training, linear_requires_sync, @@ -252,6 +257,7 @@ def main( scaling_type_input: str = "dynamic", scaling_type_weight: str = "dynamic", scaling_type_grad_output: str = "dynamic", + scaling_granularity: str = "tensorwise", model_type: str = "linear", dtype_filter: str = "both", add_inductor_metadata_to_trace: bool = True, @@ -263,28 +269,41 @@ def main( scaling_type_input = ScalingType(scaling_type_input) scaling_type_weight = ScalingType(scaling_type_weight) scaling_type_grad_output = ScalingType(scaling_type_grad_output) + scaling_granularity = ScalingGranularity(scaling_granularity) if scaling_type_input is ScalingType.STATIC: cast_config_input=CastConfig( scaling_type=scaling_type_input, static_scale=torch.tensor([1.0], device="cuda"), + scaling_granularity=scaling_granularity, ) else: - cast_config_input=CastConfig(scaling_type=scaling_type_input) + cast_config_input=CastConfig( + scaling_type=scaling_type_input, + scaling_granularity=scaling_granularity, + ) if scaling_type_weight is ScalingType.STATIC: cast_config_weight=CastConfig( scaling_type=scaling_type_weight, static_scale=torch.tensor([1.0], device="cuda"), + scaling_granularity=scaling_granularity, ) else: - cast_config_weight=CastConfig(scaling_type=scaling_type_weight) + cast_config_weight=CastConfig( + scaling_type=scaling_type_weight, + scaling_granularity=scaling_granularity, + ) if scaling_type_grad_output is ScalingType.STATIC: cast_config_grad_output=CastConfig( scaling_type=scaling_type_grad_output, static_scale=torch.tensor([1.0], device="cuda"), + scaling_granularity=scaling_granularity, ) else: - cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output) + cast_config_grad_output=CastConfig( + scaling_type=scaling_type_grad_output, + scaling_granularity=scaling_granularity, + ) config = Float8LinearConfig( cast_config_input=cast_config_input, diff --git a/test/float8/test_base.py b/test/float8/test_base.py index e6dd67951c..2fa5394b81 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -327,6 +327,10 @@ def _test_linear_impl( "scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC], ) + @pytest.mark.parametrize( + "scaling_granularity", + [ScalingGranularity.TENSORWISE, ScalingGranularity.AXISWISE], + ) @pytest.mark.parametrize("linear_dtype", [torch.bfloat16, torch.float32]) @pytest.mark.parametrize("linear_bias", [False, True]) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") @@ -337,6 +341,7 @@ def test_linear( scaling_type_input: ScalingType, scaling_type_weight: ScalingType, scaling_type_grad_output: ScalingType, + scaling_granularity: ScalingGranularity, linear_dtype: torch.dtype, linear_bias: bool, ): @@ -349,30 +354,51 @@ def test_linear( f"CUDA capability {torch.cuda.get_device_capability()} < (9.0)" ) pytest.skip() + if scaling_granularity is ScalingGranularity.AXISWISE: + if ( + scaling_type_input != ScalingType.DYNAMIC or + scaling_type_weight != ScalingType.DYNAMIC or + scaling_type_grad_output != ScalingType.DYNAMIC or + linear_dtype != torch.bfloat16 + ): + pytest.skip() + x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype) m_ref = nn.Linear(16, 32, bias=linear_bias, device="cuda", dtype=linear_dtype) if scaling_type_input is ScalingType.STATIC: cast_config_input = CastConfig( scaling_type=scaling_type_input, + scaling_granularity=scaling_granularity, static_scale=torch.tensor([1.0], device="cuda"), ) else: - cast_config_input = CastConfig(scaling_type=scaling_type_input) + cast_config_input = CastConfig( + scaling_type=scaling_type_input, + scaling_granularity=scaling_granularity, + ) if scaling_type_weight is ScalingType.STATIC: cast_config_weight = CastConfig( scaling_type=scaling_type_weight, + scaling_granularity=scaling_granularity, static_scale=torch.tensor([1.0], device="cuda"), ) else: - cast_config_weight = CastConfig(scaling_type=scaling_type_weight) + cast_config_weight = CastConfig( + scaling_type=scaling_type_weight, + scaling_granularity=scaling_granularity, + ) if scaling_type_grad_output is ScalingType.STATIC: cast_config_grad_output = CastConfig( scaling_type=scaling_type_grad_output, + scaling_granularity=scaling_granularity, static_scale=torch.tensor([1.0], device="cuda"), ) else: - cast_config_grad_output = CastConfig(scaling_type=scaling_type_grad_output) + cast_config_grad_output = CastConfig( + scaling_type=scaling_type_grad_output, + scaling_granularity=scaling_granularity, + ) config = Float8LinearConfig( cast_config_input=cast_config_input, diff --git a/test/float8/test_compile.py b/test/float8/test_compile.py index 8a0458bec3..899f63bdb3 100644 --- a/test/float8/test_compile.py +++ b/test/float8/test_compile.py @@ -18,7 +18,12 @@ import torch import torch.nn as nn -from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType +from torchao.float8.config import ( + CastConfig, + Float8LinearConfig, + ScalingType, + ScalingGranularity, +) from torchao.float8.float8_linear import Float8Linear from torchao.float8.float8_linear_utils import ( convert_to_float8_training, @@ -60,6 +65,8 @@ def _test_compile_base( y_fp8.sum().backward() y_ref = m_ref(x) y_ref.sum().backward() + # TODO(future PR): can also test fp8 eager vs compile here with a tigher + # tolerance torch.testing.assert_close(y_fp8, y_ref, atol=9.5e-2, rtol=9.5e-2) torch.testing.assert_close( m_fp8.weight.grad, m_ref.weight.grad, atol=2e-1, rtol=2e-1 @@ -70,29 +77,42 @@ def _get_config( scaling_type_input, scaling_type_weight, scaling_type_grad_output, + scaling_granularity, emulate, ): if scaling_type_input is ScalingType.STATIC: cast_config_input = CastConfig( scaling_type=scaling_type_input, + scaling_granularity=scaling_granularity, static_scale=torch.tensor([1.0], device="cuda"), ) else: - cast_config_input = CastConfig(scaling_type=scaling_type_input) + cast_config_input = CastConfig( + scaling_type=scaling_type_input, + scaling_granularity=scaling_granularity, + ) if scaling_type_weight is ScalingType.STATIC: cast_config_weight = CastConfig( scaling_type=scaling_type_weight, + scaling_granularity=scaling_granularity, static_scale=torch.tensor([1.0], device="cuda"), ) else: - cast_config_weight = CastConfig(scaling_type=scaling_type_weight) + cast_config_weight = CastConfig( + scaling_type=scaling_type_weight, + scaling_granularity=scaling_granularity, + ) if scaling_type_grad_output is ScalingType.STATIC: cast_config_grad_output = CastConfig( scaling_type=scaling_type_grad_output, + scaling_granularity=scaling_granularity, static_scale=torch.tensor([1.0], device="cuda"), ) else: - cast_config_grad_output = CastConfig(scaling_type=scaling_type_grad_output) + cast_config_grad_output = CastConfig( + scaling_type=scaling_type_grad_output, + scaling_granularity=scaling_granularity, + ) config = Float8LinearConfig( cast_config_input=cast_config_input, @@ -103,6 +123,24 @@ def _get_config( return config +def is_supported( + scaling_granularity, + scaling_type_input, + scaling_type_weight, + scaling_type_grad_output, + dtype, +) -> bool: + if scaling_granularity is ScalingGranularity.AXISWISE: + if ( + scaling_type_input != ScalingType.DYNAMIC or + scaling_type_weight != ScalingType.DYNAMIC or + scaling_type_grad_output != ScalingType.DYNAMIC or + dtype != torch.bfloat16 + ): + return False + return True + + @pytest.mark.parametrize("fullgraph", [True]) @pytest.mark.parametrize( "scaling_type_input", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] @@ -113,6 +151,9 @@ def _get_config( @pytest.mark.parametrize( "scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] ) +@pytest.mark.parametrize( + "scaling_granularity", [ScalingGranularity.TENSORWISE, ScalingGranularity.AXISWISE] +) @pytest.mark.parametrize("emulate", [False, True] if is_cuda_8_9 else [True]) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") @@ -122,11 +163,25 @@ def test_eager_only( scaling_type_input: ScalingType, scaling_type_weight: ScalingType, scaling_type_grad_output: ScalingType, + scaling_granularity: ScalingGranularity, dtype: torch.dtype, ): + if not is_supported( + scaling_granularity, + scaling_type_input, + scaling_type_weight, + scaling_type_grad_output, + dtype, + ): + pytest.skip() + torch._dynamo.reset() config = _get_config( - scaling_type_input, scaling_type_weight, scaling_type_grad_output, emulate, + scaling_type_input, + scaling_type_weight, + scaling_type_grad_output, + scaling_granularity, + emulate, ) _test_compile_base( "eager", @@ -147,6 +202,9 @@ def test_eager_only( @pytest.mark.parametrize( "scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] ) +@pytest.mark.parametrize( + "scaling_granularity", [ScalingGranularity.TENSORWISE, ScalingGranularity.AXISWISE] +) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") def test_aot_eager( @@ -155,11 +213,25 @@ def test_aot_eager( scaling_type_input: ScalingType, scaling_type_weight: ScalingType, scaling_type_grad_output: ScalingType, + scaling_granularity: ScalingGranularity, dtype: torch.dtype, ): + if not is_supported( + scaling_granularity, + scaling_type_input, + scaling_type_weight, + scaling_type_grad_output, + dtype, + ): + pytest.skip() + torch._dynamo.reset() config = _get_config( - scaling_type_input, scaling_type_weight, scaling_type_grad_output, emulate, + scaling_type_input, + scaling_type_weight, + scaling_type_grad_output, + scaling_granularity, + emulate, ) _test_compile_base( "aot_eager", @@ -180,6 +252,9 @@ def test_aot_eager( @pytest.mark.parametrize( "scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] ) +@pytest.mark.parametrize( + "scaling_granularity", [ScalingGranularity.TENSORWISE, ScalingGranularity.AXISWISE] +) @unittest.skipIf(not torch.cuda.is_available() or not is_cuda_8_9, "CUDA with float8 support not available") @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) def test_inductor( @@ -188,11 +263,25 @@ def test_inductor( scaling_type_input: ScalingType, scaling_type_weight: ScalingType, scaling_type_grad_output: ScalingType, + scaling_granularity: ScalingGranularity, dtype: torch.dtype, ): + if not is_supported( + scaling_granularity, + scaling_type_input, + scaling_type_weight, + scaling_type_grad_output, + dtype, + ): + pytest.skip() + torch._dynamo.reset() config = _get_config( - scaling_type_input, scaling_type_weight, scaling_type_grad_output, emulate, + scaling_type_input, + scaling_type_weight, + scaling_type_grad_output, + scaling_granularity, + emulate, ) _test_compile_base( "inductor", diff --git a/test/float8/test_numerics_integration.py b/test/float8/test_numerics_integration.py index 6db05dc56d..396de7efa8 100644 --- a/test/float8/test_numerics_integration.py +++ b/test/float8/test_numerics_integration.py @@ -19,7 +19,12 @@ import torch import torch.nn as nn import torch.nn.functional as F -from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType +from torchao.float8.config import ( + CastConfig, + Float8LinearConfig, + ScalingType, + ScalingGranularity, +) from torchao.float8.float8_linear_utils import ( convert_to_float8_training, linear_requires_sync, @@ -90,6 +95,10 @@ class TestFloat8NumericsIntegrationTest: "scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC], ) + @pytest.mark.parametrize( + "scaling_granularity", + [ScalingGranularity.TENSORWISE, ScalingGranularity.AXISWISE], + ) @pytest.mark.skipif(not is_cuda_8_9, reason="requires SM89 compatible machine") @pytest.mark.skipif(IS_ROCM, reason="test doesn't currently work on the ROCm stack") def test_encoder_fw_bw( @@ -97,10 +106,20 @@ def test_encoder_fw_bw( scaling_type_input: ScalingType, scaling_type_weight: ScalingType, scaling_type_grad_output: ScalingType, + scaling_granularity: ScalingGranularity, ): # TODO(later): maybe add float16 back if it becomes important data_dtype = torch.bfloat16 + if scaling_granularity is ScalingGranularity.AXISWISE: + if ( + scaling_type_input != ScalingType.DYNAMIC or + scaling_type_weight != ScalingType.DYNAMIC or + scaling_type_grad_output != ScalingType.DYNAMIC or + data_dtype != torch.bfloat16 + ): + pytest.skip() + # LLaMa 3 70B shapes model_ref = ( FeedForward( @@ -119,24 +138,34 @@ def test_encoder_fw_bw( if scaling_type_input is ScalingType.STATIC: cast_config_input = CastConfig( scaling_type=scaling_type_input, + scaling_granularity=scaling_granularity, static_scale=torch.tensor([1.0], device="cuda"), ) else: - cast_config_input = CastConfig(scaling_type=scaling_type_input) + cast_config_input = CastConfig( + scaling_type=scaling_type_input, + scaling_granularity=scaling_granularity, + ) if scaling_type_weight is ScalingType.STATIC: cast_config_weight = CastConfig( scaling_type=scaling_type_weight, static_scale=torch.tensor([1.0], device="cuda"), ) else: - cast_config_weight = CastConfig(scaling_type=scaling_type_weight) + cast_config_weight = CastConfig( + scaling_type=scaling_type_weight, + scaling_granularity=scaling_granularity, + ) if scaling_type_grad_output is ScalingType.STATIC: cast_config_grad_output = CastConfig( scaling_type=scaling_type_grad_output, static_scale=torch.tensor([1.0], device="cuda"), ) else: - cast_config_grad_output = CastConfig(scaling_type=scaling_type_grad_output) + cast_config_grad_output = CastConfig( + scaling_type=scaling_type_grad_output, + scaling_granularity=scaling_granularity, + ) config = Float8LinearConfig( cast_config_input=cast_config_input, diff --git a/torchao/float8/config.py b/torchao/float8/config.py index b24b5ba749..4d82bd1118 100644 --- a/torchao/float8/config.py +++ b/torchao/float8/config.py @@ -37,6 +37,13 @@ class ScalingGranularity(enum.Enum): # size 1. AXISWISE = "axiswise" + def short_str(self): + if self is ScalingGranularity.TENSORWISE: + return "ten" + else: + assert self is ScalingGranularity.AXISWISE + return "axs" + @dataclass(frozen=True) class CastConfig: @@ -45,12 +52,16 @@ class CastConfig: """ scaling_type: ScalingType = ScalingType.DYNAMIC + scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE static_scale: Optional[torch.Tensor] = None def __post_init__(self): if self.scaling_type is ScalingType.STATIC: assert self.static_scale is not None, \ "static_scale must be specified for static scaling" + if self.scaling_granularity is ScalingGranularity.AXISWISE: + assert self.scaling_type is ScalingType.DYNAMIC, \ + "only dynamic scaling type is supported for axiswise scaling granularity" @dataclass(frozen=True) class DelayedScalingConfig: @@ -144,6 +155,27 @@ class Float8LinearConfig: # configuration, this field may move to per-tensor configs. delayed_scaling_config: DelayedScalingConfig = DelayedScalingConfig() + def __post_init__(self): + # float8 all-gather only supports tensorwise, in the future may support blockwise + if self.cast_config_weight.scaling_granularity != ScalingGranularity.TENSORWISE: + assert not self.enable_fsdp_float8_all_gather, \ + f"enable_fsdp_float8_all_gather only supports tensorwise scaling granularity, got {self.cast_config_weight.scaling_granularity}" + + # for now, axiswise granularity is all-or-nothing + # TODO(future PR): enable more granular setting per-gemm-input + has_any_axiswise_scaling = ( + self.cast_config_input.scaling_granularity is ScalingGranularity.AXISWISE or + self.cast_config_weight.scaling_granularity is ScalingGranularity.AXISWISE or + self.cast_config_grad_output.scaling_granularity is ScalingGranularity.AXISWISE + ) + has_all_axiswise_scaling = ( + self.cast_config_input.scaling_granularity is ScalingGranularity.AXISWISE and + self.cast_config_weight.scaling_granularity is ScalingGranularity.AXISWISE and + self.cast_config_grad_output.scaling_granularity is ScalingGranularity.AXISWISE + ) + if has_any_axiswise_scaling: + assert has_all_axiswise_scaling, \ + "for now, axiswise scaling must be enabled for either all casts or none of the casts" # If True, use 'fnuz' float8 types for calculations. # Currently, ROCm only supports fnuz variants. diff --git a/torchao/float8/float8_linear.py b/torchao/float8/float8_linear.py index cb0ff7afb0..5f87e82fe4 100644 --- a/torchao/float8/float8_linear.py +++ b/torchao/float8/float8_linear.py @@ -14,7 +14,7 @@ import torch -from torchao.float8.config import Float8LinearConfig, ScalingType +from torchao.float8.config import Float8LinearConfig, ScalingType, ScalingGranularity from torchao.float8.float8_scaling_utils import ( _maybe_initialize_amaxes_scales_for_float8_cast, @@ -42,11 +42,17 @@ ) -# this code was resurrected from https://github.com/pytorch-labs/torchao.float8/pull/128/files @torch._dynamo.allow_in_graph -class manual_float8_matmul(torch.autograd.Function): +class manual_float8_matmul_with_args_in_float8(torch.autograd.Function): """ Like torch.matmul, but with the arguments in float8 + + Note: this function requires all arguments to already be Float8Tensor objects, + which only supports tensorwise scaling granularity. The reason we didn't just make this + function support axiswise scaling granularity is because that would need very + careful testing of delayed scaling, as delayed scaling modifies buffers inplace. + + In the future we'll probably have to unify, just postponing that until a future PR. """ @staticmethod @@ -97,6 +103,133 @@ def backward(ctx, grad_output_fp8): return grad_input, grad_weight.t() +@torch._dynamo.allow_in_graph +class manual_float8_matmul_with_args_in_hp(torch.autograd.Function): + """ + Like torch.matmul, but with the arguments in high precision and the cast to float8 + defined inside of this function. + + Note: this function currently only supports dynamic scaling type and + axiswise granularity. We will have to unify this with other scaling types + and other granularities in a separate PR. + """ + + # TODO(this PR): types of inputs + @staticmethod + def forward( + ctx, + input_hp: torch.Tensor, + weight_hp_t: torch.Tensor, + linear_mm_config: LinearMMConfig, + input_scaling_granularity: ScalingGranularity, + weight_scaling_granularity: ScalingGranularity, + grad_output_scaling_granularity: ScalingGranularity, + ): + ctx.save_for_backward(input_hp, weight_hp_t) + ctx.linear_mm_config = linear_mm_config + ctx.input_scaling_granularity = input_scaling_granularity + ctx.weight_scaling_granularity = weight_scaling_granularity + ctx.grad_output_scaling_granularity = grad_output_scaling_granularity + + input_fp8 = hp_tensor_to_float8_dynamic( + input_hp, + e4m3_dtype, + linear_mm_config, + gemm_input_role=GemmInputRole.INPUT, + scaling_granularity=input_scaling_granularity, + axiswise_dim=-1, + ) + + weight_fp8_t = hp_tensor_to_float8_dynamic( + weight_hp_t, + e4m3_dtype, + linear_mm_config, + gemm_input_role=GemmInputRole.WEIGHT, + scaling_granularity=weight_scaling_granularity, + axiswise_dim=0, + ) + + # the reshapes are needed in order to make the shapes compatible with + # torch.mm + orig_shape = input_fp8.shape + input_fp8_reshaped = input_fp8.reshape(-1, orig_shape[-1]) + res_bits = torch.mm(input_fp8_reshaped, weight_fp8_t) + res_bits = res_bits.reshape(*orig_shape[:-1], res_bits.shape[-1]) + return res_bits + + @staticmethod + def backward(ctx, grad_output): + input_hp, weight_hp_t = ctx.saved_tensors + + # TODO scaling + + # the reshapes are needed in order to make the shapes compatible with + # torch.mm + grad_output_orig_shape = grad_output.shape + grad_output_reshaped = grad_output.reshape( + -1, grad_output_orig_shape[-1] + ) + + # + # calculate grad_input + # + + grad_output_reshaped_fp8_dim0 = hp_tensor_to_float8_dynamic( + grad_output_reshaped, + e5m2_dtype, + ctx.linear_mm_config, + gemm_input_role=GemmInputRole.GRAD_OUTPUT, + scaling_granularity=ctx.grad_output_scaling_granularity, + axiswise_dim=-1, + ) + weight_t_fp8_dim0 = hp_tensor_to_float8_dynamic( + weight_hp_t, + e4m3_dtype, + ctx.linear_mm_config, + gemm_input_role=GemmInputRole.WEIGHT, + scaling_granularity=ctx.weight_scaling_granularity, + axiswise_dim=1, # will be transposed + ) + + grad_input = torch.mm( + grad_output_reshaped_fp8_dim0, + weight_t_fp8_dim0.t(), + ) + grad_input = grad_input.reshape( + *grad_output_orig_shape[:-1], grad_input.shape[-1] + ) + + input_hp_orig_shape = input_hp.shape + input_hp_reshaped = input_hp.reshape(-1, input_hp_orig_shape[-1]) + + # + # calculate grad_weight + # + + grad_output_reshaped_fp8_dim1 = hp_tensor_to_float8_dynamic( + grad_output_reshaped, + e5m2_dtype, + ctx.linear_mm_config, + gemm_input_role=GemmInputRole.GRAD_OUTPUT, + scaling_granularity=ctx.grad_output_scaling_granularity, + axiswise_dim=0, # will be transposed + ) + input_reshaped_fp8_dim1 = hp_tensor_to_float8_dynamic( + input_hp_reshaped, + e4m3_dtype, + ctx.linear_mm_config, + gemm_input_role=GemmInputRole.INPUT, + scaling_granularity=ctx.input_scaling_granularity, + axiswise_dim=0, + ) + + grad_weight = torch.mm( + grad_output_reshaped_fp8_dim1.t(), + input_reshaped_fp8_dim1, + ) + + return grad_input, grad_weight.t(), None, None, None, None + class Float8Linear(torch.nn.Linear): """ @@ -289,7 +422,10 @@ def cast_input_to_float8( ) elif self.scaling_type_input is ScalingType.DYNAMIC: input_fp8 = hp_tensor_to_float8_dynamic( - input, e4m3_dtype, self.linear_mm_config + input, + e4m3_dtype, + self.linear_mm_config, + gemm_input_role=GemmInputRole.INPUT, ) else: assert self.scaling_type_input is ScalingType.STATIC @@ -395,13 +531,33 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: if self.has_any_delayed_scaling: self.float8_pre_forward(input) - input_fp8 = self.cast_input_to_float8(input, self.is_amax_initialized) - weight_fp8 = self.cast_weight_to_float8(self.weight, self.is_amax_initialized) + # TODO(this PR): reuse with config, make a property + has_all_axiswise_scaling = ( + self.config.cast_config_input.scaling_granularity is ScalingGranularity.AXISWISE and + self.config.cast_config_weight.scaling_granularity is ScalingGranularity.AXISWISE and + self.config.cast_config_grad_output.scaling_granularity is ScalingGranularity.AXISWISE + ) + + if not has_all_axiswise_scaling: + input_fp8 = self.cast_input_to_float8(input, self.is_amax_initialized) + weight_fp8 = self.cast_weight_to_float8(self.weight, self.is_amax_initialized) - output = manual_float8_matmul.apply(input_fp8, weight_fp8.t()) + output = manual_float8_matmul_with_args_in_float8.apply(input_fp8, weight_fp8.t()) - # Cast grad_output to float8_e5m2 during backward - output = self.cast_output_to_float8_in_bw(output) + # Cast grad_output to float8_e5m2 during backward + output = self.cast_output_to_float8_in_bw(output) + + else: + # for now, axiswise path is separate + # TODO(future PR): unify to support mix and match + output = manual_float8_matmul_with_args_in_hp.apply( + input, + self.weight.t(), + self.linear_mm_config, + self.config.cast_config_input.scaling_granularity, + self.config.cast_config_weight.scaling_granularity, + self.config.cast_config_grad_output.scaling_granularity, + ) if self.bias is not None: output = output + self.bias.to(output.dtype) @@ -410,13 +566,21 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: self.float8_post_forward() return output - def scaling_repr(self): - # add scaling settings without using too many characters + def scaling_type_repr(self): + # add scaling type settings without using too many characters # example: "i:del,w:del,go:dyn" return f"i:{self.scaling_type_input.short_str()},w:{self.scaling_type_weight.short_str()},go:{self.scaling_type_grad_output.short_str()}" + def scaling_granularity_repr(self): + # add scaling granularity settings without using too many characters + # example: "i:ten,w:ten,g:ten" or "i:axs,w:axs,g:axs" + gi = self.config.cast_config_input.scaling_granularity.short_str() + gw = self.config.cast_config_weight.scaling_granularity.short_str() + ggo = self.config.cast_config_grad_output.scaling_granularity.short_str() + return f"i:{gi},w:{gw},go:{ggo}" + def extra_repr(self): - s = f'{super().extra_repr()}, scaling="{self.scaling_repr()}"' + s = f'{super().extra_repr()}, scaling_type="{self.scaling_type_repr()}", scaling_granularity="{self.scaling_granularity_repr()}"' return s @classmethod diff --git a/torchao/float8/float8_ops.py b/torchao/float8/float8_ops.py index 1bf9faaa4c..b97d032113 100644 --- a/torchao/float8/float8_ops.py +++ b/torchao/float8/float8_ops.py @@ -43,12 +43,9 @@ def decorator(func): [ aten.view.default, aten._unsafe_view.default, - aten.t.default, aten.as_strided.default, aten.clone.default, - aten.detach.default, aten.slice.Tensor, - aten.transpose.int, aten.fill_.Scalar, aten.reshape.default, ] @@ -65,13 +62,30 @@ def float8_desugar_op(aten_op, args, kwargs=None): ) +@implements( + [ + aten.detach.default, + ] +) +def float8_desugar_data_and_scale_op(aten_op, args, kwargs=None): + new_data = aten_op(args[0]._data, *args[1:], **kwargs) + new_scale = aten_op(args[0]._scale, *args[1:], **kwargs) + return Float8Tensor( + new_data, + new_scale, + args[0]._orig_dtype, + args[0]._linear_mm_config, + args[0]._gemm_input_role, + ) + + @implements( [ aten.t.default, aten.transpose.int, ] ) -def float8_desugar_data_and_scale(aten_op, args, kwargs=None): +def float8_transpose(aten_op, args, kwargs=None): new_data = aten_op(args[0]._data, *args[1:], **kwargs) new_scale = aten_op(args[0]._scale, *args[1:], **kwargs) From f15c2a02b0822784f718c51480c1fbe422ea5bb8 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Mon, 23 Sep 2024 09:54:13 -0700 Subject: [PATCH 3/8] Update [ghstack-poisoned] --- torchao/float8/float8_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torchao/float8/float8_utils.py b/torchao/float8/float8_utils.py index 55e520f8ca..e79cf27d88 100644 --- a/torchao/float8/float8_utils.py +++ b/torchao/float8/float8_utils.py @@ -100,7 +100,7 @@ def amax_history_to_scale_stack( @torch.no_grad() def tensor_to_amax( - x: torch.Tensor, + x: torch.Tensor, reduce_amax: bool = False, scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE, axiswise_dim: Optional[int] = None, @@ -110,7 +110,7 @@ def tensor_to_amax( else: assert scaling_granularity is ScalingGranularity.AXISWISE, "unsupported" assert axiswise_dim is not None, "unsupported" - amax = torch.amax(torch.abs(x), dim=axiswise_dim, keepdim=True) + amax = torch.amax(torch.abs(x), dim=axiswise_dim, keepdim=True) # If the user asked for distributed reduction, do it. # If the user did not ask for it, assume that it will @@ -123,8 +123,8 @@ def tensor_to_amax( @torch.no_grad() def tensor_to_scale( - x: torch.Tensor, - float8_dtype: torch.dtype, + x: torch.Tensor, + float8_dtype: torch.dtype, reduce_amax: bool = False, scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE, axiswise_dim: Optional[int] = None, From 9150b4f2b2f0e13d9cd79f876067b9907037841c Mon Sep 17 00:00:00 2001 From: vasiliy Date: Mon, 23 Sep 2024 10:55:16 -0700 Subject: [PATCH 4/8] Update [ghstack-poisoned] --- test/float8/test_base.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/float8/test_base.py b/test/float8/test_base.py index e6dd67951c..d3f48a7153 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -57,6 +57,7 @@ is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) +is_cuda_9_0 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0) def bitwise_identical(a: Float8Tensor, b: Float8Tensor) -> bool: assert torch.all(a._data == b._data).item(), "scales are not identical" @@ -210,6 +211,8 @@ def test_axiswise_reshape(self): a_fp8_d2_r2 = a_fp8_d2.reshape(3, -1) @pytest.mark.parametrize("a_shape", [(16, 32), (2, 16, 32), (1, 2, 16, 32)]) + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + @unittest.skipIf(not is_cuda_9_0, "Requires CUDA capability >= 9.0") def test_axiswise_gemm(self, a_shape): a = torch.randn(*a_shape, dtype=torch.bfloat16, device="cuda") b = torch.randn(64, 32, dtype=torch.bfloat16, device="cuda") From 459e92c434378eb4d8b70d69299af94a69e4d45c Mon Sep 17 00:00:00 2001 From: vasiliy Date: Mon, 23 Sep 2024 12:25:42 -0700 Subject: [PATCH 5/8] Update [ghstack-poisoned] --- test/float8/test_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/float8/test_base.py b/test/float8/test_base.py index d3f48a7153..ebc33f0372 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -159,7 +159,7 @@ def test_axiswise_dynamic_cast(self, shape, dim_name): assert sqnr >= 25.0 def test_axiswise_reshape(self): - a = torch.randn(3, 5, 7, dtype=torch.bfloat16, device="cuda") + a = torch.randn(3, 5, 7, dtype=torch.bfloat16) linear_mm_config = LinearMMConfig() # if we scale across dim0, we can only reshape to [3, -1] From 732b231dffe87bad13a76ccb7cc859b8714a6239 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Mon, 23 Sep 2024 13:22:08 -0700 Subject: [PATCH 6/8] Update [ghstack-poisoned] --- test/float8/test_compile.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/float8/test_compile.py b/test/float8/test_compile.py index 74cc6faa5c..eacd317b1a 100644 --- a/test/float8/test_compile.py +++ b/test/float8/test_compile.py @@ -137,7 +137,7 @@ def is_supported( scaling_type_weight != ScalingType.DYNAMIC or scaling_type_grad_output != ScalingType.DYNAMIC or dtype != torch.bfloat16 or - (not IS_H100) + (not is_H100) ): return False return True From fc8d4efc031c16453f8430cbb7a172a1121a819a Mon Sep 17 00:00:00 2001 From: vasiliy Date: Wed, 2 Oct 2024 16:06:33 -0700 Subject: [PATCH 7/8] Update [ghstack-poisoned] --- torchao/ops.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/torchao/ops.py b/torchao/ops.py index 99e19dbbd4..79c02dfd85 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -5,11 +5,10 @@ lib = torch.library.Library("torchao", "FRAGMENT") -# TODO(before land): undo this, this is to work around https://github.com/pytorch/ao/issues/991 -# lib.define("quant_llm_linear(int EXPONENT, int MANTISSA, Tensor _in_feats, Tensor _weights, Tensor _scales, int splitK) -> Tensor") -# lib.define("unpack_tensor_core_tiled_layout(Tensor packed_w, int inner_k_tiles) -> Tensor") -# lib.define("dequantize_tensor_core_tiled_layout(Tensor packed_w, Tensor scales_and_zeros, int group_size, int inner_k_tiles) -> Tensor") -# lib.define("marlin_24_gemm(Tensor x, Tensor weight_marlin, Tensor meta, Tensor s, Tensor workspace, int bits, int size_m, int size_n, int size_k) -> Tensor") +lib.define("quant_llm_linear(int EXPONENT, int MANTISSA, Tensor _in_feats, Tensor _weights, Tensor _scales, int splitK) -> Tensor") +lib.define("unpack_tensor_core_tiled_layout(Tensor packed_w, int inner_k_tiles) -> Tensor") +lib.define("dequantize_tensor_core_tiled_layout(Tensor packed_w, Tensor scales_and_zeros, int group_size, int inner_k_tiles) -> Tensor") +lib.define("marlin_24_gemm(Tensor x, Tensor weight_marlin, Tensor meta, Tensor s, Tensor workspace, int bits, int size_m, int size_n, int size_k) -> Tensor") def register_custom_op(name): From ac6f768ba2fbc189381967e35bd0796d60256a06 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Wed, 2 Oct 2024 16:09:58 -0700 Subject: [PATCH 8/8] Update [ghstack-poisoned] --- torchao/float8/float8_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/float8/float8_utils.py b/torchao/float8/float8_utils.py index acb814404c..b6f42c5081 100644 --- a/torchao/float8/float8_utils.py +++ b/torchao/float8/float8_utils.py @@ -8,9 +8,9 @@ import torch import torch.distributed as dist -from torchao.float8.config import ScalingGranularity import torchao.float8.config as config +from torchao.float8.config import ScalingGranularity # Helpful visualizer for debugging (only supports fp32): # https://www.h-schmidt.net/FloatConverter/IEEE754.html