From 8cbcf50b0c518812066645c05c3fce1b196676b9 Mon Sep 17 00:00:00 2001 From: shuw Date: Wed, 22 May 2024 22:56:05 +0000 Subject: [PATCH 1/6] Support fp8 direct quantization Call quantized_einsum --- praxis/layers/injection/fp8_nvidia_gpu.py | 29 ++++++++++++++++--- .../layers/injection/fp8_nvidia_gpu_test.py | 11 ++++--- 2 files changed, 32 insertions(+), 8 deletions(-) diff --git a/praxis/layers/injection/fp8_nvidia_gpu.py b/praxis/layers/injection/fp8_nvidia_gpu.py index 05d79b0f..40d74ead 100644 --- a/praxis/layers/injection/fp8_nvidia_gpu.py +++ b/praxis/layers/injection/fp8_nvidia_gpu.py @@ -70,6 +70,7 @@ class Fp8EinsumOp(base_layer.BaseLayer): """Wrapper around jnp.einsum used in standard Pax layers.""" amax_history_length: int = 1024 + use_direct_quant: bool = False def setup(self) -> None: scale_args, amax_history_args = _get_fp8_args( @@ -128,9 +129,7 @@ def quantized_einsum( return y, x_qdq return y - def __call__( - self, equation: str, *args: JTensor - ) -> Union[JTensor, tuple[JTensor, JTensor]]: + def __call__(self, equation: str, *args: JTensor) -> JTensor: assert len(args) == 2 x = args[0] k = args[1] @@ -141,7 +140,29 @@ def __call__( ), f'k dtype has to be {comp_dtype}, but got {k.dtype}' x = jnp.asarray(x, comp_dtype) - y = self.quantized_einsum(equation, x, k, return_quantized_x=False) + if self.use_direct_quant: + def _quantized_dot_general( + lhs, rhs, dimension_numbers, precision=None, + preferred_element_type=None + ): + theta = self.theta + return fp8_ops.q_dot_dq( + lhs, + rhs, + lhs_scale=theta.input_scale, + rhs_scale=theta.kernel_scale, + out_grad_scale=theta.output_grad_scale, + lhs_amax_history=theta.input_amax_history, + rhs_amax_history=theta.kernel_amax_history, + out_grad_amax_history=theta.output_grad_amax_history, + compute_dtype=comp_dtype, + dimension_numbers=dimension_numbers, + precision=precision, + preferred_element_type=preferred_element_type, + ) + y = jnp.einsum(equation, x, k, _dot_general=_quantized_dot_general) + else: + y = self.quantized_einsum(equation, x, k, return_quantized_x=False) return y diff --git a/praxis/layers/injection/fp8_nvidia_gpu_test.py b/praxis/layers/injection/fp8_nvidia_gpu_test.py index 3aaa0258..3c232103 100644 --- a/praxis/layers/injection/fp8_nvidia_gpu_test.py +++ b/praxis/layers/injection/fp8_nvidia_gpu_test.py @@ -17,7 +17,7 @@ from functools import partial -from absl.testing import absltest +from absl.testing import absltest, parameterized from flax.linen.fp8_ops import quantize_dequantize import jax from jax import numpy as jnp @@ -30,9 +30,10 @@ PARAMS = base_layer.PARAMS -class Fp8LinearsTest(test_utils.TestCase): +class Fp8LinearsTest(test_utils.TestCase, parameterized.TestCase): - def test_fp8_einsum_injection(self): + @parameterized.parameters([True, False]) + def test_fp8_einsum_injection(self, use_direct_quant): # Used to cast the inputs to be representable in FP8, so that the difference # of the results from the original gemm and fp8 gemm is small. cast_to_representable = partial( @@ -100,7 +101,9 @@ def _train(variables, x): } output1a, output1b = run(None, expected_shapes_original) - einsum_tpl = pax_fiddle.Config(fp8_ops.Fp8EinsumOp) + einsum_tpl = pax_fiddle.Config( + fp8_ops.Fp8EinsumOp, use_direct_quant=use_direct_quant + ) output2a, output2b = run(einsum_tpl, expected_shapes_new) dw1, dw2 = output1b[0][PARAMS]['w'], output2b[0][PARAMS]['w'] dx1, dx2 = output1b[1], output2b[1] From d83fc960bd9e82b07384778f86880cff0c8d0b85 Mon Sep 17 00:00:00 2001 From: Shu Wang Date: Wed, 18 Sep 2024 15:37:55 -0500 Subject: [PATCH 2/6] set default to be direct quantization --- praxis/layers/injection/fp8_nvidia_gpu.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/praxis/layers/injection/fp8_nvidia_gpu.py b/praxis/layers/injection/fp8_nvidia_gpu.py index 40d74ead..d51df425 100644 --- a/praxis/layers/injection/fp8_nvidia_gpu.py +++ b/praxis/layers/injection/fp8_nvidia_gpu.py @@ -70,7 +70,7 @@ class Fp8EinsumOp(base_layer.BaseLayer): """Wrapper around jnp.einsum used in standard Pax layers.""" amax_history_length: int = 1024 - use_direct_quant: bool = False + use_direct_quant: bool = True def setup(self) -> None: scale_args, amax_history_args = _get_fp8_args( From f4e3e7c2b22b84330b2aa4638c386d4be35ce149 Mon Sep 17 00:00:00 2001 From: shuw Date: Mon, 23 Sep 2024 20:08:08 -0700 Subject: [PATCH 3/6] directed quant pattern for gated einsum --- praxis/layers/injection/fp8_nvidia_gpu.py | 114 +++++++++++++++++----- 1 file changed, 92 insertions(+), 22 deletions(-) diff --git a/praxis/layers/injection/fp8_nvidia_gpu.py b/praxis/layers/injection/fp8_nvidia_gpu.py index d51df425..cb271fe8 100644 --- a/praxis/layers/injection/fp8_nvidia_gpu.py +++ b/praxis/layers/injection/fp8_nvidia_gpu.py @@ -202,29 +202,99 @@ def __call__(self, equation: str, *args: JTensor) -> tuple[JTensor, JTensor]: ), f'k dtype has to be {comp_dtype}, but got {k.dtype} and {k_gated.dtype}' x = jnp.asarray(x, comp_dtype) - y, x_qdq = self.quantized_einsum(equation, x, k, return_quantized_x=True) - theta = self.theta - k_gated_qdq = fp8_ops.in_qdq( - comp_dtype, - jnp.float8_e4m3fn, - k_gated, - theta.kernel_scale_gated, - theta.kernel_amax_history_gated, - ) - y_gated_qdq = jnp.einsum( - equation, - x_qdq, - k_gated_qdq, - _dot_general=fp8_ops.dot_general_with_precision, - ) - y_gated = fp8_ops.out_qdq( - comp_dtype, - jnp.float8_e5m2, - y_gated_qdq, - theta.output_grad_scale_gated, - theta.output_grad_amax_history_gated, - ) + if self.use_direct_quant: + q_x, new_input_scale = fp8_ops.in_q(comp_dtype, jnp.float8_e4m3fn, x, theta.input_scale, theta.input_amax_history) + # def create_one_sided_q_dot_dq(comp_dtype, q_x, new_input_scale, kernel_scale, out_grad_scale, kernel_amax_history, out_grad_amax_history): + # def _quantized_one_sided_dot_general( + # lhs, rhs, dimension_numbers, precision=None, + # preferred_element_type=None + # ): + # return fp8_ops.one_sided_q_dot_dq( + # lhs=lhs, + # q_lhs=q_x, + # lhs_scale=new_input_scale, + # rhs=rhs, + # rhs_scale=kernel_scale, + # out_grad_scale=out_grad_scale, + # rhs_amax_history=kernel_amax_history, + # out_grad_amax_history=out_grad_amax_history, + # compute_dtype=comp_dtype, + # dimension_numbers=dimension_numbers, + # precision=precision, + # preferred_element_type=preferred_element_type + # ) + # return _quantized_one_sided_dot_general + + # _one_sided_quantized_dot_general = create_one_sided_q_dot_dq( + # comp_dtype, q_x, new_input_scale, + # theta.kernel_scale, theta.out_grad_scale, + # theta.kernel_amax_history, theta.out_grad_amax_history + # ) + + # _one_sided_quantized_dot_general_gated = create_one_sided_q_dot_dq( + # comp_dtype, q_x, new_input_scale, + # theta.kernel_scale_gated, theta.out_grad_scale_gated, + # theta.kernel_amax_history_gated, theta.out_grad_amax_history_gated + # ) + def one_sided_q_dot_dq(comp_dtype, q_x, new_input_scale, kernel_scale, out_grad_scale, kernel_amax_history, out_grad_amax_history): + def decorator(func): + @wraps(func) + def wrapper(lhs, rhs, dimension_numbers, precision=None, preferred_element_type=None): + return fp8_ops.one_sided_q_dot_dq( + lhs=lhs, + q_lhs=q_x, + lhs_scale=new_input_scale, + rhs=rhs, + rhs_scale=kernel_scale, + out_grad_scale=out_grad_scale, + rhs_amax_history=kernel_amax_history, + out_grad_amax_history=out_grad_amax_history, + compute_dtype=comp_dtype, + dimension_numbers=dimension_numbers, + precision=precision, + preferred_element_type=preferred_element_type + ) + return wrapper + return decorator + common_args = (comp_dtype, q_x, new_input_scale) + main_fp8_metas = ( + theta.kernel_scale, theta.out_grad_scale, + theta.kernel_amax_history, theta.out_grad_amax_history + ) + gated_fp8_metas = ( + theta.kernel_scale_gated, theta.out_grad_scale_gated, + theta.kernel_amax_history_gated, theta.out_grad_amax_history_gated + ) + _dot_general_main = one_sided_q_dot_dq(*common_args, *main_fp8_metas) + _dot_general_gated = one_sided_q_dot_dq(*common_args, *gated_fp8_metas) + + y = jnp.einsum(equation, x, k, _dot_general=_dot_general_main) + y_gated = jnp.einsum(equation, x, k_gated, _dot_general=_dot_general_gated) + else: + y, x_qdq = self.quantized_einsum( + equation, x, k, return_quantized_x=True + ) + k_gated_qdq = fp8_ops.in_qdq( + comp_dtype, + jnp.float8_e4m3fn, + k_gated, + theta.kernel_scale_gated, + theta.kernel_amax_history_gated, + ) + y_gated_qdq = jnp.einsum( + equation, + x_qdq, + k_gated_qdq, + _dot_general=fp8_ops.dot_general_with_precision, + ) + y_gated = fp8_ops.out_qdq( + comp_dtype, + jnp.float8_e5m2, + y_gated_qdq, + theta.output_grad_scale_gated, + theta.output_grad_amax_history_gated, + ) return y, y_gated From e81a1489ea8290a2240534a2619cc0f4c86d7b5f Mon Sep 17 00:00:00 2001 From: shuw Date: Tue, 24 Sep 2024 11:53:42 -0700 Subject: [PATCH 4/6] pass e2e --- praxis/layers/injection/fp8_nvidia_gpu.py | 75 +++++------------------ 1 file changed, 15 insertions(+), 60 deletions(-) diff --git a/praxis/layers/injection/fp8_nvidia_gpu.py b/praxis/layers/injection/fp8_nvidia_gpu.py index cb271fe8..ed690a19 100644 --- a/praxis/layers/injection/fp8_nvidia_gpu.py +++ b/praxis/layers/injection/fp8_nvidia_gpu.py @@ -206,72 +206,27 @@ def __call__(self, equation: str, *args: JTensor) -> tuple[JTensor, JTensor]: if self.use_direct_quant: q_x, new_input_scale = fp8_ops.in_q(comp_dtype, jnp.float8_e4m3fn, x, theta.input_scale, theta.input_amax_history) - # def create_one_sided_q_dot_dq(comp_dtype, q_x, new_input_scale, kernel_scale, out_grad_scale, kernel_amax_history, out_grad_amax_history): - # def _quantized_one_sided_dot_general( - # lhs, rhs, dimension_numbers, precision=None, - # preferred_element_type=None - # ): - # return fp8_ops.one_sided_q_dot_dq( - # lhs=lhs, - # q_lhs=q_x, - # lhs_scale=new_input_scale, - # rhs=rhs, - # rhs_scale=kernel_scale, - # out_grad_scale=out_grad_scale, - # rhs_amax_history=kernel_amax_history, - # out_grad_amax_history=out_grad_amax_history, - # compute_dtype=comp_dtype, - # dimension_numbers=dimension_numbers, - # precision=precision, - # preferred_element_type=preferred_element_type - # ) - # return _quantized_one_sided_dot_general - - # _one_sided_quantized_dot_general = create_one_sided_q_dot_dq( - # comp_dtype, q_x, new_input_scale, - # theta.kernel_scale, theta.out_grad_scale, - # theta.kernel_amax_history, theta.out_grad_amax_history - # ) - - # _one_sided_quantized_dot_general_gated = create_one_sided_q_dot_dq( - # comp_dtype, q_x, new_input_scale, - # theta.kernel_scale_gated, theta.out_grad_scale_gated, - # theta.kernel_amax_history_gated, theta.out_grad_amax_history_gated - # ) - def one_sided_q_dot_dq(comp_dtype, q_x, new_input_scale, kernel_scale, out_grad_scale, kernel_amax_history, out_grad_amax_history): - def decorator(func): - @wraps(func) - def wrapper(lhs, rhs, dimension_numbers, precision=None, preferred_element_type=None): - return fp8_ops.one_sided_q_dot_dq( - lhs=lhs, - q_lhs=q_x, - lhs_scale=new_input_scale, - rhs=rhs, - rhs_scale=kernel_scale, - out_grad_scale=out_grad_scale, - rhs_amax_history=kernel_amax_history, - out_grad_amax_history=out_grad_amax_history, - compute_dtype=comp_dtype, - dimension_numbers=dimension_numbers, - precision=precision, - preferred_element_type=preferred_element_type - ) - return wrapper - return decorator common_args = (comp_dtype, q_x, new_input_scale) main_fp8_metas = ( - theta.kernel_scale, theta.out_grad_scale, - theta.kernel_amax_history, theta.out_grad_amax_history + theta.kernel_scale, theta.output_grad_scale, + theta.kernel_amax_history, theta.output_grad_amax_history ) gated_fp8_metas = ( - theta.kernel_scale_gated, theta.out_grad_scale_gated, - theta.kernel_amax_history_gated, theta.out_grad_amax_history_gated + theta.kernel_scale_gated, theta.output_grad_scale_gated, + theta.kernel_amax_history_gated, theta.output_grad_amax_history_gated ) - _dot_general_main = one_sided_q_dot_dq(*common_args, *main_fp8_metas) - _dot_general_gated = one_sided_q_dot_dq(*common_args, *gated_fp8_metas) - y = jnp.einsum(equation, x, k, _dot_general=_dot_general_main) - y_gated = jnp.einsum(equation, x, k_gated, _dot_general=_dot_general_gated) + @fp8_ops.one_sided_q_dot_dq_config(*common_args, *main_fp8_metas) + def _one_sided_quantized_dot_general(lhs, rhs, dimension_numbers, precision=None, preferred_element_type=None): + pass + + @fp8_ops.one_sided_q_dot_dq_config(*common_args, *gated_fp8_metas) + def _one_sided_quantized_dot_general_gated(lhs, rhs, dimension_numbers, precision=None, preferred_element_type=None): + pass + + y = jnp.einsum(equation, x, k, _dot_general=_one_sided_quantized_dot_general) + y_gated = jnp.einsum(equation, x, k_gated, _dot_general=_one_sided_quantized_dot_general_gated) + else: y, x_qdq = self.quantized_einsum( equation, x, k, return_quantized_x=True From 729b6da141cd48bab519d5744e595f6cd00af227 Mon Sep 17 00:00:00 2001 From: shuw Date: Tue, 24 Sep 2024 12:17:50 -0700 Subject: [PATCH 5/6] all in decorator --- praxis/layers/injection/fp8_nvidia_gpu.py | 32 +++++++++-------------- 1 file changed, 12 insertions(+), 20 deletions(-) diff --git a/praxis/layers/injection/fp8_nvidia_gpu.py b/praxis/layers/injection/fp8_nvidia_gpu.py index ed690a19..7630fdf8 100644 --- a/praxis/layers/injection/fp8_nvidia_gpu.py +++ b/praxis/layers/injection/fp8_nvidia_gpu.py @@ -139,27 +139,19 @@ def __call__(self, equation: str, *args: JTensor) -> JTensor: k.dtype == comp_dtype ), f'k dtype has to be {comp_dtype}, but got {k.dtype}' x = jnp.asarray(x, comp_dtype) - + if self.use_direct_quant: - def _quantized_dot_general( - lhs, rhs, dimension_numbers, precision=None, - preferred_element_type=None - ): - theta = self.theta - return fp8_ops.q_dot_dq( - lhs, - rhs, - lhs_scale=theta.input_scale, - rhs_scale=theta.kernel_scale, - out_grad_scale=theta.output_grad_scale, - lhs_amax_history=theta.input_amax_history, - rhs_amax_history=theta.kernel_amax_history, - out_grad_amax_history=theta.output_grad_amax_history, - compute_dtype=comp_dtype, - dimension_numbers=dimension_numbers, - precision=precision, - preferred_element_type=preferred_element_type, - ) + theta = self.theta + args = ( + theta.input_scale, theta.kernel_scale, theta.output_grad_scale, + theta.input_amax_history, theta.kernel_amax_history, theta.output_grad_amax_history, + comp_dtype + ) + + @fp8_ops.q_dot_dq_config(*args) + def _quantized_dot_general(lhs, rhs, dimension_numbers, precision=None, preferred_element_type=None): + pass + y = jnp.einsum(equation, x, k, _dot_general=_quantized_dot_general) else: y = self.quantized_einsum(equation, x, k, return_quantized_x=False) From 4c06ebcef9b23b76f64c6c2993ce82f0b102e009 Mon Sep 17 00:00:00 2001 From: shuw Date: Wed, 25 Sep 2024 08:58:52 -0700 Subject: [PATCH 6/6] Support direct quantization for gated einsum --- praxis/layers/injection/fp8_nvidia_gpu.py | 102 +++++++++++++++++----- 1 file changed, 80 insertions(+), 22 deletions(-) diff --git a/praxis/layers/injection/fp8_nvidia_gpu.py b/praxis/layers/injection/fp8_nvidia_gpu.py index 7630fdf8..05b860fb 100644 --- a/praxis/layers/injection/fp8_nvidia_gpu.py +++ b/praxis/layers/injection/fp8_nvidia_gpu.py @@ -141,23 +141,57 @@ def __call__(self, equation: str, *args: JTensor) -> JTensor: x = jnp.asarray(x, comp_dtype) if self.use_direct_quant: - theta = self.theta - args = ( - theta.input_scale, theta.kernel_scale, theta.output_grad_scale, - theta.input_amax_history, theta.kernel_amax_history, theta.output_grad_amax_history, - comp_dtype - ) - - @fp8_ops.q_dot_dq_config(*args) - def _quantized_dot_general(lhs, rhs, dimension_numbers, precision=None, preferred_element_type=None): - pass - + def _quantized_dot_general( + lhs, rhs, dimension_numbers, precision=None, + preferred_element_type=None + ): + theta = self.theta + return fp8_ops.q_dot_dq( + lhs, + rhs, + lhs_scale=theta.input_scale, + rhs_scale=theta.kernel_scale, + out_grad_scale=theta.output_grad_scale, + lhs_amax_history=theta.input_amax_history, + rhs_amax_history=theta.kernel_amax_history, + out_grad_amax_history=theta.output_grad_amax_history, + compute_dtype=comp_dtype, + dimension_numbers=dimension_numbers, + precision=precision, + preferred_element_type=preferred_element_type, + ) y = jnp.einsum(equation, x, k, _dot_general=_quantized_dot_general) else: y = self.quantized_einsum(equation, x, k, return_quantized_x=False) return y +# This decorator wraps a function to perform quantized dot product. +# It prepares the arguments for quantized_dot, including the pre-quantized input, +# scales, and amax histories. This allows for efficient FP8 matrix multiplication while +# managing quantization parameters. +def quantized_dot_config( + compute_dtype, q_lhs, lhs_scale, q_rhs, rhs_scale, out_grad_scale, + out_grad_amax_history +): + def decorator(func): + def wrapper(lhs, rhs, dimension_numbers, precision=None, preferred_element_type=None): + return fp8_ops.quantized_dot( + lhs=lhs, + q_lhs=q_lhs, + lhs_scale=lhs_scale, + rhs=rhs, + q_rhs=q_rhs, + rhs_scale=rhs_scale, + out_grad_scale=out_grad_scale, + out_grad_amax_history=out_grad_amax_history, + compute_dtype=compute_dtype, + dimension_numbers=dimension_numbers, + precision=precision, + preferred_element_type=preferred_element_type + ) + return wrapper + return decorator class Fp8EinsumGatedOp(Fp8EinsumOp): """Wrapper around two jnp.einsum for gated FFN.""" @@ -197,28 +231,52 @@ def __call__(self, equation: str, *args: JTensor) -> tuple[JTensor, JTensor]: theta = self.theta if self.use_direct_quant: - q_x, new_input_scale = fp8_ops.in_q(comp_dtype, jnp.float8_e4m3fn, x, theta.input_scale, theta.input_amax_history) + q_x, new_input_scale = fp8_ops.in_q( + comp_dtype, jnp.float8_e4m3fn, x, theta.input_scale, theta.input_amax_history + ) + q_k, new_kernel_scale = fp8_ops.in_q( + comp_dtype, jnp.float8_e4m3fn, k, theta.kernel_scale, theta.kernel_amax_history + ) + q_k_gated, new_kernel_scale_gated = fp8_ops.in_q( + comp_dtype, jnp.float8_e4m3fn, k_gated, theta.kernel_scale_gated, theta.kernel_amax_history_gated + ) common_args = (comp_dtype, q_x, new_input_scale) main_fp8_metas = ( - theta.kernel_scale, theta.output_grad_scale, - theta.kernel_amax_history, theta.output_grad_amax_history + q_k, + new_kernel_scale, + theta.output_grad_scale, + theta.output_grad_amax_history ) gated_fp8_metas = ( - theta.kernel_scale_gated, theta.output_grad_scale_gated, - theta.kernel_amax_history_gated, theta.output_grad_amax_history_gated + q_k_gated, + new_kernel_scale_gated, + theta.output_grad_scale_gated, + theta.output_grad_amax_history_gated ) - @fp8_ops.one_sided_q_dot_dq_config(*common_args, *main_fp8_metas) - def _one_sided_quantized_dot_general(lhs, rhs, dimension_numbers, precision=None, preferred_element_type=None): + @quantized_dot_config(*common_args, *main_fp8_metas) + def _quantized_dot_general(lhs, rhs, dimension_numbers, precision=None, preferred_element_type=None): pass - @fp8_ops.one_sided_q_dot_dq_config(*common_args, *gated_fp8_metas) - def _one_sided_quantized_dot_general_gated(lhs, rhs, dimension_numbers, precision=None, preferred_element_type=None): + @quantized_dot_config(*common_args, *gated_fp8_metas) + def _quantized_dot_general_gated(lhs, rhs, dimension_numbers, precision=None, preferred_element_type=None): pass - y = jnp.einsum(equation, x, k, _dot_general=_one_sided_quantized_dot_general) - y_gated = jnp.einsum(equation, x, k_gated, _dot_general=_one_sided_quantized_dot_general_gated) + y = jnp.einsum(equation, x, k, _dot_general=_quantized_dot_general) + y_gated = jnp.einsum(equation, x, k_gated, _dot_general=_quantized_dot_general_gated) + y = out_dq( + dq_type=x.dtype, + lhs_scale=new_input_scale, + rhs_scale=new_kernel_scale, + out=y + ) + y_gated = out_dq( + dq_type=x.dtype, + lhs_scale=new_input_scale, + rhs_scale=new_kernel_scale_gated, + out=y + ) else: y, x_qdq = self.quantized_einsum( equation, x, k, return_quantized_x=True