diff --git a/praxis/layers/injection/fp8_nvidia_gpu.py b/praxis/layers/injection/fp8_nvidia_gpu.py index 05d79b0f..05b860fb 100644 --- a/praxis/layers/injection/fp8_nvidia_gpu.py +++ b/praxis/layers/injection/fp8_nvidia_gpu.py @@ -70,6 +70,7 @@ class Fp8EinsumOp(base_layer.BaseLayer): """Wrapper around jnp.einsum used in standard Pax layers.""" amax_history_length: int = 1024 + use_direct_quant: bool = True def setup(self) -> None: scale_args, amax_history_args = _get_fp8_args( @@ -128,9 +129,7 @@ def quantized_einsum( return y, x_qdq return y - def __call__( - self, equation: str, *args: JTensor - ) -> Union[JTensor, tuple[JTensor, JTensor]]: + def __call__(self, equation: str, *args: JTensor) -> JTensor: assert len(args) == 2 x = args[0] k = args[1] @@ -140,11 +139,59 @@ def __call__( k.dtype == comp_dtype ), f'k dtype has to be {comp_dtype}, but got {k.dtype}' x = jnp.asarray(x, comp_dtype) - - y = self.quantized_einsum(equation, x, k, return_quantized_x=False) + + if self.use_direct_quant: + def _quantized_dot_general( + lhs, rhs, dimension_numbers, precision=None, + preferred_element_type=None + ): + theta = self.theta + return fp8_ops.q_dot_dq( + lhs, + rhs, + lhs_scale=theta.input_scale, + rhs_scale=theta.kernel_scale, + out_grad_scale=theta.output_grad_scale, + lhs_amax_history=theta.input_amax_history, + rhs_amax_history=theta.kernel_amax_history, + out_grad_amax_history=theta.output_grad_amax_history, + compute_dtype=comp_dtype, + dimension_numbers=dimension_numbers, + precision=precision, + preferred_element_type=preferred_element_type, + ) + y = jnp.einsum(equation, x, k, _dot_general=_quantized_dot_general) + else: + y = self.quantized_einsum(equation, x, k, return_quantized_x=False) return y +# This decorator wraps a function to perform quantized dot product. +# It prepares the arguments for quantized_dot, including the pre-quantized input, +# scales, and amax histories. This allows for efficient FP8 matrix multiplication while +# managing quantization parameters. +def quantized_dot_config( + compute_dtype, q_lhs, lhs_scale, q_rhs, rhs_scale, out_grad_scale, + out_grad_amax_history +): + def decorator(func): + def wrapper(lhs, rhs, dimension_numbers, precision=None, preferred_element_type=None): + return fp8_ops.quantized_dot( + lhs=lhs, + q_lhs=q_lhs, + lhs_scale=lhs_scale, + rhs=rhs, + q_rhs=q_rhs, + rhs_scale=rhs_scale, + out_grad_scale=out_grad_scale, + out_grad_amax_history=out_grad_amax_history, + compute_dtype=compute_dtype, + dimension_numbers=dimension_numbers, + precision=precision, + preferred_element_type=preferred_element_type + ) + return wrapper + return decorator class Fp8EinsumGatedOp(Fp8EinsumOp): """Wrapper around two jnp.einsum for gated FFN.""" @@ -181,29 +228,78 @@ def __call__(self, equation: str, *args: JTensor) -> tuple[JTensor, JTensor]: ), f'k dtype has to be {comp_dtype}, but got {k.dtype} and {k_gated.dtype}' x = jnp.asarray(x, comp_dtype) - y, x_qdq = self.quantized_einsum(equation, x, k, return_quantized_x=True) - theta = self.theta - k_gated_qdq = fp8_ops.in_qdq( - comp_dtype, - jnp.float8_e4m3fn, - k_gated, - theta.kernel_scale_gated, - theta.kernel_amax_history_gated, - ) - y_gated_qdq = jnp.einsum( - equation, - x_qdq, - k_gated_qdq, - _dot_general=fp8_ops.dot_general_with_precision, - ) - y_gated = fp8_ops.out_qdq( - comp_dtype, - jnp.float8_e5m2, - y_gated_qdq, - theta.output_grad_scale_gated, - theta.output_grad_amax_history_gated, - ) + if self.use_direct_quant: + q_x, new_input_scale = fp8_ops.in_q( + comp_dtype, jnp.float8_e4m3fn, x, theta.input_scale, theta.input_amax_history + ) + q_k, new_kernel_scale = fp8_ops.in_q( + comp_dtype, jnp.float8_e4m3fn, k, theta.kernel_scale, theta.kernel_amax_history + ) + q_k_gated, new_kernel_scale_gated = fp8_ops.in_q( + comp_dtype, jnp.float8_e4m3fn, k_gated, theta.kernel_scale_gated, theta.kernel_amax_history_gated + ) + common_args = (comp_dtype, q_x, new_input_scale) + main_fp8_metas = ( + q_k, + new_kernel_scale, + theta.output_grad_scale, + theta.output_grad_amax_history + ) + gated_fp8_metas = ( + q_k_gated, + new_kernel_scale_gated, + theta.output_grad_scale_gated, + theta.output_grad_amax_history_gated + ) + + @quantized_dot_config(*common_args, *main_fp8_metas) + def _quantized_dot_general(lhs, rhs, dimension_numbers, precision=None, preferred_element_type=None): + pass + + @quantized_dot_config(*common_args, *gated_fp8_metas) + def _quantized_dot_general_gated(lhs, rhs, dimension_numbers, precision=None, preferred_element_type=None): + pass + + y = jnp.einsum(equation, x, k, _dot_general=_quantized_dot_general) + y_gated = jnp.einsum(equation, x, k_gated, _dot_general=_quantized_dot_general_gated) + + y = out_dq( + dq_type=x.dtype, + lhs_scale=new_input_scale, + rhs_scale=new_kernel_scale, + out=y + ) + y_gated = out_dq( + dq_type=x.dtype, + lhs_scale=new_input_scale, + rhs_scale=new_kernel_scale_gated, + out=y + ) + else: + y, x_qdq = self.quantized_einsum( + equation, x, k, return_quantized_x=True + ) + k_gated_qdq = fp8_ops.in_qdq( + comp_dtype, + jnp.float8_e4m3fn, + k_gated, + theta.kernel_scale_gated, + theta.kernel_amax_history_gated, + ) + y_gated_qdq = jnp.einsum( + equation, + x_qdq, + k_gated_qdq, + _dot_general=fp8_ops.dot_general_with_precision, + ) + y_gated = fp8_ops.out_qdq( + comp_dtype, + jnp.float8_e5m2, + y_gated_qdq, + theta.output_grad_scale_gated, + theta.output_grad_amax_history_gated, + ) return y, y_gated diff --git a/praxis/layers/injection/fp8_nvidia_gpu_test.py b/praxis/layers/injection/fp8_nvidia_gpu_test.py index 3aaa0258..3c232103 100644 --- a/praxis/layers/injection/fp8_nvidia_gpu_test.py +++ b/praxis/layers/injection/fp8_nvidia_gpu_test.py @@ -17,7 +17,7 @@ from functools import partial -from absl.testing import absltest +from absl.testing import absltest, parameterized from flax.linen.fp8_ops import quantize_dequantize import jax from jax import numpy as jnp @@ -30,9 +30,10 @@ PARAMS = base_layer.PARAMS -class Fp8LinearsTest(test_utils.TestCase): +class Fp8LinearsTest(test_utils.TestCase, parameterized.TestCase): - def test_fp8_einsum_injection(self): + @parameterized.parameters([True, False]) + def test_fp8_einsum_injection(self, use_direct_quant): # Used to cast the inputs to be representable in FP8, so that the difference # of the results from the original gemm and fp8 gemm is small. cast_to_representable = partial( @@ -100,7 +101,9 @@ def _train(variables, x): } output1a, output1b = run(None, expected_shapes_original) - einsum_tpl = pax_fiddle.Config(fp8_ops.Fp8EinsumOp) + einsum_tpl = pax_fiddle.Config( + fp8_ops.Fp8EinsumOp, use_direct_quant=use_direct_quant + ) output2a, output2b = run(einsum_tpl, expected_shapes_new) dw1, dw2 = output1b[0][PARAMS]['w'], output2b[0][PARAMS]['w'] dx1, dx2 = output1b[1], output2b[1]