From 8603060895a6c1570c47aaf3a9b9b95e75d907bb Mon Sep 17 00:00:00 2001 From: The praxis Authors Date: Mon, 16 Oct 2023 12:19:37 -0700 Subject: [PATCH 1/5] Add activation quantization options for `quantization.for_transformer` PiperOrigin-RevId: 573891577 --- praxis/layers/quantization/quantize.py | 37 +++++++++++++++++++++++--- 1 file changed, 34 insertions(+), 3 deletions(-) diff --git a/praxis/layers/quantization/quantize.py b/praxis/layers/quantization/quantize.py index cccce99e..1f3c37be 100644 --- a/praxis/layers/quantization/quantize.py +++ b/praxis/layers/quantization/quantize.py @@ -371,6 +371,8 @@ def for_transformer( quantize_init_from_checkpoint_rules_task: bool = False, block_size: int = 0, # Internal quantization parameters. + num_bits_act: int | None = None, + use_symmetric_act: bool | None = None, ): """Find and quantize transformer. @@ -420,6 +422,10 @@ def for_transformer( that are defined in task_p.train.init_from_checkpoint_rules.values() block_size: block size for sub-channel quantization. Defaults to 0, which means off. + num_bits_act: The number of bits used for activation quantization. Only + valid when weight_quant_only is false. + use_symmetric_act: Use symmetric activation quantization.Only valid when + weight_quant_only is false. Returns: A modifier that quantizes transformers when applied to a config. @@ -463,6 +469,8 @@ def task(self): quantize_self_attention=quantize_self_attention, quantize_cross_attention=quantize_cross_attention, softmax_only=softmax_only, + use_symmetric_act=use_symmetric_act, + num_bits_act=num_bits_act, ) return task_p @@ -568,6 +576,8 @@ def set_transformer_quantization( use_int4_packed_weights: bool = True, int4_packed_weights_container_dtype: jnp.dtype = jnp.int32, # Internal quantization parameters. + num_bits_act: int | None = None, + use_symmetric_act: bool | None = None, ): """Sets quantization params for TransformerLm or TransformerEncoderDecoder. @@ -611,6 +621,10 @@ def set_transformer_quantization( False int4 weights will be kept in int8. int4_packed_weights_container_dtype: Container type for int4 weights: int32 to pack 8 int4s, or int8 to pack 2 int4s. + num_bits_act: The number of bits used for activation quantization. Only + valid when weight_quant_only is false. + use_symmetric_act: Use symmetric activation quantization. Only valid when + weight_quant_only is false. """ weight_quantization_params = WeightQuantizationParams( precision=num_bits, @@ -621,9 +635,26 @@ def set_transformer_quantization( int4_packed_weights_container_dtype=int4_packed_weights_container_dtype, # Pass internal quantization parameters. ) - act_quantization_params = ( - None if weight_quant_only else ActQuantizationParams(precision=num_bits) - ) + act_quantization_params = None + if ( + num_bits_act is not None or use_symmetric_act is not None + ) and weight_quant_only: + raise ValueError( + f'Activation quantization params (`num_bits_act` and' + f' `use_symmetric_act`) should not be set when `weight_quant_only` is' + f' set to True.' + ) + if not weight_quant_only: + if num_bits_act == None or use_symmetric_act == None: + raise ValueError( + f'Activation quantization params (`num_bits_act` and' + f' `use_symmetric_act`) have to be set when `weight_quant_only` is' + f' set to false.' + ) + act_quantization_params = ActQuantizationParams( + precision=num_bits_act, + symmetric=use_symmetric_act, + ) transformer_tpls = utils.find_target_tpl( config, layers.transformers.Transformer From adb61d315308ebb5badf16114debfd8f79f8f9ee Mon Sep 17 00:00:00 2001 From: The praxis Authors Date: Mon, 16 Oct 2023 12:26:36 -0700 Subject: [PATCH 2/5] Allows custom implementations of StackedTransformer to be used with praxis.layers.transformer_models.TransformerLM class. PiperOrigin-RevId: 573893539 --- praxis/layers/transformer_models.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/praxis/layers/transformer_models.py b/praxis/layers/transformer_models.py index 44b3403b..d8720ae8 100644 --- a/praxis/layers/transformer_models.py +++ b/praxis/layers/transformer_models.py @@ -511,8 +511,6 @@ def setup(self) -> None: xformer_params = xformer_params.pipeline_stage if issubclass(xformer_params.cls, transformers.StackedTransformerRepeated): xformer_params = xformer_params.block - if not issubclass(xformer_params.cls, transformers.StackedTransformer): - assert False, f'{xformer_params.cls} not supported.' assert ( xformer_params.model_dims == 0 or xformer_params.model_dims == self.model_dims From a31b34c8dd24b11f2e41a079fc9824aeddb2b853 Mon Sep 17 00:00:00 2001 From: Chandra Devarakonda Date: Mon, 16 Oct 2023 15:41:40 -0700 Subject: [PATCH 3/5] Praxis 1.2.0 release PiperOrigin-RevId: 573950784 --- RELEASE.md | 8 ++ praxis/pip_package/cloudbuild-release.yaml | 2 +- praxis/pip_package/requirements.txt | 108 ++++++++++----------- setup.py | 2 +- 4 files changed, 64 insertions(+), 56 deletions(-) diff --git a/RELEASE.md b/RELEASE.md index fbaf7102..d6380d45 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -1,3 +1,11 @@ +# Version: 1.2.0 +## Major Features and Improvements +## Breaking changes +## Deprecations +## Note +* Version: 1.2.0 +* Build Date: 20231016 +* Praxis commit: 7bd63412bf86a68e09fcd9455f76a4909d19377e # Version: 1.1.0 ## Major Features and Improvements * Move to python 3.10 as the minimal python requirement (previously on python 3.8). diff --git a/praxis/pip_package/cloudbuild-release.yaml b/praxis/pip_package/cloudbuild-release.yaml index 5c5c383c..11f596de 100644 --- a/praxis/pip_package/cloudbuild-release.yaml +++ b/praxis/pip_package/cloudbuild-release.yaml @@ -16,7 +16,7 @@ steps: substitutions: _PYTHON_VERSION: '3.10' - _RELEASE_VERSION: '1.1.0' # or rX.Y + _RELEASE_VERSION: '1.2.0' # or rX.Y _IMAGE_NAME: 'praxis_${_RELEASE_VERSION}_${_PYTHON_VERSION}' _WHEEL_FOLDER: '/tmp/wheels' options: diff --git a/praxis/pip_package/requirements.txt b/praxis/pip_package/requirements.txt index 1bb8a8d6..30680d84 100644 --- a/praxis/pip_package/requirements.txt +++ b/praxis/pip_package/requirements.txt @@ -27,7 +27,7 @@ argon2-cffi-bindings==21.2.0 # via argon2-cffi array-record==0.4.1 # via tfds-nightly -arrow==1.2.3 +arrow==1.3.0 # via isoduration asttokens==2.4.0 # via stack-data @@ -40,29 +40,29 @@ attrs==23.1.0 # jsonschema # lingvo # referencing -babel==2.12.1 +babel==2.13.0 # via jupyterlab-server backcall==0.2.0 # via ipython beautifulsoup4==4.12.2 # via nbconvert -bleach==6.0.0 +bleach==6.1.0 # via nbconvert cachetools==5.3.1 # via google-auth certifi==2023.7.22 # via requests -cffi==1.15.1 +cffi==1.16.0 # via argon2-cffi-bindings -charset-normalizer==3.2.0 +charset-normalizer==3.3.0 # via requests -chex==0.1.83 +chex==0.1.7 # via optax click==8.1.7 # via # tensorflow-datasets # tfds-nightly -clu==0.0.9 +clu==0.0.10 # via -r praxis-requirements.in comm==0.1.4 # via @@ -72,7 +72,7 @@ contextlib2==21.6.0 # via ml-collections contourpy==1.1.1 # via matplotlib -cycler==0.11.0 +cycler==0.12.1 # via matplotlib debugpy==1.8.0 # via ipykernel @@ -82,11 +82,12 @@ defusedxml==0.7.1 # via nbconvert dm-tree==0.1.8 # via + # chex # tensorflow-datasets # tfds-nightly -einops==0.6.1 +einops==0.7.0 # via -r praxis-requirements.in -etils[enp,epath,epy]==1.5.0 +etils[enp,epath,epy]==1.5.1 # via # -r praxis-requirements.in # array-record @@ -99,9 +100,9 @@ exceptiongroup==1.1.3 # via # anyio # ipython -executing==1.2.0 +executing==2.0.0 # via stack-data -fastjsonschema==2.18.0 +fastjsonschema==2.18.1 # via nbformat fiddle @ git+https://github.com/google/fiddle # via -r praxis-requirements.in @@ -111,7 +112,7 @@ flax==0.7.4 # via # -r praxis-requirements.in # clu -fonttools==4.42.1 +fonttools==4.43.1 # via matplotlib fqdn==1.5.1 # via jsonschema @@ -119,7 +120,7 @@ fsspec==2023.9.2 # via etils gast==0.4.0 # via tensorflow -google-auth==2.23.0 +google-auth==2.23.3 # via # google-auth-oauthlib # tensorboard @@ -127,17 +128,17 @@ google-auth-oauthlib==0.4.6 # via tensorboard google-pasta==0.2.0 # via tensorflow -googleapis-common-protos==1.60.0 +googleapis-common-protos==1.61.0 # via tensorflow-metadata graph-compression-google-research==0.0.4 # via lingvo graphviz==0.20.1 # via fiddle -grpcio==1.58.0 +grpcio==1.59.0 # via # tensorboard # tensorflow -h5py==3.9.0 +h5py==3.10.0 # via tensorflow idna==3.4 # via @@ -153,7 +154,7 @@ ipykernel==6.25.2 # jupyterlab # lingvo # qtconsole -ipython==8.15.0 +ipython==8.16.1 # via # ipykernel # ipywidgets @@ -174,15 +175,15 @@ jax @ git+https://github.com/google/jax # orbax-checkpoint jax-bitempered-loss==0.0.2 # via -r praxis-requirements.in -jaxlib==0.4.16 +jaxlib==0.4.18 # via # chex # clu # optax # orbax-checkpoint -jaxtyping==0.2.22 +jaxtyping==0.2.23 # via -r praxis-requirements.in -jedi==0.19.0 +jedi==0.19.1 # via ipython jinja2==3.1.2 # via @@ -205,7 +206,7 @@ jsonschema-specifications==2023.7.1 # via jsonschema jupyter==1.0.0 # via lingvo -jupyter-client==8.3.1 +jupyter-client==8.4.0 # via # ipykernel # jupyter-console @@ -214,7 +215,7 @@ jupyter-client==8.3.1 # qtconsole jupyter-console==6.6.3 # via jupyter -jupyter-core==5.3.1 +jupyter-core==5.4.0 # via # ipykernel # jupyter-client @@ -225,13 +226,13 @@ jupyter-core==5.3.1 # nbconvert # nbformat # qtconsole -jupyter-events==0.7.0 +jupyter-events==0.8.0 # via jupyter-server jupyter-http-over-ws==0.0.8 # via lingvo jupyter-lsp==2.2.0 # via jupyterlab -jupyter-server==2.7.3 +jupyter-server==2.8.0 # via # jupyter-lsp # jupyterlab @@ -240,7 +241,7 @@ jupyter-server==2.7.3 # notebook-shim jupyter-server-terminals==0.4.4 # via jupyter-server -jupyterlab==4.0.6 +jupyterlab==4.0.7 # via notebook jupyterlab-pygments==0.2.2 # via nbconvert @@ -258,11 +259,11 @@ kiwisolver==1.4.5 # via matplotlib libclang==16.0.6 # via tensorflow -libcst==1.0.1 +libcst==1.1.0 # via fiddle lingvo==0.12.7 # via -r praxis-requirements.in -markdown==3.4.4 +markdown==3.5 # via tensorboard markdown-it-py==3.0.0 # via rich @@ -279,7 +280,7 @@ matplotlib-inline==0.1.6 # ipython mdurl==0.1.2 # via markdown-it-py -mistune==3.0.1 +mistune==3.0.2 # via nbconvert ml-collections==0.1.1 # via clu @@ -291,7 +292,7 @@ model-pruning-google-research==0.0.5 # via lingvo mpmath==1.3.0 # via sympy -msgpack==1.0.6 +msgpack==1.0.7 # via # flax # orbax-checkpoint @@ -299,7 +300,7 @@ mypy-extensions==1.0.0 # via typing-inspect nbclient==0.8.0 # via nbconvert -nbconvert==7.8.0 +nbconvert==7.9.2 # via # jupyter # jupyter-server @@ -312,7 +313,7 @@ nest-asyncio==1.5.8 # via # ipykernel # orbax-checkpoint -notebook==7.0.4 +notebook==7.0.5 # via # jupyter # jupyter-http-over-ws @@ -320,7 +321,7 @@ notebook-shim==0.2.3 # via # jupyterlab # notebook -numpy==1.26.0 +numpy==1.23.1 # via # -r praxis-requirements.in # chex @@ -359,11 +360,11 @@ optax==0.1.7 # flax optax-shampoo==0.0.6 # via -r praxis-requirements.in -orbax-checkpoint==0.4.0 +orbax-checkpoint==0.4.1 # via flax overrides==7.4.0 # via jupyter-server -packaging==23.1 +packaging==23.2 # via # clu # ipykernel @@ -383,11 +384,11 @@ pexpect==4.8.0 # via ipython pickleshare==0.7.5 # via ipython -pillow==10.0.1 +pillow==10.1.0 # via # lingvo # matplotlib -platformdirs==3.10.0 +platformdirs==3.11.0 # via jupyter-core prometheus-client==0.17.1 # via jupyter-server @@ -410,7 +411,7 @@ protobuf==3.19.6 # tensorflow-hub # tensorflow-metadata # tfds-nightly -psutil==5.9.5 +psutil==5.9.6 # via # ipykernel # tensorflow-datasets @@ -485,9 +486,9 @@ rfc3986-validator==0.1.1 # via # jsonschema # jupyter-events -rich==13.5.3 +rich==13.6.0 # via flax -rpds-py==0.10.3 +rpds-py==0.10.6 # via # jsonschema # referencing @@ -495,7 +496,7 @@ rsa==4.9 # via google-auth scikit-learn==1.3.1 # via lingvo -scipy==1.11.2 +scipy==1.11.3 # via # jax # jaxlib @@ -521,7 +522,7 @@ sniffio==1.3.0 # via anyio soupsieve==2.5 # via beautifulsoup4 -stack-data==0.6.2 +stack-data==0.6.3 # via ipython sympy==1.12 # via lingvo @@ -542,7 +543,7 @@ tensorflow-datasets==4.8.3 # lingvo tensorflow-estimator==2.9.0 # via tensorflow -tensorflow-hub==0.14.0 +tensorflow-hub==0.15.0 # via # lingvo # tensorflow-text @@ -557,7 +558,7 @@ tensorflow-text==2.9.0 # via # -r praxis-requirements.in # lingvo -tensorstore==0.1.44 +tensorstore==0.1.45 # via # flax # orbax-checkpoint @@ -597,7 +598,7 @@ tqdm==4.66.1 # via # tensorflow-datasets # tfds-nightly -traitlets==5.10.0 +traitlets==5.11.2 # via # comm # ipykernel @@ -614,10 +615,12 @@ traitlets==5.10.0 # nbconvert # nbformat # qtconsole -typeguard==4.1.5 +typeguard==2.13.3 # via # -r praxis-requirements.in # jaxtyping +types-python-dateutil==2.8.19.14 + # via arrow typing-extensions==4.8.0 # via # async-lru @@ -630,17 +633,14 @@ typing-extensions==4.8.0 # libcst # orbax-checkpoint # tensorflow - # typeguard # typing-inspect typing-inspect==0.9.0 # via libcst uri-template==1.3.0 # via jsonschema -urllib3==1.26.16 - # via - # google-auth - # requests -wcwidth==0.2.6 +urllib3==2.0.6 + # via requests +wcwidth==0.2.8 # via prompt-toolkit webcolors==1.13 # via jsonschema @@ -648,9 +648,9 @@ webencodings==0.5.1 # via # bleach # tinycss2 -websocket-client==1.6.3 +websocket-client==1.6.4 # via jupyter-server -werkzeug==2.3.7 +werkzeug==3.0.0 # via tensorboard wheel==0.41.2 # via diff --git a/setup.py b/setup.py index 086d8bbb..6448bb62 100644 --- a/setup.py +++ b/setup.py @@ -47,7 +47,7 @@ def _get_requirements(): setup( name='praxis', - version='1.1.0', + version='1.2.0', description=( 'Functionalities such as a layers for building neural networks in Jax.' ), From 1a51bb8ce244c9c5f978e1d313beaaec5def06e0 Mon Sep 17 00:00:00 2001 From: The praxis Authors Date: Mon, 16 Oct 2023 11:03:15 -0700 Subject: [PATCH 4/5] Fix the bug in asymmetric activation quantization inference logic. PiperOrigin-RevId: 573866807 --- praxis/layers/quantization/operations.py | 18 +++++++++++++----- praxis/layers/quantization/operations_test.py | 19 ++++++++++++++++++- 2 files changed, 31 insertions(+), 6 deletions(-) diff --git a/praxis/layers/quantization/operations.py b/praxis/layers/quantization/operations.py index 5511a987..0f18382c 100644 --- a/praxis/layers/quantization/operations.py +++ b/praxis/layers/quantization/operations.py @@ -240,6 +240,17 @@ def einsum( Returns: A JTensor. """ + # Non performent equation for inference testing purposes + # TODO: b/305735188 - Improve the performance by using the integer einsum op. + if zp_act is not None: + dequantized_x = jnp.multiply(x, scale_act) - zp_act + # explicit broadcast if necessary. + if w.ndim == 3 and scale.ndim == 1: + scale = jnp.expand_dims(scale, (1, 2)) + dequantized_w = jnp.multiply(w, scale) + if zp is not None: + dequantized_w = dequantized_w - zp + return jnp.einsum(eqn, dequantized_x, dequantized_w) use_int_dot_general = ( x.dtype in QUANTIZED_TYPES and w.dtype in QUANTIZED_TYPES @@ -302,11 +313,6 @@ def einsum( offset = compute_offset(x, zp, eqn) ret = ret - offset - if zp_act is not None: - # Non performent equation for inference testing purposes - dequantized_x = scale_act * x - zp_act - dequantized_w = scale * w - zp - ret = jnp.einsum(eqn, dequantized_x, dequantized_w) return ret @@ -623,6 +629,8 @@ def reduce_einsum_activation_precision( if squeeze: scale = jnp.squeeze(scale, axis=contract_dims) + if zp is not None: + zp = jnp.squeeze(zp, axis=contract_dims) return t, scale, zp diff --git a/praxis/layers/quantization/operations_test.py b/praxis/layers/quantization/operations_test.py index 4e23f20a..c45c1ff1 100644 --- a/praxis/layers/quantization/operations_test.py +++ b/praxis/layers/quantization/operations_test.py @@ -118,7 +118,24 @@ def test_quantized_einsum_with_asym_weight_act(self, eqn): ret = operations.einsum(eqn, qx, qw, sw, zpw, sx, zpx) expected = jnp.einsum(eqn, x, w) - self.assertAllClose(ret, expected, rtol=0.1, atol=0.5) + self.assertAllClose(ret, expected, rtol=0.02, atol=0.02) + + @parameterized.named_parameters( + ('eqn_with_dot', '...y,yz->...z'), + ) + def test_quantized_einsum_with_aym_weight_asym_act(self, eqn): + w = jax.random.uniform(jax.random.PRNGKey(0), (4, 3)) + x = jax.random.uniform(jax.random.PRNGKey(0), (2, 4)) + qw, sw, zpw = operations.reduce_einsum_weight_precision( + eqn, w, use_symmetric=True + ) + qx, sx, zpx = operations.reduce_einsum_activation_precision( + eqn, x, symmetric=False + ) + + ret = operations.einsum(eqn, qx, qw, sw, zpw, sx, zpx) + expected = jnp.einsum(eqn, x, w) + self.assertAllClose(ret, expected, rtol=0.02, atol=0.02) @parameterized.parameters( ('ab,bc->ac', (10, 4), (4, 5)), From c0b0af8d9de20a209b0640c6c74e3d6d1d3361f6 Mon Sep 17 00:00:00 2001 From: Chandra Devarakonda Date: Mon, 16 Oct 2023 23:00:44 +0000 Subject: [PATCH 5/5] Update praxis 1.2.0 requirements --- praxis/pip_package/requirements.txt | 4 ++-- requirements.in | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/praxis/pip_package/requirements.txt b/praxis/pip_package/requirements.txt index 30680d84..0eae5e75 100644 --- a/praxis/pip_package/requirements.txt +++ b/praxis/pip_package/requirements.txt @@ -104,7 +104,7 @@ executing==2.0.0 # via stack-data fastjsonschema==2.18.1 # via nbformat -fiddle @ git+https://github.com/google/fiddle +fiddle==0.2.11 # via -r praxis-requirements.in flatbuffers==1.12 # via tensorflow @@ -165,7 +165,7 @@ ipywidgets==8.1.1 # via jupyter isoduration==20.11.0 # via jsonschema -jax @ git+https://github.com/google/jax +jax==0.4.18 # via # -r praxis-requirements.in # chex diff --git a/requirements.in b/requirements.in index 5f8e04fd..e0defcae 100644 --- a/requirements.in +++ b/requirements.in @@ -5,9 +5,9 @@ absl-py clu einops etils -fiddle @ git+https://github.com/google/fiddle +fiddle==0.2.11 flax -jax @ git+https://github.com/google/jax +jax==0.4.18 jax-bitempered-loss jaxtyping lingvo