From e5a8d0b75ef10678ef816abb6b4ed5a5dd3c8ddc Mon Sep 17 00:00:00 2001 From: Shu Wang Date: Tue, 28 May 2024 15:53:30 -0500 Subject: [PATCH 1/2] experiment_fp8_inference --- praxis/layers/quantization/operations.py | 7 +++++-- praxis/layers/quantization/quantizer.py | 4 ++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/praxis/layers/quantization/operations.py b/praxis/layers/quantization/operations.py index 1199895d..592d3ce8 100644 --- a/praxis/layers/quantization/operations.py +++ b/praxis/layers/quantization/operations.py @@ -28,6 +28,7 @@ from praxis.layers.quantization import optimization from praxis.layers.quantization import quantization_hparams from praxis.layers.quantization import utils +from flax.linen import fp8_ops JTensor = pytypes.JTensor @@ -605,7 +606,7 @@ def einsum( jax.dtypes.scalar_type_of(w.dtype) == float and jnp.finfo(w.dtype).bits == 8 ): - w = w.astype(jnp.bfloat16) + pass # stay as fp8 if x.dtype in INT_TYPES and w.dtype in INT_TYPES: assert not swap_xw, 'No need to swap x and w when both are int types.' @@ -626,7 +627,9 @@ def einsum( if swap_xw: ret = jnp.einsum(eqn_normalized, w, x) else: - ret = jnp.einsum(eqn_normalized, x, w) + x = fp8_ops.quantize_dequantize(x, jnp.float8_e4m3fn, 1.01 * jnp.ones((1,)), jnp.bfloat16) + w_dq = fp8_ops.in_dq(jnp.bfloat16, w, 0.99*jnp.ones((1,)), jnp.ones((1024,))) # kernel_amax_history is dummy + ret = jnp.einsum(eqn_normalized, x, w_dq, _dot_general=fp8_ops.dot_general_with_precision) if scale_act is not None: if scale_act.ndim == 0: diff --git a/praxis/layers/quantization/quantizer.py b/praxis/layers/quantization/quantizer.py index d8248bee..e2012b76 100644 --- a/praxis/layers/quantization/quantizer.py +++ b/praxis/layers/quantization/quantizer.py @@ -269,8 +269,8 @@ def quantized_einsum( and jnp.finfo(dtype).bits == 8 ): w = jax.lax.bitcast_convert_type(w, dtype) - # cast to bf16 since bf16 x fp8 is not supported. - w = w.astype(jnp.bfloat16) + # bf16 x fp8 is supported by nvidia + # w = w.astype(jnp.bfloat16) out = operations.einsum( eqn, x, From c18469f7457cbe8eac12194d7b8b3b8f5bfb2bcc Mon Sep 17 00:00:00 2001 From: Shu Wang Date: Tue, 4 Jun 2024 23:19:33 -0500 Subject: [PATCH 2/2] Experimental fp8 inference. --- praxis/base_layer.py | 31 +++++++++++++- praxis/layers/quantization/attentions.py | 6 +-- praxis/layers/quantization/linears.py | 5 ++- praxis/layers/quantization/operations.py | 17 +++++--- praxis/layers/quantization/quantizer.py | 53 ++++++++++++++---------- 5 files changed, 79 insertions(+), 33 deletions(-) diff --git a/praxis/base_layer.py b/praxis/base_layer.py index 3f628b84..0d0bea48 100644 --- a/praxis/base_layer.py +++ b/praxis/base_layer.py @@ -110,6 +110,7 @@ # Postfix for quantized scale and zero point names. QUANTIZED_SCALE_NAME_POSTFIX = '_quantized_scale' +QUANTIZED_SCALE_ACT_NAME_POSTFIX = '_quantized_act_scale' QUANTIZED_ZP_NAME_POSTFIX = '_quantized_zp' # Postfix for sparsity mask @@ -2294,13 +2295,20 @@ def create_quantized_variable( if scale_hparams is None: scale_hparams = WeightHParams(shape=scale_shape) else: - if len(scale_shape) > 0: - raise ValueError('Should either scale_shape or scale_hparams, not both') + pass self.create_variable(name=name, var_hparams=quantized_weight_hparams) self.create_variable( name=name + QUANTIZED_SCALE_NAME_POSTFIX, var_hparams=scale_hparams, ) + dtype = weight_hparams.dtype + if (jax.dtypes.scalar_type_of(dtype) == float + and jnp.finfo(dtype).bits == 8 + ): + self.create_variable( + name=name + QUANTIZED_SCALE_ACT_NAME_POSTFIX, + var_hparams=scale_hparams, + ) if not use_symmetric: self.create_variable( name=name + QUANTIZED_ZP_NAME_POSTFIX, @@ -2335,6 +2343,25 @@ def get_quantized_weight( zp = None if use_symmetric else self.theta[zp_name] return self.theta[name], self.theta[scale_name], zp + @nn.nowrap + def get_quantized_act_scale( + self, name: str, + ) -> tuple[JTensor, JTensor, JTensor | None]: + """Gets quantized activation scale. + + `name` will be name of the weight tensor; assumes scale and zero point + tensor have the postfix, `_quantized_act_scale`. + + Args: + name: Variable name for the weight tensor. + + Returns: + Activation scale Tensor. + """ + + scale_act_name = name + QUANTIZED_SCALE_ACT_NAME_POSTFIX + return self.theta[scale_act_name] + @nn.nowrap def create_sparse_variable(self, name: str, weight_hparams: WeightHParams): """Creates the weight and mask tensors for sparse variables. diff --git a/praxis/layers/quantization/attentions.py b/praxis/layers/quantization/attentions.py index 4335c233..e8d3d33b 100644 --- a/praxis/layers/quantization/attentions.py +++ b/praxis/layers/quantization/attentions.py @@ -110,7 +110,7 @@ def _get_weight_scale_shape(self, block_size, use_block_size): else: weight_shape = [self.input_dim] + hd_shape - scale_shape = [self.input_dim] if self.is_output_projection else hd_shape + scale_shape = [1] if block_size > 0 and use_block_size: eqn = self._get_eqn() @@ -546,9 +546,9 @@ def setup(self) -> None: self.set_up_weights( weight_name='w', weight_params=pc, - scale_shape=[3] + hd_shape, + scale_shape=[1], ) - self.create_sparsity_variables('w', pc, scale_shape=[3] + hd_shape) + self.create_sparsity_variables('w', pc, scale_shape=[1]) if self.use_bias: # Combined bias weight for q, k, v projections. pc_bias = WeightHParams( diff --git a/praxis/layers/quantization/linears.py b/praxis/layers/quantization/linears.py index ddaf602d..987261a2 100644 --- a/praxis/layers/quantization/linears.py +++ b/praxis/layers/quantization/linears.py @@ -106,7 +106,7 @@ def _get_weight_hparams( """ wp = self.weight_split_dims_mapping weight_shape = [self.input_dims, self.output_dims] - scale_shape = [self.output_dims] + scale_shape = [1] block_size = self._sub_channel_block_size() if using_sub_channel: weight_shape = self._get_sub_channel_shape(weight_shape, block_size, 0) @@ -190,11 +190,12 @@ def setup(self) -> None: weight_name='w', weight_params=weight_hparams, scale_hparams=scale_hparams, + scale_shape=[1], ) self.create_sparsity_variables( 'w', weight_hparams, - scale_shape=[self.output_dims], + scale_shape=[1], ) def __call__(self, inputs: JTensor) -> JTensor: diff --git a/praxis/layers/quantization/operations.py b/praxis/layers/quantization/operations.py index 592d3ce8..c406d5bf 100644 --- a/praxis/layers/quantization/operations.py +++ b/praxis/layers/quantization/operations.py @@ -602,11 +602,12 @@ def einsum( w_dequantized = _dequantize(w, scale, zp, eqn_to_weight_contract_dims(eqn)) return jnp.einsum(eqn, x_dequantized, w_dequantized) + use_fp8 = False if ( jax.dtypes.scalar_type_of(w.dtype) == float and jnp.finfo(w.dtype).bits == 8 ): - pass # stay as fp8 + use_fp8 = True # w stay as fp8 if x.dtype in INT_TYPES and w.dtype in INT_TYPES: assert not swap_xw, 'No need to swap x and w when both are int types.' @@ -627,12 +628,18 @@ def einsum( if swap_xw: ret = jnp.einsum(eqn_normalized, w, x) else: - x = fp8_ops.quantize_dequantize(x, jnp.float8_e4m3fn, 1.01 * jnp.ones((1,)), jnp.bfloat16) - w_dq = fp8_ops.in_dq(jnp.bfloat16, w, 0.99*jnp.ones((1,)), jnp.ones((1024,))) # kernel_amax_history is dummy - ret = jnp.einsum(eqn_normalized, x, w_dq, _dot_general=fp8_ops.dot_general_with_precision) + dot_general_with_precision = lambda lhs, rhs, dimension_numbers, \ + precision=None, preferred_element_type=jnp.bfloat16: lax.dot_general( + lhs, + rhs, + dimension_numbers=dimension_numbers, + precision=precision, + preferred_element_type=jnp.bfloat16, #TODO: use proper type + ) + ret = jnp.einsum(eqn_normalized, x, w, preferred_element_type=jnp.bfloat16) if scale_act is not None: - if scale_act.ndim == 0: + if scale_act.ndim == 0 or use_fp8: scale *= scale_act else: ret *= jnp.expand_dims(scale_act, _get_expand_dims_lhs(eqn)) diff --git a/praxis/layers/quantization/quantizer.py b/praxis/layers/quantization/quantizer.py index e2012b76..27073536 100644 --- a/praxis/layers/quantization/quantizer.py +++ b/praxis/layers/quantization/quantizer.py @@ -43,6 +43,7 @@ QuantizationParams = quantization_hparams.QuantizationParams instance_field = base_layer.instance_field WeightQuantizationParams = quantization_hparams.WeightQuantizationParams +QUANTIZED_SCALE_ACT_NAME_POSTFIX = '_quantized_act_scale' class QuantizationLayer(base_layer.BaseLayer): @@ -131,6 +132,7 @@ def set_up_weights( jax.dtypes.scalar_type_of(dtype) == float and jnp.finfo(dtype).bits == 8 ): + weight_params.dtype=dtype dtype = jnp.int8 self.create_quantized_variable( weight_name, @@ -219,6 +221,14 @@ def quantized_einsum( else: return jnp.einsum(eqn, x, w) + use_fp8 = False + dtype = self.quantization.weight_params.dtype + if ( + jax.dtypes.scalar_type_of(dtype) == float + and jnp.finfo(dtype).bits == 8 + ): + use_fp8 = True + # Optionally create step count. step_count = None if self.quantization.weight_params.use_step_count: @@ -249,28 +259,29 @@ def quantized_einsum( x = x.astype(jnp.int8) logging.info('Static activation quantization is not supported yet.') elif self.quantization.act_params is not None: - act_params = self.quantization.act_params - x, scale_act, zp_act = operations.reduce_einsum_activation_precision( - eqn, - x, - bits=act_params.precision, - per_channel=act_params.per_channel, - symmetric=act_params.symmetric, - percentile=act_params.clipping_coeff, - ) - if act_params.precision <= 8: - if act_params.symmetric: - # TODO(rybakov): add support for asymmetric too. - x = x.astype(jnp.int8) - - dtype = self.quantization.weight_params.dtype - if ( - jax.dtypes.scalar_type_of(dtype) == float - and jnp.finfo(dtype).bits == 8 - ): + if not use_fp8: + act_params = self.quantization.act_params + x, scale_act, zp_act = operations.reduce_einsum_activation_precision( + eqn, + x, + bits=act_params.precision, + per_channel=act_params.per_channel, + symmetric=act_params.symmetric, + percentile=act_params.clipping_coeff, + ) + if act_params.precision <= 8 and act_params.symmetric: + if act_params.symmetric: + # TODO(rybakov): add support for asymmetric too. + x = x.astype(jnp.int8) + else: + # per-tensor quant + scale_act = self.get_quantized_act_scale(weight_name) + x = fp8_ops_linen.quantize(x, jnp.float8_e4m3fn, scale_act, jnp.bfloat16) + + if use_fp8: # + # cast from int8 to fp8 w = jax.lax.bitcast_convert_type(w, dtype) - # bf16 x fp8 is supported by nvidia - # w = w.astype(jnp.bfloat16) + out = operations.einsum( eqn, x,