diff --git a/.github/workflows/ci-lint.yml b/.github/workflows/ci-lint.yml new file mode 100644 index 0000000000..dede434d68 --- /dev/null +++ b/.github/workflows/ci-lint.yml @@ -0,0 +1,21 @@ +name: pre-commit + +on: + pull_request: + push: + branches: [master] + +jobs: + pre-commit: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4.1.7 + with: + # Ensure the full history is fetched + # This is required to run pre-commit on a specific set of commits + # TODO: Remove this when all the pre-commit issues are fixed + fetch-depth: 0 + - uses: actions/setup-python@v5.1.1 + with: + python-version: 3.13 + - uses: pre-commit/action@v3.0.1 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000000..387a3efbf3 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,39 @@ +# pre-commit is a tool to perform a predefined set of tasks manually and/or +# automatically before git commits are made. +# +# Config reference: https://pre-commit.com/#pre-commit-configyaml---top-level +# +# Common tasks +# +# - Register git hooks: pre-commit install --install-hooks +# - Run on all files: pre-commit run --all-files +# +# These pre-commit hooks are run as CI. +# +# NOTE: if it can be avoided, add configs/args in pyproject.toml or below instead of creating a new `.config.file`. +# https://pre-commit.ci/#configuration +ci: + autoupdate_schedule: monthly + autofix_commit_msg: | + [pre-commit.ci] Apply automatic pre-commit fixes + +repos: + # general + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.6.0 + hooks: + - id: end-of-file-fixer + exclude: '\.svg$' + - id: trailing-whitespace + exclude: '\.svg$' + - id: check-json + - id: check-yaml + args: [--allow-multiple-documents, --unsafe] + - id: check-toml + + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.5.6 + hooks: + - id: ruff + args: ["--fix"] + - id: ruff-format diff --git a/MANIFEST.in b/MANIFEST.in index 21f0e01bb1..c858b9fc8e 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1 +1 @@ -recursive-include static *.* \ No newline at end of file +recursive-include static *.* diff --git a/README.md b/README.md index f5c3befd3b..773549056b 100644 --- a/README.md +++ b/README.md @@ -87,7 +87,7 @@ pip3 install tensorflow_model_analysis--py3-none-any.whl ### Running tests -To run tests, run +To run tests, run ``` python -m unittest discover -p *_test.py diff --git a/docs/api_docs/python/tfma-evaluators.md b/docs/api_docs/python/tfma-evaluators.md index bb098cfd48..a1797765da 100644 --- a/docs/api_docs/python/tfma-evaluators.md +++ b/docs/api_docs/python/tfma-evaluators.md @@ -2,4 +2,3 @@ # TFMA Evaluators ::: tensorflow_model_analysis.evaluators - diff --git a/docs/api_docs/python/tfma-experimental.md b/docs/api_docs/python/tfma-experimental.md index 77af1ee691..31fcfa0f21 100644 --- a/docs/api_docs/python/tfma-experimental.md +++ b/docs/api_docs/python/tfma-experimental.md @@ -2,4 +2,3 @@ # TFMA Experimental ::: tensorflow_model_analysis.experimental - diff --git a/docs/api_docs/python/tfma-extractors.md b/docs/api_docs/python/tfma-extractors.md index 08ab4f7263..acb636aab4 100644 --- a/docs/api_docs/python/tfma-extractors.md +++ b/docs/api_docs/python/tfma-extractors.md @@ -2,4 +2,3 @@ # TFMA Extractors ::: tensorflow_model_analysis.extractors - diff --git a/docs/api_docs/python/tfma-metrics.md b/docs/api_docs/python/tfma-metrics.md index 0e424df164..3ea124dd54 100644 --- a/docs/api_docs/python/tfma-metrics.md +++ b/docs/api_docs/python/tfma-metrics.md @@ -2,4 +2,3 @@ # TFMA Metrics ::: tensorflow_model_analysis.metrics - diff --git a/docs/api_docs/python/tfma-post_export_metrics.md b/docs/api_docs/python/tfma-post_export_metrics.md index c476e10ca1..4025b2d1ef 100644 --- a/docs/api_docs/python/tfma-post_export_metrics.md +++ b/docs/api_docs/python/tfma-post_export_metrics.md @@ -2,4 +2,3 @@ # TFMA Post_Export_Metrics ::: tensorflow_model_analysis.post_export_metrics - diff --git a/docs/api_docs/python/tfma-sdk.md b/docs/api_docs/python/tfma-sdk.md index 267fd7d780..9cb41fd2a5 100644 --- a/docs/api_docs/python/tfma-sdk.md +++ b/docs/api_docs/python/tfma-sdk.md @@ -2,4 +2,3 @@ # TFMA SDK ::: tensorflow_model_analysis.sdk - diff --git a/docs/api_docs/python/tfma-types.md b/docs/api_docs/python/tfma-types.md index 4b859da9cb..04aa1770c0 100644 --- a/docs/api_docs/python/tfma-types.md +++ b/docs/api_docs/python/tfma-types.md @@ -2,4 +2,3 @@ # TFMA Types ::: tensorflow_model_analysis.types - diff --git a/docs/api_docs/python/tfma-utils.md b/docs/api_docs/python/tfma-utils.md index 6057d9e977..75a6332379 100644 --- a/docs/api_docs/python/tfma-utils.md +++ b/docs/api_docs/python/tfma-utils.md @@ -2,4 +2,3 @@ # TFMA Utils ::: tensorflow_model_analysis.utils - diff --git a/docs/api_docs/python/tfma-validators.md b/docs/api_docs/python/tfma-validators.md index 835d8f55c8..7018ff856a 100644 --- a/docs/api_docs/python/tfma-validators.md +++ b/docs/api_docs/python/tfma-validators.md @@ -2,4 +2,3 @@ # TFMA Validators ::: tensorflow_model_analysis.validators - diff --git a/docs/api_docs/python/tfma-version.md b/docs/api_docs/python/tfma-version.md index fe4a03fbe4..d64a57d26e 100644 --- a/docs/api_docs/python/tfma-version.md +++ b/docs/api_docs/python/tfma-version.md @@ -2,4 +2,3 @@ # TFMA Version ::: tensorflow_model_analysis.version - diff --git a/docs/api_docs/python/tfma-view.md b/docs/api_docs/python/tfma-view.md index 935553d98d..fb3b6d46ef 100644 --- a/docs/api_docs/python/tfma-view.md +++ b/docs/api_docs/python/tfma-view.md @@ -2,4 +2,3 @@ # TFMA View ::: tensorflow_model_analysis.view - diff --git a/docs/api_docs/python/tfma-writers.md b/docs/api_docs/python/tfma-writers.md index 26ca97058c..48d97a31d5 100644 --- a/docs/api_docs/python/tfma-writers.md +++ b/docs/api_docs/python/tfma-writers.md @@ -2,4 +2,3 @@ # TFMA Writers ::: tensorflow_model_analysis.writers - diff --git a/docs/javascripts/mathjax.js b/docs/javascripts/mathjax.js index 0be88e0419..7e48906afd 100644 --- a/docs/javascripts/mathjax.js +++ b/docs/javascripts/mathjax.js @@ -11,7 +11,7 @@ window.MathJax = { } }; -document$.subscribe(() => { +document$.subscribe(() => { MathJax.startup.output.clearCache() MathJax.typesetClear() MathJax.texReset() diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000..90e680b79f --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,140 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +[build-system] +requires = [ + "setuptools", + "wheel", +] + +[tool.ruff] +line-length = 88 + +[tool.ruff.lint] +select = [ + # pycodestyle + "E", + "W", + # Pyflakes + "F", + # pyupgrade + "UP", + # flake8-bugbear + "B", + # flake8-simplify + "SIM", + # isort + "I", + # pep8 naming + "N", + # pydocstyle + "D", + # annotations + "ANN", + # debugger + "T10", + # flake8-pytest + "PT", + # flake8-return + "RET", + # flake8-unused-arguments + "ARG", + # flake8-fixme + "FIX", + # flake8-eradicate + "ERA", + # pandas-vet + "PD", + # numpy-specific rules + "NPY", +] + +ignore = [ + "D104", # Missing docstring in public package + "D100", # Missing docstring in public module + "D211", # No blank line before class + "PD901", # Avoid using 'df' for pandas dataframes. Perfectly fine in functions with limited scope + "ANN201", # Missing return type annotation for public function (makes no sense for NoneType return types...) + "ANN101", # Missing type annotation for `self` + "ANN204", # Missing return type annotation for special method + "ANN002", # Missing type annotation for `*args` + "ANN003", # Missing type annotation for `**kwargs` + "D105", # Missing docstring in magic method + "D203", # 1 blank line before after class docstring + "D204", # 1 blank line required after class docstring + "D413", # 1 blank line after parameters + "SIM108", # Simplify if/else to one line; not always clearer + "D206", # Docstrings should be indented with spaces; unnecessary when running ruff-format + "E501", # Line length too long; unnecessary when running ruff-format + "W191", # Indentation contains tabs; unnecessary when running ruff-format + + # FIX AND REMOVE BELOW CODES + "ANN001", # Missing type annotation for function argument + "ANN102", # Missing type annotation for `cls` in classmethod + "ANN202", # Missing return type annotation for private function + "ANN401", # Dynamically typed expressions (typing.Any) are disallowed + "ARG001", # Unused function argument + "ARG002", # Unused method argument + "ARG005", # Unused lambda argument + "B007", # Loop control variable not used within loop body + "B008", # Do not perform function call in argument defaults + "B020", # Loop control variable overrides iterable it iterates + "B028", # No explicit `stacklevel` keyword argument for `warnings.warn` + "B904", # Within an `except` clause, raise exceptions with `raise ... from err` or `raise ... from None` + "D101", # Missing docstring in public class + "D102", # Missing docstring in public method + "D103", # Missing docstring in public function + "D106", # Missing docstring in public nested class + "D107", # Missing docstring in `__init__` + "D210", # No whitespaces allowed surrounding docstring text + "D401", # First line of docstring should be in imperative mood + "D404", # First word of the docstring should not be "This" + "E721", # Use `is` or `isinstance()` for type comparisons + "E722", # Do not use bare `except` + "E731", # Do not assign a `lambda` expression, use a `def` + "E741", # Ambiguous variable name + "ERA001", # Found commented-out code + "F403", # `from module import *` used; unable to detect undefined names + "F405", # import module may be undefined, or defined from star imports + "F841", # Local variable is assigned to but never used + "FIX002", # Line contains TODO, consider resolving the issue + "N801", # Class name should use CapWords convention + "N802", # Function name should be lowercase + "N811", # Constant imported as non-constant + "NPY002", # Replace legacy `np.random` call with `np.random.Generator` + "PD004", # `.notna` is preferred to `.notnull` + "PD010", # `.pivot_table` is preferred to `.pivot` or `.unstack` + "PD011", # Use `.to_numpy()` instead of `.values` + "PT009", # Use a regular `assert` instead of unittest-style assert + "PT018", # Assertion should be broken down into multiple parts + "PT027", # Use `pytest.raises` instead of unittest-style `assertRaises` or `assertRaisesRegex` + "RET503", # Missing explicit `return` at the end of function able to return non-`None` value + "RET504", # Unnecessary assignment to `result` before `return` statement + "RET505", # Unnecessary `elif` or `else` after `return` statement + "RET506", # Unnecessary `else` after `raise` statement + "SIM101", # Multiple `isinstance` calls, merge into a single call + "SIM102", # Use a single `if` statement instead of nested `if` statements + "SIM103", # Return the condition directly + "SIM109", # Use `in` for multiple equality comparisons + "SIM110", # Use `any(...)` instead of `for` loop for early return + "SIM117", # Use a single `with` statement with multiple contexts instead of nested `with` statements + "SIM118", # Use `key in dict` instead of `key in dict.keys()` + "SIM401", # Use `self.get(key, None)` instead of an `if` block + "UP008", # Use `super()` instead of `super(__class__, self)` + "UP028", # Replace `yield` over `for` loop with `yield from` + "UP031", # Use format specifiers instead of percent format +] + +[tool.ruff.lint.per-file-ignores] +"__init__.py" = ["F401"] diff --git a/setup.py b/setup.py index 16b7b60090..3ceee3406e 100644 --- a/setup.py +++ b/setup.py @@ -16,17 +16,15 @@ The widget is based on the template generated from jupyter-widget's widget-cookiecutter. """ -from distutils import log -from distutils import spawn + import os -from pathlib import Path import platform import subprocess import sys +from distutils import log, spawn +from pathlib import Path -from setuptools import Command -from setuptools import find_packages -from setuptools import setup +from setuptools import Command, find_packages, setup from setuptools.command.build_py import build_py as _build_py from setuptools.command.develop import develop as _develop from setuptools.command.egg_info import egg_info @@ -40,352 +38,359 @@ # Find the Protocol Compiler. -if 'PROTOC' in os.environ and os.path.exists(os.environ['PROTOC']): - protoc = os.environ['PROTOC'] -elif os.path.exists('../src/protoc'): - protoc = '../src/protoc' -elif os.path.exists('../src/protoc.exe'): - protoc = '../src/protoc.exe' -elif os.path.exists('../vsprojects/Debug/protoc.exe'): - protoc = '../vsprojects/Debug/protoc.exe' -elif os.path.exists('../vsprojects/Release/protoc.exe'): - protoc = '../vsprojects/Release/protoc.exe' +if "PROTOC" in os.environ and os.path.exists(os.environ["PROTOC"]): + protoc = os.environ["PROTOC"] +elif os.path.exists("../src/protoc"): + protoc = "../src/protoc" +elif os.path.exists("../src/protoc.exe"): + protoc = "../src/protoc.exe" +elif os.path.exists("../vsprojects/Debug/protoc.exe"): + protoc = "../vsprojects/Debug/protoc.exe" +elif os.path.exists("../vsprojects/Release/protoc.exe"): + protoc = "../vsprojects/Release/protoc.exe" else: - protoc = spawn.find_executable('protoc') + protoc = spawn.find_executable("protoc") # Get version from version module. -with open('tensorflow_model_analysis/version.py') as fp: - globals_dict = {} - exec(fp.read(), globals_dict) # pylint: disable=exec-used -__version__ = globals_dict['VERSION'] +with open("tensorflow_model_analysis/version.py") as fp: + globals_dict = {} + exec(fp.read(), globals_dict) # pylint: disable=exec-used +__version__ = globals_dict["VERSION"] here = os.path.dirname(os.path.abspath(__file__)) -node_root = os.path.join(here, 'tensorflow_model_analysis', 'notebook', - 'jupyter', 'js') -is_repo = os.path.exists(os.path.join(here, '.git')) +node_root = os.path.join(here, "tensorflow_model_analysis", "notebook", "jupyter", "js") +is_repo = os.path.exists(os.path.join(here, ".git")) -npm_path = os.pathsep.join([ - os.path.join(node_root, 'node_modules', '.bin'), - os.environ.get('PATH', os.defpath), -]) +npm_path = os.pathsep.join( + [ + os.path.join(node_root, "node_modules", ".bin"), + os.environ.get("PATH", os.defpath), + ] +) # Set this to true if ipywidgets js should be built. This would require nodejs. -build_js = os.environ.get('BUILD_JS') is not None +build_js = os.environ.get("BUILD_JS") is not None log.set_verbosity(log.DEBUG) -log.info('setup.py entered') -log.info('$PATH=%s' % os.environ['PATH']) +log.info("setup.py entered") +log.info("$PATH=%s" % os.environ["PATH"]) def generate_proto(source, require=True): - """Invokes the Protocol Compiler to generate a _pb2.py.""" - - # Does nothing if the output already exists and is newer than - # the input. - - if not require and not os.path.exists(source): - return - - output = source.replace('.proto', '_pb2.py').replace('../src/', '') - - if (not os.path.exists(output) or - (os.path.exists(source) and - os.path.getmtime(source) > os.path.getmtime(output))): - print('Generating %s...' % output) - - if not os.path.exists(source): - sys.stderr.write("Can't find required file: %s\n" % source) - sys.exit(-1) - - if protoc is None: - sys.stderr.write( - 'protoc is not installed nor found in ../src. Please compile it ' - 'or install the binary package.\n') - sys.exit(-1) - - protoc_command = [ - protoc, - '-I/usr/include', - '-I.', - '-I./tensorflow_model_analysis/proto', - '--python_out=.', - source, - ] - if subprocess.call(protoc_command) != 0: - sys.exit(-1) + """Invokes the Protocol Compiler to generate a _pb2.py.""" + # Does nothing if the output already exists and is newer than + # the input. + + if not require and not os.path.exists(source): + return + + output = source.replace(".proto", "_pb2.py").replace("../src/", "") + + if not os.path.exists(output) or ( + os.path.exists(source) and os.path.getmtime(source) > os.path.getmtime(output) + ): + print("Generating %s..." % output) + + if not os.path.exists(source): + sys.stderr.write("Can't find required file: %s\n" % source) + sys.exit(-1) + + if protoc is None: + sys.stderr.write( + "protoc is not installed nor found in ../src. Please compile it " + "or install the binary package.\n" + ) + sys.exit(-1) + + protoc_command = [ + protoc, + "-I/usr/include", + "-I.", + "-I./tensorflow_model_analysis/proto", + "--python_out=.", + source, + ] + if subprocess.call(protoc_command) != 0: + sys.exit(-1) def generate_tfma_protos(): - """Generate necessary .proto file if it doesn't exist.""" - generate_proto('tensorflow_model_analysis/proto/config.proto', False) - generate_proto('tensorflow_model_analysis/proto/metrics_for_slice.proto', - False) - generate_proto('tensorflow_model_analysis/proto/validation_result.proto', - False) + """Generate necessary .proto file if it doesn't exist.""" + generate_proto("tensorflow_model_analysis/proto/config.proto", False) + generate_proto("tensorflow_model_analysis/proto/metrics_for_slice.proto", False) + generate_proto("tensorflow_model_analysis/proto/validation_result.proto", False) class build_py(_build_py): # pylint: disable=invalid-name - """Build necessary dependencies.""" + """Build necessary dependencies.""" - def run(self): - generate_tfma_protos() - # _build_py is an old-style class, so super() doesn't work. - _build_py.run(self) + def run(self): + generate_tfma_protos() + # _build_py is an old-style class, so super() doesn't work. + _build_py.run(self) class develop(_develop): # pylint: disable=invalid-name - """Build necessary dependencies in develop mode.""" + """Build necessary dependencies in develop mode.""" - def run(self): - generate_tfma_protos() - _develop.run(self) + def run(self): + generate_tfma_protos() + _develop.run(self) def js_prerelease(command, strict=False): - """Decorator for building minified js/css prior to another command.""" + """Decorator for building minified js/css prior to another command.""" + + class DecoratedCommand(command): + """Decorated command.""" + + def run(self): + jsdeps = self.distribution.get_command_obj("jsdeps") + if not is_repo and all(os.path.exists(t) for t in jsdeps.targets): + # sdist, nothing to do + command.run(self) + return + + try: + self.distribution.run_command("jsdeps") + except Exception as e: # pylint: disable=broad-except + missing = [t for t in jsdeps.targets if not os.path.exists(t)] + if strict or missing: + log.warn("rebuilding js and css failed") + if missing: + log.error("missing files: %s" % missing) + raise e + else: + log.warn("rebuilding js and css failed (not a problem)") + log.warn(str(e)) + command.run(self) + update_package_data(self.distribution) + + return DecoratedCommand - class DecoratedCommand(command): - """Decorated command.""" - def run(self): - jsdeps = self.distribution.get_command_obj('jsdeps') - if not is_repo and all(os.path.exists(t) for t in jsdeps.targets): - # sdist, nothing to do - command.run(self) - return +def update_package_data(distribution): + """Update package_data to catch changes during setup.""" + build_py_cmd = distribution.get_command_obj("build_py") + # distribution.package_data = find_package_data() + # re-init build_py options which load package_data + build_py_cmd.finalize_options() - try: - self.distribution.run_command('jsdeps') - except Exception as e: # pylint: disable=broad-except - missing = [t for t in jsdeps.targets if not os.path.exists(t)] - if strict or missing: - log.warn('rebuilding js and css failed') - if missing: - log.error('missing files: %s' % missing) - raise e - else: - log.warn('rebuilding js and css failed (not a problem)') - log.warn(str(e)) - command.run(self) - update_package_data(self.distribution) - return DecoratedCommand +class NPM(Command): + """NPM builder. + Builds the js and css using npm. + """ -def update_package_data(distribution): - """update package_data to catch changes during setup.""" - build_py_cmd = distribution.get_command_obj('build_py') - # distribution.package_data = find_package_data() - # re-init build_py options which load package_data - build_py_cmd.finalize_options() + description = "install package.json dependencies using npm" + user_options = [] -class NPM(Command): - """NPM builder. - - Builds the js and css using npm. - """ - - description = 'install package.json dependencies using npm' - - user_options = [] - - node_modules = os.path.join(node_root, 'node_modules') - - targets = [ - os.path.join(here, 'tensorflow_model_analysis', 'static', 'extension.js'), - os.path.join(here, 'tensorflow_model_analysis', 'static', 'index.js'), - os.path.join(here, 'tensorflow_model_analysis', 'static', - 'vulcanized_tfma.js'), - ] - - def initialize_options(self): - pass - - def finalize_options(self): - pass - - def get_npm_name(self): - npm_name = 'npm' - if platform.system() == 'Windows': - npm_name = 'npm.cmd' - - return npm_name - - def has_npm(self): - npm_name = self.get_npm_name() - try: - subprocess.check_call([npm_name, '--version']) - return True - except: # pylint: disable=bare-except - return False - - def should_run_npm_install(self): - return self.has_npm() - - def run(self): - if not build_js: - return - - has_npm = self.has_npm() - if not has_npm: - log.error( - "`npm` unavailable. If you're running this command using sudo, make" - ' sure `npm` is available to sudo') - - env = os.environ.copy() - env['PATH'] = npm_path - - if self.should_run_npm_install(): - log.info( - 'Installing build dependencies with npm. This may take a while...') - npm_name = self.get_npm_name() - subprocess.check_call([npm_name, 'install'], - cwd=node_root, - stdout=sys.stdout, - stderr=sys.stderr) - os.utime(self.node_modules, None) - - for t in self.targets: - if not os.path.exists(t): - msg = 'Missing file: %s' % t + node_modules = os.path.join(node_root, "node_modules") + + targets = [ + os.path.join(here, "tensorflow_model_analysis", "static", "extension.js"), + os.path.join(here, "tensorflow_model_analysis", "static", "index.js"), + os.path.join(here, "tensorflow_model_analysis", "static", "vulcanized_tfma.js"), + ] + + def initialize_options(self): + pass + + def finalize_options(self): + pass + + def get_npm_name(self): + npm_name = "npm" + if platform.system() == "Windows": + npm_name = "npm.cmd" + + return npm_name + + def has_npm(self): + npm_name = self.get_npm_name() + try: + subprocess.check_call([npm_name, "--version"]) + return True + except: # pylint: disable=bare-except + return False + + def should_run_npm_install(self): + return self.has_npm() + + def run(self): + if not build_js: + return + + has_npm = self.has_npm() if not has_npm: - msg += ('\nnpm is required to build a development version of a widget' - ' extension') - raise ValueError(msg) + log.error( + "`npm` unavailable. If you're running this command using sudo, make" + " sure `npm` is available to sudo" + ) + + env = os.environ.copy() + env["PATH"] = npm_path + + if self.should_run_npm_install(): + log.info( + "Installing build dependencies with npm. This may take a while..." + ) + npm_name = self.get_npm_name() + subprocess.check_call( + [npm_name, "install"], + cwd=node_root, + stdout=sys.stdout, + stderr=sys.stderr, + ) + os.utime(self.node_modules, None) + + for t in self.targets: + if not os.path.exists(t): + msg = "Missing file: %s" % t + if not has_npm: + msg += ( + "\nnpm is required to build a development version of a widget" + " extension" + ) + raise ValueError(msg) + + # update package data in case this created new files + update_package_data(self.distribution) - # update package data in case this created new files - update_package_data(self.distribution) def _make_docs_packages(): - return [ - req for req in Path("./requirements-docs.txt") - .expanduser() - .resolve() - .read_text() - .splitlines() - if req - ] + return [ + req + for req in Path("./requirements-docs.txt") + .expanduser() + .resolve() + .read_text() + .splitlines() + if req + ] + def _make_extra_packages_tfjs(): - # Packages needed for tfjs. - return [ - 'tensorflowjs>=4.22.0,<5', - ] + # Packages needed for tfjs. + return [ + "tensorflowjs>=4.22.0,<5", + ] def select_constraint(default, nightly=None, git_master=None): - """Select dependency constraint based on TFX_DEPENDENCY_SELECTOR env var.""" - selector = os.environ.get('TFX_DEPENDENCY_SELECTOR') - if selector == 'UNCONSTRAINED': - return '' - elif selector == 'NIGHTLY' and nightly is not None: - return nightly - elif selector == 'GIT_MASTER' and git_master is not None: - return git_master - else: - return default + """Select dependency constraint based on TFX_DEPENDENCY_SELECTOR env var.""" + selector = os.environ.get("TFX_DEPENDENCY_SELECTOR") + if selector == "UNCONSTRAINED": + return "" + elif selector == "NIGHTLY" and nightly is not None: + return nightly + elif selector == "GIT_MASTER" and git_master is not None: + return git_master + else: + return default # Get the long description from the README file. -with open('README.md') as fp: - _LONG_DESCRIPTION = fp.read() +with open("README.md") as fp: + _LONG_DESCRIPTION = fp.read() setup_args = { - 'name': 'tensorflow_model_analysis', - 'version': __version__, - 'description': 'A library for analyzing TensorFlow models', - 'long_description': _LONG_DESCRIPTION, - 'long_description_content_type': 'text/markdown', - 'include_package_data': True, - 'data_files': [ + "name": "tensorflow_model_analysis", + "version": __version__, + "description": "A library for analyzing TensorFlow models", + "long_description": _LONG_DESCRIPTION, + "long_description_content_type": "text/markdown", + "include_package_data": True, + "data_files": [ ( - 'share/jupyter/nbextensions/tensorflow_model_analysis', + "share/jupyter/nbextensions/tensorflow_model_analysis", [ - 'tensorflow_model_analysis/static/extension.js', - 'tensorflow_model_analysis/static/index.js', - 'tensorflow_model_analysis/static/index.js.map', - 'tensorflow_model_analysis/static/vulcanized_tfma.js', + "tensorflow_model_analysis/static/extension.js", + "tensorflow_model_analysis/static/index.js", + "tensorflow_model_analysis/static/index.js.map", + "tensorflow_model_analysis/static/vulcanized_tfma.js", ], ), ], # Make sure to sync the versions of common dependencies (numpy, six, and # protobuf) with TF. - 'install_requires': [ + "install_requires": [ # Sort alphabetically - 'absl-py>=0.9,<2.0.0', + "absl-py>=0.9,<2.0.0", 'apache-beam[gcp]>=2.53,<3;python_version>="3.11"', 'apache-beam[gcp]>=2.50,<2.51;python_version<"3.11"', - 'ipython>=7,<8', - 'ipywidgets>=7,<8', - 'numpy>=1.23.5', - 'pandas>=1.0,<2', - 'pillow>=9.4.0', + "ipython>=7,<8", + "ipywidgets>=7,<8", + "numpy>=1.23.5", + "pandas>=1.0,<2", + "pillow>=9.4.0", 'protobuf>=4.25.2,<6.0.0;python_version>="3.11"', 'protobuf>=4.21.6,<6.0.0;python_version<"3.11"', - 'pyarrow>=10,<11', - 'rouge-score>=0.1.2,<2', - 'sacrebleu>=2.3,<4', - 'scipy>=1.4.1,<2', - 'six>=1.12,<2', - 'tensorflow>=2.17,<2.18', - 'tensorflow-estimator>=2.10', - 'tensorflow-metadata' + "pyarrow>=10,<11", + "rouge-score>=0.1.2,<2", + "sacrebleu>=2.3,<4", + "scipy>=1.4.1,<2", + "six>=1.12,<2", + "tensorflow>=2.17,<2.18", + "tensorflow-estimator>=2.10", + "tensorflow-metadata" + select_constraint( - default='>=1.17.1,<1.18.0', - nightly='>=1.18.0.dev', - git_master='@git+https://github.com/tensorflow/metadata@master', + default=">=1.17.1,<1.18.0", + nightly=">=1.18.0.dev", + git_master="@git+https://github.com/tensorflow/metadata@master", ), - 'tfx-bsl' + "tfx-bsl" + select_constraint( - default='>=1.17.1,<1.18.0', - nightly='>=1.18.0.dev', - git_master='@git+https://github.com/tensorflow/tfx-bsl@master', + default=">=1.17.1,<1.18.0", + nightly=">=1.18.0.dev", + git_master="@git+https://github.com/tensorflow/tfx-bsl@master", ), - 'tf-keras', + "tf-keras", ], - 'extras_require': { - 'all': [*_make_extra_packages_tfjs(), *_make_docs_packages()], - 'docs': _make_docs_packages(), + "extras_require": { + "all": [*_make_extra_packages_tfjs(), *_make_docs_packages()], + "docs": _make_docs_packages(), }, - 'python_requires': '>=3.9,<4', - 'packages': find_packages(), - 'zip_safe': False, - 'cmdclass': { - 'build_py': js_prerelease(build_py), - 'develop': js_prerelease(develop), - 'egg_info': js_prerelease(egg_info), - 'sdist': js_prerelease(sdist, strict=True), - 'jsdeps': NPM, + "python_requires": ">=3.9,<4", + "packages": find_packages(), + "zip_safe": False, + "cmdclass": { + "build_py": js_prerelease(build_py), + "develop": js_prerelease(develop), + "egg_info": js_prerelease(egg_info), + "sdist": js_prerelease(sdist, strict=True), + "jsdeps": NPM, }, - 'author': 'Google LLC', - 'author_email': 'tensorflow-extended-dev@googlegroups.com', - 'license': 'Apache 2.0', - 'classifiers': [ - 'Development Status :: 4 - Beta', - 'Intended Audience :: Developers', - 'Intended Audience :: Education', - 'Intended Audience :: Science/Research', - 'License :: OSI Approved :: Apache Software License', - 'Operating System :: OS Independent', - 'Programming Language :: Python', - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.9', - 'Programming Language :: Python :: 3.10', - 'Programming Language :: Python :: 3.11', - 'Programming Language :: Python :: 3 :: Only', - 'Topic :: Scientific/Engineering', - 'Topic :: Scientific/Engineering :: Mathematics', - 'Topic :: Scientific/Engineering :: Artificial Intelligence', - 'Topic :: Software Development', - 'Topic :: Software Development :: Libraries', - 'Topic :: Software Development :: Libraries :: Python Modules', + "author": "Google LLC", + "author_email": "tensorflow-extended-dev@googlegroups.com", + "license": "Apache 2.0", + "classifiers": [ + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "Intended Audience :: Education", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: Apache Software License", + "Operating System :: OS Independent", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3 :: Only", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Mathematics", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Python Modules", ], - 'namespace_packages': [], - 'requires': [], - 'keywords': 'tensorflow model analysis tfx', - 'url': 'https://www.tensorflow.org/tfx/model_analysis/get_started', - 'download_url': 'https://github.com/tensorflow/model-analysis/tags', + "namespace_packages": [], + "requires": [], + "keywords": "tensorflow model analysis tfx", + "url": "https://www.tensorflow.org/tfx/model_analysis/get_started", + "download_url": "https://github.com/tensorflow/model-analysis/tags", } setup(**setup_args) diff --git a/tensorflow_model_analysis/__init__.py b/tensorflow_model_analysis/__init__.py index f130fc578b..108c40b352 100644 --- a/tensorflow_model_analysis/__init__.py +++ b/tensorflow_model_analysis/__init__.py @@ -27,153 +27,161 @@ # pylint: disable=g-statement-before-imports # See b/148667210 for why the ImportError is ignored. try: - from tensorflow_model_analysis.sdk import * + # TODO(b/73882264): The orders should be kept in order to make benchmark on + # DataFlow work. We need to look into why the import orders matters for the + # DataFlow benchmark. + from tensorflow_model_analysis import ( + evaluators, + extractors, + metrics, + slicer, + utils, + validators, + view, + writers, + ) - # Allow api module types to be imported at the top-level since they are the - # main public interface to using TFMA. - from tensorflow_model_analysis.api import dataframe - from tensorflow_model_analysis.api.model_eval_lib import AttributionsForSlice - from tensorflow_model_analysis.api.model_eval_lib import analyze_raw_data - from tensorflow_model_analysis.api.model_eval_lib import BatchedInputsToExtracts - from tensorflow_model_analysis.api.model_eval_lib import default_eval_shared_model - from tensorflow_model_analysis.api.model_eval_lib import default_evaluators - from tensorflow_model_analysis.api.model_eval_lib import default_extractors - from tensorflow_model_analysis.api.model_eval_lib import default_writers - from tensorflow_model_analysis.api.model_eval_lib import ExtractAndEvaluate - from tensorflow_model_analysis.api.model_eval_lib import ExtractEvaluateAndWriteResults - from tensorflow_model_analysis.api.model_eval_lib import InputsToExtracts - from tensorflow_model_analysis.api.model_eval_lib import is_batched_input - from tensorflow_model_analysis.api.model_eval_lib import is_legacy_estimator - from tensorflow_model_analysis.api.model_eval_lib import load_attributions - from tensorflow_model_analysis.api.model_eval_lib import load_eval_result - from tensorflow_model_analysis.api.model_eval_lib import load_eval_results - from tensorflow_model_analysis.api.model_eval_lib import load_metrics - from tensorflow_model_analysis.api.model_eval_lib import load_plots - from tensorflow_model_analysis.api.model_eval_lib import load_validation_result - from tensorflow_model_analysis.api.model_eval_lib import make_eval_results - from tensorflow_model_analysis.api.model_eval_lib import MetricsForSlice - from tensorflow_model_analysis.api.model_eval_lib import multiple_data_analysis - from tensorflow_model_analysis.api.model_eval_lib import multiple_model_analysis - from tensorflow_model_analysis.api.model_eval_lib import PlotsForSlice - from tensorflow_model_analysis.api.model_eval_lib import run_model_analysis - from tensorflow_model_analysis.api.model_eval_lib import WriteResults - from tensorflow_model_analysis.api.model_eval_lib import ValidationResult - from tensorflow_model_analysis.api.verifier_lib import Validate + # Allow api module types to be imported at the top-level since they are the + # main public interface to using TFMA. + # TODO(b/228406044): Stop exposing tfma.types and migrate all internal users + # to use the top-level symbols exported below (e.g. tfma.Extracts). + from tensorflow_model_analysis.api import dataframe, types + from tensorflow_model_analysis.api.model_eval_lib import ( + AttributionsForSlice, + BatchedInputsToExtracts, + ExtractAndEvaluate, + ExtractEvaluateAndWriteResults, + InputsToExtracts, + MetricsForSlice, + PlotsForSlice, + ValidationResult, + WriteResults, + analyze_raw_data, + default_eval_shared_model, + default_evaluators, + default_extractors, + default_writers, + is_batched_input, + is_legacy_estimator, + load_attributions, + load_eval_result, + load_eval_results, + load_metrics, + load_plots, + load_validation_result, + make_eval_results, + multiple_data_analysis, + multiple_model_analysis, + run_model_analysis, + ) - # TODO(b/73882264): The orders should be kept in order to make benchmark on - # DataFlow work. We need to look into why the import orders matters for the - # DataFlow benchmark. - from tensorflow_model_analysis import extractors - from tensorflow_model_analysis import slicer - from tensorflow_model_analysis import validators - from tensorflow_model_analysis import evaluators - from tensorflow_model_analysis import metrics - from tensorflow_model_analysis import utils - from tensorflow_model_analysis import writers - from tensorflow_model_analysis import view - # TODO(b/228406044): Stop exposing tfma.types and migrate all internal users - # to use the top-level symbols exported below (e.g. tfma.Extracts). - from tensorflow_model_analysis.api import types + # Allow types to be imported at the top-level since they live in root dir. + # TODO(b/120222218): Remove after passing of native FPL supported. + # TODO(b/120222218): Remove after passing of native FPL supported. + from tensorflow_model_analysis.api.types import ( + AddMetricsCallbackType, + EvalSharedModel, + Extracts, + FeaturesPredictionsLabels, + MaterializedColumn, + MaybeMultipleEvalSharedModels, + ModelLoader, + RaggedTensorValue, + SparseTensorValue, + TensorType, + TensorTypeMaybeDict, + TensorValue, + VarLenTensorValue, + ) + from tensorflow_model_analysis.api.verifier_lib import Validate + from tensorflow_model_analysis.sdk import * - # TODO(b/171992041): Deprecate use of EvalResult in the future. - from tensorflow_model_analysis.view.view_types import EvalResult + # Import VERSION as __version__ for compatibility with other TFX components. + from tensorflow_model_analysis.version import VERSION as __version__ - # Allow types to be imported at the top-level since they live in root dir. - from tensorflow_model_analysis.api.types import AddMetricsCallbackType - from tensorflow_model_analysis.api.types import EvalSharedModel - from tensorflow_model_analysis.api.types import Extracts - # TODO(b/120222218): Remove after passing of native FPL supported. - from tensorflow_model_analysis.api.types import FeaturesPredictionsLabels - # TODO(b/120222218): Remove after passing of native FPL supported. - from tensorflow_model_analysis.api.types import MaterializedColumn - from tensorflow_model_analysis.api.types import MaybeMultipleEvalSharedModels - from tensorflow_model_analysis.api.types import ModelLoader - from tensorflow_model_analysis.api.types import RaggedTensorValue - from tensorflow_model_analysis.api.types import SparseTensorValue - from tensorflow_model_analysis.api.types import TensorType - from tensorflow_model_analysis.api.types import TensorTypeMaybeDict - from tensorflow_model_analysis.api.types import TensorValue - from tensorflow_model_analysis.api.types import VarLenTensorValue - - # Import VERSION as __version__ for compatibility with other TFX components. - from tensorflow_model_analysis.version import VERSION as __version__ + # TODO(b/171992041): Deprecate use of EvalResult in the future. + from tensorflow_model_analysis.view.view_types import EvalResult except ImportError as err: - import sys + import sys - sys.stderr.write('Error importing: {}'.format(err)) + sys.stderr.write(f"Error importing: {err}") # pylint: enable=g-statement-before-imports # pylint: enable=g-import-not-at-top + def _jupyter_nbextension_paths(): - return [{ - 'section': 'notebook', - 'src': 'static', - 'dest': 'tensorflow_model_analysis', - 'require': 'tensorflow_model_analysis/extension' - }] + return [ + { + "section": "notebook", + "src": "static", + "dest": "tensorflow_model_analysis", + "require": "tensorflow_model_analysis/extension", + } + ] + __all__ = [ - 'AddMetricsCallbackType', - 'AggregationOptions', - 'analyze_raw_data', - 'AttributionsForSlice', - 'BatchedInputsToExtracts', - 'BinarizationOptions', - 'ConfidenceIntervalOptions', - 'CrossSliceMetricThreshold', - 'CrossSliceMetricThresholds', - 'CrossSlicingSpec', - 'default_eval_shared_model', - 'default_evaluators', - 'default_extractors', - 'default_writers', - 'EvalConfig', - 'EvalResult', - 'EvalSharedModel', - 'ExampleWeightOptions', - 'ExtractAndEvaluate', - 'ExtractEvaluateAndWriteResults', - 'Extracts', - 'FeaturesPredictionsLabels', - 'GenericChangeThreshold', - 'GenericValueThreshold', - 'InputsToExtracts', - 'is_batched_input', - 'is_legacy_estimator', - 'load_attributions', - 'load_eval_result', - 'load_eval_results', - 'load_metrics', - 'load_plots', - 'load_validation_result', - 'make_eval_results', - 'MaterializedColumn', - 'MaybeMultipleEvalSharedModels', - 'MetricConfig', - 'MetricsForSlice', - 'MetricsSpec', - 'MetricThreshold', - 'ModelLoader', - 'ModelSpec', - 'multiple_data_analysis', - 'multiple_model_analysis', - 'Options', - 'PaddingOptions', - 'PerSliceMetricThreshold', - 'PerSliceMetricThresholds', - 'PlotsForSlice', - 'RaggedTensorValue', - 'RepeatedInt32Value', - 'RepeatedStringValue', - 'run_model_analysis', - 'SlicingSpec', - 'SparseTensorValue', - 'TensorType', - 'TensorTypeMaybeDict', - 'TensorValue', - 'Validate', - 'ValidationResult', - 'VarLenTensorValue', - 'WriteResults' + "AddMetricsCallbackType", + "AggregationOptions", + "analyze_raw_data", + "AttributionsForSlice", + "BatchedInputsToExtracts", + "BinarizationOptions", + "ConfidenceIntervalOptions", + "CrossSliceMetricThreshold", + "CrossSliceMetricThresholds", + "CrossSlicingSpec", + "default_eval_shared_model", + "default_evaluators", + "default_extractors", + "default_writers", + "EvalConfig", + "EvalResult", + "EvalSharedModel", + "ExampleWeightOptions", + "ExtractAndEvaluate", + "ExtractEvaluateAndWriteResults", + "Extracts", + "FeaturesPredictionsLabels", + "GenericChangeThreshold", + "GenericValueThreshold", + "InputsToExtracts", + "is_batched_input", + "is_legacy_estimator", + "load_attributions", + "load_eval_result", + "load_eval_results", + "load_metrics", + "load_plots", + "load_validation_result", + "make_eval_results", + "MaterializedColumn", + "MaybeMultipleEvalSharedModels", + "MetricConfig", + "MetricsForSlice", + "MetricsSpec", + "MetricThreshold", + "ModelLoader", + "ModelSpec", + "multiple_data_analysis", + "multiple_model_analysis", + "Options", + "PaddingOptions", + "PerSliceMetricThreshold", + "PerSliceMetricThresholds", + "PlotsForSlice", + "RaggedTensorValue", + "RepeatedInt32Value", + "RepeatedStringValue", + "run_model_analysis", + "SlicingSpec", + "SparseTensorValue", + "TensorType", + "TensorTypeMaybeDict", + "TensorValue", + "Validate", + "ValidationResult", + "VarLenTensorValue", + "WriteResults", ] diff --git a/tensorflow_model_analysis/api/dataframe.py b/tensorflow_model_analysis/api/dataframe.py index b98f2bad4d..474dacc7fb 100644 --- a/tensorflow_model_analysis/api/dataframe.py +++ b/tensorflow_model_analysis/api/dataframe.py @@ -20,151 +20,162 @@ import numpy as np import pandas as pd -from tensorflow_model_analysis.proto import metrics_for_slice_pb2 +from google.protobuf import descriptor, message, wrappers_pb2 -from google.protobuf import wrappers_pb2 -from google.protobuf import descriptor -from google.protobuf import message +from tensorflow_model_analysis.proto import metrics_for_slice_pb2 MetricsForSlice = metrics_for_slice_pb2.MetricsForSlice PlotsForSlice = metrics_for_slice_pb2.PlotsForSlice -_OVERALL = 'Overall' +_OVERALL = "Overall" # DataFrame output columns. -_METRIC_VALUES = 'metric_values' -_PLOT_DATA = 'plot_data' -_SLICE_STR = 'stringified_slices' -_SLICES = 'slices' -_METRIC_KEYS = 'metric_keys' -_PLOT_KEYS = 'plot_keys' +_METRIC_VALUES = "metric_values" +_PLOT_DATA = "plot_data" +_SLICE_STR = "stringified_slices" +_SLICES = "slices" +_METRIC_KEYS = "metric_keys" +_PLOT_KEYS = "plot_keys" @dataclasses.dataclass class _ColumnData: - metric_keys: Dict[str, List[Tuple[Any, int]]] = dataclasses.field( - default_factory=lambda: collections.defaultdict(list) - ) - values: Dict[str, List[Tuple[Any, int]]] = dataclasses.field( - default_factory=lambda: collections.defaultdict(list) - ) - slices: Dict[str, List[Tuple[Any, int]]] = dataclasses.field( - default_factory=lambda: collections.defaultdict(list) - ) + metric_keys: Dict[str, List[Tuple[Any, int]]] = dataclasses.field( + default_factory=lambda: collections.defaultdict(list) + ) + values: Dict[str, List[Tuple[Any, int]]] = dataclasses.field( + default_factory=lambda: collections.defaultdict(list) + ) + slices: Dict[str, List[Tuple[Any, int]]] = dataclasses.field( + default_factory=lambda: collections.defaultdict(list) + ) @dataclasses.dataclass(frozen=True) class MetricsDataFrames: - double_value: Optional[pd.DataFrame] = None - confusion_matrix_at_thresholds: Optional[pd.DataFrame] = None - multi_class_confusion_matrix_at_thresholds: Optional[pd.DataFrame] = None - bytes_value: Optional[pd.DataFrame] = None - array_value: Optional[pd.DataFrame] = None + double_value: Optional[pd.DataFrame] = None + confusion_matrix_at_thresholds: Optional[pd.DataFrame] = None + multi_class_confusion_matrix_at_thresholds: Optional[pd.DataFrame] = None + bytes_value: Optional[pd.DataFrame] = None + array_value: Optional[pd.DataFrame] = None @dataclasses.dataclass(frozen=True) class PlotsDataFrames: - calibration_histogram_buckets: Optional[pd.DataFrame] = None - confusion_matrix_at_thresholds: Optional[pd.DataFrame] = None - multi_class_confusion_matrix_at_thresholds: Optional[pd.DataFrame] = None - multi_label_confusion_matrix_at_thresholds: Optional[pd.DataFrame] = None - debug_message: Optional[pd.DataFrame] = None + calibration_histogram_buckets: Optional[pd.DataFrame] = None + confusion_matrix_at_thresholds: Optional[pd.DataFrame] = None + multi_class_confusion_matrix_at_thresholds: Optional[pd.DataFrame] = None + multi_label_confusion_matrix_at_thresholds: Optional[pd.DataFrame] = None + debug_message: Optional[pd.DataFrame] = None @dataclasses.dataclass(frozen=True) class _ColumnPrefixes: - slices: str - metric_keys: str - metric_values: str + slices: str + metric_keys: str + metric_values: str _metric_columns = _ColumnPrefixes(_SLICES, _METRIC_KEYS, _METRIC_VALUES) _plot_columns = _ColumnPrefixes(_SLICES, _PLOT_KEYS, _PLOT_DATA) -_WRAPPED_PRIMITIVES = (wrappers_pb2.DoubleValue, wrappers_pb2.FloatValue, - wrappers_pb2.BoolValue, wrappers_pb2.BytesValue, - wrappers_pb2.StringValue, wrappers_pb2.Int64Value, - wrappers_pb2.Int32Value) +_WRAPPED_PRIMITIVES = ( + wrappers_pb2.DoubleValue, + wrappers_pb2.FloatValue, + wrappers_pb2.BoolValue, + wrappers_pb2.BytesValue, + wrappers_pb2.StringValue, + wrappers_pb2.Int64Value, + wrappers_pb2.Int32Value, +) def _flatten_proto( root_field: message.Message, field_name: str, index: int, - include_empty_columns: bool = False) -> Iterator[Tuple[str, Any, int]]: - """Generates the leaf primitive fields by traversing the proto recursively. - - Traverses a proto and emits a tuple of the name, value, index at which the - value should be inserted. If include_empty_columns is True, unset fields are - also emitted with value of None. The index is the order of which this - primitive should be inserted to the DataFrame. - Note: that nested or misaligned repeated fields are not supported and will - lead to undefined behavior. - - Args: - root_field: The root message field. - field_name: The child field name under the root field where the traversal - begins. - index: The starting row index where the DataFrame was at. - include_empty_columns: If True, the unset fields are also emitted. - - Returns: - An iterator of the field name, field value, and index of the underlying - proto primitives. - """ - - def _is_repeated_field(parent, field_name): - return (parent.DESCRIPTOR.fields_by_name[field_name].label == - descriptor.FieldDescriptor.LABEL_REPEATED) - - def _flatten_proto_in( - parent: message.Message, - field_name: str, - field_value: Any, - index: int, - is_repeated_field: bool = False) -> Iterator[Tuple[str, Any, int]]: - if isinstance(field_value, message.Message): - # Test the message field is unset. - if not is_repeated_field and not parent.HasField(field_name): - if include_empty_columns: - yield (field_name, None, index) - return - elif isinstance(field_value, _WRAPPED_PRIMITIVES): - # Preserve the field_name of the wrapped primitives. - yield (field_name, field_value.value, index) - else: - for field in field_value.DESCRIPTOR.fields: - yield from _flatten_proto_in(field_value, field.name, - getattr(field_value, field.name), index) - # Handling repeated field. - elif _is_repeated_field(parent, field_name): - for i, single_field_value in enumerate(field_value): - yield from _flatten_proto_in( - parent, - field_name, - single_field_value, - index + i, - is_repeated_field=True) - # Python primitives. - else: - yield (field_name, field_value, index) + include_empty_columns: bool = False, +) -> Iterator[Tuple[str, Any, int]]: + """Generates the leaf primitive fields by traversing the proto recursively. + + Traverses a proto and emits a tuple of the name, value, index at which the + value should be inserted. If include_empty_columns is True, unset fields are + also emitted with value of None. The index is the order of which this + primitive should be inserted to the DataFrame. + Note: that nested or misaligned repeated fields are not supported and will + lead to undefined behavior. + + Args: + ---- + root_field: The root message field. + field_name: The child field name under the root field where the traversal + begins. + index: The starting row index where the DataFrame was at. + include_empty_columns: If True, the unset fields are also emitted. + + Returns: + ------- + An iterator of the field name, field value, and index of the underlying + proto primitives. + """ + + def _is_repeated_field(parent, field_name): + return ( + parent.DESCRIPTOR.fields_by_name[field_name].label + == descriptor.FieldDescriptor.LABEL_REPEATED + ) + + def _flatten_proto_in( + parent: message.Message, + field_name: str, + field_value: Any, + index: int, + is_repeated_field: bool = False, + ) -> Iterator[Tuple[str, Any, int]]: + if isinstance(field_value, message.Message): + # Test the message field is unset. + if not is_repeated_field and not parent.HasField(field_name): + if include_empty_columns: + yield (field_name, None, index) + return + elif isinstance(field_value, _WRAPPED_PRIMITIVES): + # Preserve the field_name of the wrapped primitives. + yield (field_name, field_value.value, index) + else: + for field in field_value.DESCRIPTOR.fields: + yield from _flatten_proto_in( + field_value, field.name, getattr(field_value, field.name), index + ) + # Handling repeated field. + elif _is_repeated_field(parent, field_name): + for i, single_field_value in enumerate(field_value): + yield from _flatten_proto_in( + parent, + field_name, + single_field_value, + index + i, + is_repeated_field=True, + ) + # Python primitives. + else: + yield (field_name, field_value, index) - field_value = getattr(root_field, field_name) - return _flatten_proto_in(root_field, field_name, field_value, index) + field_value = getattr(root_field, field_name) + return _flatten_proto_in(root_field, field_name, field_value, index) def _get_slice_value( - slice_key: metrics_for_slice_pb2.SingleSliceKey + slice_key: metrics_for_slice_pb2.SingleSliceKey, ) -> Union[float, bytes, int]: - """Determines the primitive value stored by the slice.""" - value_type = slice_key.WhichOneof('kind') - if value_type == 'float_value': - return slice_key.float_value - elif value_type == 'bytes_value': - return slice_key.bytes_value - elif value_type == 'int64_value': - return slice_key.int64_value - else: - raise NotImplementedError(f'{value_type} in {slice_key} is not supported.') + """Determines the primitive value stored by the slice.""" + value_type = slice_key.WhichOneof("kind") + if value_type == "float_value": + return slice_key.float_value + elif value_type == "bytes_value": + return slice_key.bytes_value + elif value_type == "int64_value": + return slice_key.int64_value + else: + raise NotImplementedError(f"{value_type} in {slice_key} is not supported.") def _to_dataframes( @@ -172,86 +183,99 @@ def _to_dataframes( column_prefixes: _ColumnPrefixes, include_empty_columns: bool = False, ) -> Dict[str, pd.DataFrame]: - """The implementation of loading TFMA metrics or plots as DataFrames. - - Args: - metrics_or_plots: an iterable of MetricsForSlice or PlotsForSlice. - column_prefixes: the column names of the first layer columns of the - multi-index columns. - include_empty_columns: if True, keeps all the unset fields with value set to - None. - - Returns: - A map of the DataFrame type to the corresponding DataFrame. - """ - slice_key_name = column_prefixes.slices - metric_key_name = column_prefixes.metric_keys - metric_value_name = column_prefixes.metric_values - column_data = collections.defaultdict(_ColumnData) - index = 0 - # For each slice. - for metrics_or_plots_for_slice in metrics_or_plots: - slices = [(single_slice_key.column, _get_slice_value(single_slice_key)) - for single_slice_key in - metrics_or_plots_for_slice.slice_key.single_slice_keys] - # For each metric inside the slice. - if isinstance(metrics_or_plots_for_slice, MetricsForSlice): - key_and_values = metrics_or_plots_for_slice.metric_keys_and_values - elif isinstance(metrics_or_plots_for_slice, PlotsForSlice): - key_and_values = metrics_or_plots_for_slice.plot_keys_and_values - else: - raise NotImplementedError( - f'{type(metrics_or_plots_for_slice)} is not supported.') - for key_and_value in key_and_values: - metric_value = key_and_value.value - for field in metric_value.DESCRIPTOR.fields: - metric_type = field.name - metric = getattr(metric_value, metric_type) - if (isinstance(metric, message.Message) and - not metric_value.HasField(metric_type) or not metric): - continue - # Flattens the metric_values. - # Initializes index_end to 'index - 1' to indicate that there is no item - # added yet. It is set to 'index' once the loop starts, indicating that - # at least one item is found. - index_end = index - 1 - for k, v, index_end in _flatten_proto(metric_value, metric_type, index): - column_data[metric_type].values[k].append((v, index_end)) - # index_end is later used as the exclusive range end, thus, the +1. - index_end += 1 - # Insert a column per leaf field in MetricKey. - for k, v, _ in _flatten_proto( - key_and_value, - 'key', - index, - include_empty_columns=include_empty_columns): - column_data[metric_type].metric_keys[k].extend([ - (v, i) for i in range(index, index_end) - ]) - # Insert each slice - if slices: - for slice_name, slice_value in slices: - column_data[metric_type].slices[slice_name].extend( - [slice_value, i] for i in range(index, index_end)) + """The implementation of loading TFMA metrics or plots as DataFrames. + + Args: + ---- + metrics_or_plots: an iterable of MetricsForSlice or PlotsForSlice. + column_prefixes: the column names of the first layer columns of the + multi-index columns. + include_empty_columns: if True, keeps all the unset fields with value set to + None. + + Returns: + ------- + A map of the DataFrame type to the corresponding DataFrame. + """ + slice_key_name = column_prefixes.slices + metric_key_name = column_prefixes.metric_keys + metric_value_name = column_prefixes.metric_values + column_data = collections.defaultdict(_ColumnData) + index = 0 + # For each slice. + for metrics_or_plots_for_slice in metrics_or_plots: + slices = [ + (single_slice_key.column, _get_slice_value(single_slice_key)) + for single_slice_key in metrics_or_plots_for_slice.slice_key.single_slice_keys + ] + # For each metric inside the slice. + if isinstance(metrics_or_plots_for_slice, MetricsForSlice): + key_and_values = metrics_or_plots_for_slice.metric_keys_and_values + elif isinstance(metrics_or_plots_for_slice, PlotsForSlice): + key_and_values = metrics_or_plots_for_slice.plot_keys_and_values else: - column_data[metric_type].slices[_OVERALL].extend( - ['', i] for i in range(index, index_end)) - index = index_end - dfs = {} - for metric_type, data in column_data.items(): - columns = pd.MultiIndex.from_tuples( - [(slice_key_name, key) for key in data.slices] - + [(metric_key_name, key) for key in data.metric_keys] - + [(metric_value_name, key) for key in data.values] - ) - all_data = itertools.chain(data.slices.values(), data.metric_keys.values(), - data.values.values()) - df = pd.DataFrame({ - column: pd.Series(*zip(*values)) - for column, values in zip(columns, all_data) - }) - dfs[metric_type] = df - return dfs + raise NotImplementedError( + f"{type(metrics_or_plots_for_slice)} is not supported." + ) + for key_and_value in key_and_values: + metric_value = key_and_value.value + for field in metric_value.DESCRIPTOR.fields: + metric_type = field.name + metric = getattr(metric_value, metric_type) + if ( + isinstance(metric, message.Message) + and not metric_value.HasField(metric_type) + or not metric + ): + continue + # Flattens the metric_values. + # Initializes index_end to 'index - 1' to indicate that there is no item + # added yet. It is set to 'index' once the loop starts, indicating that + # at least one item is found. + index_end = index - 1 + for k, v, index_end in _flatten_proto(metric_value, metric_type, index): + column_data[metric_type].values[k].append((v, index_end)) + # index_end is later used as the exclusive range end, thus, the +1. + index_end += 1 + # Insert a column per leaf field in MetricKey. + for k, v, _ in _flatten_proto( + key_and_value, + "key", + index, + include_empty_columns=include_empty_columns, + ): + column_data[metric_type].metric_keys[k].extend( + [(v, i) for i in range(index, index_end)] + ) + # Insert each slice + if slices: + for slice_name, slice_value in slices: + column_data[metric_type].slices[slice_name].extend( + [slice_value, i] for i in range(index, index_end) + ) + else: + column_data[metric_type].slices[_OVERALL].extend( + ["", i] for i in range(index, index_end) + ) + index = index_end + dfs = {} + for metric_type, data in column_data.items(): + columns = pd.MultiIndex.from_tuples( + [(slice_key_name, key) for key in data.slices] + + [(metric_key_name, key) for key in data.metric_keys] + + [(metric_value_name, key) for key in data.values] + ) + all_data = itertools.chain( + data.slices.values(), data.metric_keys.values(), data.values.values() + ) + df = pd.DataFrame( + { + column: pd.Series(*zip(*values)) + for column, values in zip(columns, all_data) + } + ) + dfs[metric_type] = df + return dfs def metrics_as_dataframes( @@ -259,84 +283,87 @@ def metrics_as_dataframes( *, # Keyword only args below. include_empty_columns: bool = False, ) -> MetricsDataFrames: - """Convert the deserialized MetricsForSlice protos to Pandas DataFrame. - - To load all metrics: - dfs = metrics_as_dataframe(tfma.load_metrics(output_path)) - then users can load "double_value" DataFrame as - dfs.double_value - and confusion_metrics_at_thresholds DataFrame as: - dfs.confusion_metrics_at_thresholds: - - For example, if the input metrics proto looks like this: - slice_key { - single_slice_keys { - column: "age" - float_value: 38.0 - } - } - metric_keys_and_values { - key { - name: "mean_absolute_error" - } - value { - double_value { - value: 0.1 + """Convert the deserialized MetricsForSlice protos to Pandas DataFrame. + + To load all metrics: + dfs = metrics_as_dataframe(tfma.load_metrics(output_path)) + then users can load "double_value" DataFrame as + dfs.double_value + and confusion_metrics_at_thresholds DataFrame as: + dfs.confusion_metrics_at_thresholds: + + For example, if the input metrics proto looks like this: + slice_key { + single_slice_keys { + column: "age" + float_value: 38.0 } } - } - metric_keys_and_values { - key { - name: "mean_squared_logarithmic_error" + metric_keys_and_values { + key { + name: "mean_absolute_error" + } + value { + double_value { + value: 0.1 + } + } } - value { - double_value { - value: 0.02 + metric_keys_and_values { + key { + name: "mean_squared_logarithmic_error" + } + value { + double_value { + value: 0.02 + } } } - } - - The corresponding output table will look like this: - - | | metric_values| metric_keys | slices | - | | double_value | name | age | sex | - |---:|-------------:|:-------------------------------|----------:|:---------| - | 0 | 0.1 | mean_absolute_error | 38 | male | - | 1 | 0.02 | mean_squared_logarithmic_error | 38 | female | - - One typical use of this DataFrame table is to re-organize it in the form of - slices vs. metrics table. For single model single output: - `auto_pivot(dfs.double_value)` - This will pivot on the non-unique "metric_keys.*" columns. - - Args: - metrics: The directory path of the metrics file or an iterable of - MetricsForSlice proto. - include_empty_columns: Include a column if its value is not empty (None) in - corresponding field in the MetricKey. - - Returns: - A DataFrame with the following columns if the value is not None: - * slices.: feature columns derived from - tfma.proto.metrics_for_slice_pb2.SliceKey. There are multiple feature - columns if there is any feature-cross, e.g., slice.featureA and - slice.featureB if there is FeatureA and FeatureB cross. - * metric_keys.: recursively flattened items in - metric_key: name, model_name, output_name, is_diff, example_weighted, - flattened sub_key (class_id, k, top_k), flattened aggregation_type - (micro_average, macro_average, weighted_macro_average). - * metric_values.: column(s) of metric_value(s); - metric_value specified by metric_type and flattened to its leaf - primitives. E.g., - 'double_value' if metric_type is 'double_value'; - 'true_positives', 'false_negatives', etc. if metric_type is - 'confusion_matrix_at_thresholds'. - """ - dfs = _to_dataframes( - metrics, - column_prefixes=_metric_columns, - include_empty_columns=include_empty_columns) - return MetricsDataFrames(**dfs) + + The corresponding output table will look like this: + + | | metric_values| metric_keys | slices | + | | double_value | name | age | sex | + |---:|-------------:|:-------------------------------|----------:|:---------| + | 0 | 0.1 | mean_absolute_error | 38 | male | + | 1 | 0.02 | mean_squared_logarithmic_error | 38 | female | + + One typical use of this DataFrame table is to re-organize it in the form of + slices vs. metrics table. For single model single output: + `auto_pivot(dfs.double_value)` + This will pivot on the non-unique "metric_keys.*" columns. + + Args: + ---- + metrics: The directory path of the metrics file or an iterable of + MetricsForSlice proto. + include_empty_columns: Include a column if its value is not empty (None) in + corresponding field in the MetricKey. + + Returns: + ------- + A DataFrame with the following columns if the value is not None: + * slices.: feature columns derived from + tfma.proto.metrics_for_slice_pb2.SliceKey. There are multiple feature + columns if there is any feature-cross, e.g., slice.featureA and + slice.featureB if there is FeatureA and FeatureB cross. + * metric_keys.: recursively flattened items in + metric_key: name, model_name, output_name, is_diff, example_weighted, + flattened sub_key (class_id, k, top_k), flattened aggregation_type + (micro_average, macro_average, weighted_macro_average). + * metric_values.: column(s) of metric_value(s); + metric_value specified by metric_type and flattened to its leaf + primitives. E.g., + 'double_value' if metric_type is 'double_value'; + 'true_positives', 'false_negatives', etc. if metric_type is + 'confusion_matrix_at_thresholds'. + """ + dfs = _to_dataframes( + metrics, + column_prefixes=_metric_columns, + include_empty_columns=include_empty_columns, + ) + return MetricsDataFrames(**dfs) def plots_as_dataframes( @@ -344,215 +371,230 @@ def plots_as_dataframes( *, # Keyword only args below. include_empty_columns: bool = False, ) -> PlotsDataFrames: - """Read and deserialize the PlotsForSlice records as Pandas DataFrame. + """Read and deserialize the PlotsForSlice records as Pandas DataFrame. - To load confusion_matrix_at_thresholds: - df = plots_as_dataframes(tfma.load_plots(eval_path)) + To load confusion_matrix_at_thresholds: + df = plots_as_dataframes(tfma.load_plots(eval_path)) - For example, if the input plots_for_slice proto looks like this: - slice_key { - single_slice_keys { - column: "age" - float_value: 38.0 + For example, if the input plots_for_slice proto looks like this: + slice_key { + single_slice_keys { + column: "age" + float_value: 38.0 + } } - } - plot_keys_and_values { - key {} - value { - confusion_matrix_at_thresholds { - matrices { - threshold: 0.5 - false_negatives: 10 - true_negatives: 10 - false_positives: 10 - true_positives: 10 - precision: 0.9 - recall: 0.8 - } - matrices { - threshold: 0.5 - false_negatives: 10 - true_negatives: 10 - false_positives: 10 - true_positives: 10 - precision: 0.9 - recall: 0.8 + plot_keys_and_values { + key {} + value { + confusion_matrix_at_thresholds { + matrices { + threshold: 0.5 + false_negatives: 10 + true_negatives: 10 + false_positives: 10 + true_positives: 10 + precision: 0.9 + recall: 0.8 + } + matrices { + threshold: 0.5 + false_negatives: 10 + true_negatives: 10 + false_positives: 10 + true_positives: 10 + precision: 0.9 + recall: 0.8 + } } } } - } - - The corresponding output table will look like this: - - | plot_data | plot_keys | slices | - | threshold | false_negatives | ... | is_diff | ... | age | - |----------:|----------------:|----:|:---------------------| ... |:----------| - | 0.5 | 10 | ... | False | ... | 38 | - | 0.5 | 10 | ... | False | ... | 38 | - # pylint: enable=line-too-long - - One typical use of this DataFrame table is to re-organize it in the form of - slices vs. plots table. - E.g., for single model single output: - result.pivot(index='slice', columns='name', values='plot_data'). - This only works when there is one unique value as the pivot values. - Otherewise, a user needs to specify more columns or indices to make sure that - the metric value is unique per column and per index. - E.g., for single model and multiple outputs: - result.pivot(index='slice', columns=['output_name', 'name'], - values='plot_data'). - - Args: - plots: a path to the evaluation directory or an iterable of PlotsForSlice - proto. - include_empty_columns: include a column if its value is not empty (None) in - corresponding field in the MetricKey. - - Returns: - A DataFrame with the following columns if the value is not None: - * slice_str: the string representation of the slice. - E.g., "age: 10; sex: male". - * : feature columns derived from - tfma.proto.metrics_for_slice_pb2.SliceKey. There are multiple feature - columns if there is any feature-cross, e.g., slice.featureA and - slice.featureB if there is FeatureA and FeatureB cross. - * recursively flattened items in metric_key: name, model_name, output_name, - is_diff, example_weighted, flattened sub_key (class_id, k, top_k), - flattened aggregation_type (micro_average, macro_average, - weighted_macro_average). - * multiple columns of plot_values: plot_value specified by plot_type - and flattened to its leaf primitives. E.g., there will be columns - false_negatives, false_positives, true_negatives, true_positives, - precision, recall when plot_type is `confusion_matrix_at_thresholds`. - """ - dfs = _to_dataframes( - plots, - column_prefixes=_plot_columns, - include_empty_columns=include_empty_columns) - return PlotsDataFrames(**dfs) + + The corresponding output table will look like this: + + | plot_data | plot_keys | slices | + | threshold | false_negatives | ... | is_diff | ... | age | + |----------:|----------------:|----:|:---------------------| ... |:----------| + | 0.5 | 10 | ... | False | ... | 38 | + | 0.5 | 10 | ... | False | ... | 38 | + # pylint: enable=line-too-long + + One typical use of this DataFrame table is to re-organize it in the form of + slices vs. plots table. + E.g., for single model single output: + result.pivot(index='slice', columns='name', values='plot_data'). + This only works when there is one unique value as the pivot values. + Otherewise, a user needs to specify more columns or indices to make sure that + the metric value is unique per column and per index. + E.g., for single model and multiple outputs: + result.pivot(index='slice', columns=['output_name', 'name'], + values='plot_data'). + + Args: + ---- + plots: a path to the evaluation directory or an iterable of PlotsForSlice + proto. + include_empty_columns: include a column if its value is not empty (None) in + corresponding field in the MetricKey. + + Returns: + ------- + A DataFrame with the following columns if the value is not None: + * slice_str: the string representation of the slice. + E.g., "age: 10; sex: male". + * : feature columns derived from + tfma.proto.metrics_for_slice_pb2.SliceKey. There are multiple feature + columns if there is any feature-cross, e.g., slice.featureA and + slice.featureB if there is FeatureA and FeatureB cross. + * recursively flattened items in metric_key: name, model_name, output_name, + is_diff, example_weighted, flattened sub_key (class_id, k, top_k), + flattened aggregation_type (micro_average, macro_average, + weighted_macro_average). + * multiple columns of plot_values: plot_value specified by plot_type + and flattened to its leaf primitives. E.g., there will be columns + false_negatives, false_positives, true_negatives, true_positives, + precision, recall when plot_type is `confusion_matrix_at_thresholds`. + """ + dfs = _to_dataframes( + plots, + column_prefixes=_plot_columns, + include_empty_columns=include_empty_columns, + ) + return PlotsDataFrames(**dfs) def _collapse_column_names(columns: pd.MultiIndex) -> pd.Index: - """Reduce multi-index column names by removing layers with the same value.""" - dropables = [i for i, x in enumerate(zip(*columns)) if len(set(x)) == 1] - return columns.droplevel(dropables) - - -def _stringify_slices(df: pd.DataFrame, - slice_key_name: str = _SLICES, - drop_slices: bool = True) -> pd.DataFrame: - """Stringify all the slice columns into one column. - - For example, if there are two slice column of 'sex' and 'age', the function - takes all the slice columns, convert them to string in the format of - : and concatenate all slice columns using '; '. The - final result looks like: - ':; :.' - - Args: - df: a metrics DataFrame or plots DataFrame. - slice_key_name: the first level column name that groups all slice columns. - drop_slices: if True, drops the original slice columns. - - Returns: - A DataFrame with all slice columns stringified and concatenated. - """ - - def _concatenate(x): - return '; '.join( - f'{col}:{val}' for col, val in zip(df[slice_key_name].columns, x) - if pd.notnull(val)) - - t = df.slices.agg(_concatenate, axis=1) - df = df.drop(slice_key_name, axis=1, level=0) if drop_slices else df.copy() - df.insert(0, (slice_key_name, _SLICE_STR), t) - return df.sort_values((slice_key_name, _SLICE_STR)) - - -def _auto_pivot(df: pd.DataFrame, column_prefixes: _ColumnPrefixes, - stringify_slices: bool, - collapse_column_names: bool) -> pd.DataFrame: - """Implements auto_pivot.""" - df = _stringify_slices(df, column_prefixes.slices) if stringify_slices else df - # TODO(b/277280388): Use Series.sort_values after upgrading to pandas > 1.2. - # See https://github.com/pandas-dev/pandas/issues/35922 - # Replace the df_unique logic with the following block. - # df_unique = ( - # df[column_prefixes.metric_keys] - # .nunique(dropna=False) - # .sort_index() - # .sort_values(ascending=False, kind='stable') - # ) - df_unique = df[column_prefixes.metric_keys].nunique(dropna=False).sort_index() - _, tags = np.unique(df_unique.values, return_inverse=True) - df_unique = df_unique.iloc[np.argsort(-tags, kind='stable')] - - pivot_columns = [(column_prefixes.metric_keys, column) - for column, nunique in df_unique.items() - if nunique > 1] - metric_value_columns = df[column_prefixes.metric_values].columns - slice_columns = df[column_prefixes.slices].columns - value_columns = [ - (column_prefixes.metric_values, c) for c in metric_value_columns - ] - index_columns = [(column_prefixes.slices, c) for c in slice_columns] - result = df.pivot( - index=index_columns, columns=pivot_columns, values=value_columns) - if stringify_slices: - result.index.name = column_prefixes.slices - if collapse_column_names and isinstance(result.columns, pd.MultiIndex): - result.columns = _collapse_column_names(result.columns) - return result - - -def auto_pivot(df: pd.DataFrame, - stringify_slices: bool = True, - collapse_column_names: bool = True) -> pd.DataFrame: - """Automatically pivots a metric or plots DataFrame. - - Given a DataFrame provided by metrics/plots_as_dataframes, one can - automatically pivot the table on all non-unique metric_key columns, with the - slices as the index columns, metric_values as the values columns. - E.g., given this raw DataFrame: - - | | metric_values| metric_keys | slices | - | | double_value | name | age | sex | - |---:|-------------:|:-------------------------------|----------:|:---------| - | 0 | 0.1 | mean_absolute_error | nan | nan | - | 1 | 0.02 | mean_squared_logarithmic_error | 38 | male | - - Since the only non-unique metric_key column is metric_keys.name, auto_pivot - with stringify_slices set to True will generates the following DataFrame: - - | slices | 'mean_absolute_error' | 'mean_squared_logarithmic_error' | - |:---------------|---------------------------------------------------------:| - |Overall | 0.1 | nan | - |sex:male; age:38| nan | 0.02 | - - Args: - df: a DataFrame from one of the MetricsDataFrames or PlotsDataFrames. - stringify_slices: stringify all the slice columns and collapse them into one - column by concatenating the corresponding strings. This is turned on by - default. - collapse_column_names: collapsing the multi-index column names by removing - layer(s) with only the same string. This is turned on by default. - Returns: - A DataFrame that pivoted from the metrics DataFrame or plots DataFrame. - """ - if _SLICES in df: - if _PLOT_KEYS in df and _PLOT_DATA in df: - return _auto_pivot( - df, - _plot_columns, - stringify_slices=stringify_slices, - collapse_column_names=collapse_column_names) - elif _METRIC_KEYS in df and _METRIC_VALUES in df: - return _auto_pivot( - df, - _metric_columns, - stringify_slices=stringify_slices, - collapse_column_names=collapse_column_names) - - raise NotImplementedError( - 'Only a metrics or a plots DataFrame is supported. This DataFrame has' - f'the following columns: {df.columns}') + """Reduce multi-index column names by removing layers with the same value.""" + dropables = [i for i, x in enumerate(zip(*columns)) if len(set(x)) == 1] + return columns.droplevel(dropables) + + +def _stringify_slices( + df: pd.DataFrame, slice_key_name: str = _SLICES, drop_slices: bool = True +) -> pd.DataFrame: + """Stringify all the slice columns into one column. + + For example, if there are two slice column of 'sex' and 'age', the function + takes all the slice columns, convert them to string in the format of + : and concatenate all slice columns using '; '. The + final result looks like: + ':; :.' + + Args: + ---- + df: a metrics DataFrame or plots DataFrame. + slice_key_name: the first level column name that groups all slice columns. + drop_slices: if True, drops the original slice columns. + + Returns: + ------- + A DataFrame with all slice columns stringified and concatenated. + """ + + def _concatenate(x): + return "; ".join( + f"{col}:{val}" + for col, val in zip(df[slice_key_name].columns, x) + if pd.notnull(val) + ) + + t = df.slices.agg(_concatenate, axis=1) + df = df.drop(slice_key_name, axis=1, level=0) if drop_slices else df.copy() + df.insert(0, (slice_key_name, _SLICE_STR), t) + return df.sort_values((slice_key_name, _SLICE_STR)) + + +def _auto_pivot( + df: pd.DataFrame, + column_prefixes: _ColumnPrefixes, + stringify_slices: bool, + collapse_column_names: bool, +) -> pd.DataFrame: + """Implements auto_pivot.""" + df = _stringify_slices(df, column_prefixes.slices) if stringify_slices else df + # TODO(b/277280388): Use Series.sort_values after upgrading to pandas > 1.2. + # See https://github.com/pandas-dev/pandas/issues/35922 + # Replace the df_unique logic with the following block. + # df_unique = ( + # df[column_prefixes.metric_keys] + # .nunique(dropna=False) + # .sort_index() + # .sort_values(ascending=False, kind='stable') + # ) + df_unique = df[column_prefixes.metric_keys].nunique(dropna=False).sort_index() + _, tags = np.unique(df_unique.values, return_inverse=True) + df_unique = df_unique.iloc[np.argsort(-tags, kind="stable")] + + pivot_columns = [ + (column_prefixes.metric_keys, column) + for column, nunique in df_unique.items() + if nunique > 1 + ] + metric_value_columns = df[column_prefixes.metric_values].columns + slice_columns = df[column_prefixes.slices].columns + value_columns = [(column_prefixes.metric_values, c) for c in metric_value_columns] + index_columns = [(column_prefixes.slices, c) for c in slice_columns] + result = df.pivot(index=index_columns, columns=pivot_columns, values=value_columns) + if stringify_slices: + result.index.name = column_prefixes.slices + if collapse_column_names and isinstance(result.columns, pd.MultiIndex): + result.columns = _collapse_column_names(result.columns) + return result + + +def auto_pivot( + df: pd.DataFrame, stringify_slices: bool = True, collapse_column_names: bool = True +) -> pd.DataFrame: + """Automatically pivots a metric or plots DataFrame. + + Given a DataFrame provided by metrics/plots_as_dataframes, one can + automatically pivot the table on all non-unique metric_key columns, with the + slices as the index columns, metric_values as the values columns. + E.g., given this raw DataFrame: + + | | metric_values| metric_keys | slices | + | | double_value | name | age | sex | + |---:|-------------:|:-------------------------------|----------:|:---------| + | 0 | 0.1 | mean_absolute_error | nan | nan | + | 1 | 0.02 | mean_squared_logarithmic_error | 38 | male | + + Since the only non-unique metric_key column is metric_keys.name, auto_pivot + with stringify_slices set to True will generates the following DataFrame: + + | slices | 'mean_absolute_error' | 'mean_squared_logarithmic_error' | + |:---------------|---------------------------------------------------------:| + |Overall | 0.1 | nan | + |sex:male; age:38| nan | 0.02 | + + Args: + ---- + df: a DataFrame from one of the MetricsDataFrames or PlotsDataFrames. + stringify_slices: stringify all the slice columns and collapse them into one + column by concatenating the corresponding strings. This is turned on by + default. + collapse_column_names: collapsing the multi-index column names by removing + layer(s) with only the same string. This is turned on by default. + + Returns: + ------- + A DataFrame that pivoted from the metrics DataFrame or plots DataFrame. + """ + if _SLICES in df: + if _PLOT_KEYS in df and _PLOT_DATA in df: + return _auto_pivot( + df, + _plot_columns, + stringify_slices=stringify_slices, + collapse_column_names=collapse_column_names, + ) + elif _METRIC_KEYS in df and _METRIC_VALUES in df: + return _auto_pivot( + df, + _metric_columns, + stringify_slices=stringify_slices, + collapse_column_names=collapse_column_names, + ) + + raise NotImplementedError( + "Only a metrics or a plots DataFrame is supported. This DataFrame has" + f"the following columns: {df.columns}" + ) diff --git a/tensorflow_model_analysis/api/dataframe_test.py b/tensorflow_model_analysis/api/dataframe_test.py index 26d8434562..b9c37f6b15 100644 --- a/tensorflow_model_analysis/api/dataframe_test.py +++ b/tensorflow_model_analysis/api/dataframe_test.py @@ -17,19 +17,18 @@ import numpy as np import pandas as pd import tensorflow as tf +from google.protobuf import text_format + from tensorflow_model_analysis.experimental import dataframe from tensorflow_model_analysis.proto import metrics_for_slice_pb2 -from google.protobuf import text_format - class MetricsAsDataFrameTest(tf.test.TestCase): - - def setUp(self): - super().setUp() - self.metrics_for_slices = [ - text_format.Parse( - """ + def setUp(self): + super().setUp() + self.metrics_for_slices = [ + text_format.Parse( + """ slice_key { single_slice_keys { column: "age" @@ -64,9 +63,11 @@ def setUp(self): } } } - """, metrics_for_slice_pb2.MetricsForSlice()), - text_format.Parse( - """ + """, + metrics_for_slice_pb2.MetricsForSlice(), + ), + text_format.Parse( + """ slice_key {} metric_keys_and_values { key { @@ -78,12 +79,14 @@ def setUp(self): } } } - """, metrics_for_slice_pb2.MetricsForSlice()) - ] + """, + metrics_for_slice_pb2.MetricsForSlice(), + ), + ] - self.metrics_overall_slice_only = [ - text_format.Parse( - """ + self.metrics_overall_slice_only = [ + text_format.Parse( + """ slice_key {} metric_keys_and_values { key { @@ -105,66 +108,78 @@ def setUp(self): } } } - """, metrics_for_slice_pb2.MetricsForSlice()) - ] + """, + metrics_for_slice_pb2.MetricsForSlice(), + ) + ] - def testLoadMetricsAsDataFrame_DoubleValueOnly(self): - dfs = dataframe.metrics_as_dataframes(self.metrics_for_slices) + def testLoadMetricsAsDataFrame_DoubleValueOnly(self): + dfs = dataframe.metrics_as_dataframes(self.metrics_for_slices) - expected = pd.DataFrame({ - ('slices', 'age'): [38.0, 38.0, None], - ('slices', 'sex'): [b'Female', b'Female', None], - ('slices', 'Overall'): [None, None, ''], - ('metric_keys', 'name'): [ - 'mean_absolute_error', 'mean_squared_logarithmic_error', - 'mean_absolute_error' - ], - ('metric_keys', 'model_name'): ['', '', ''], - ('metric_keys', 'output_name'): ['', '', ''], - ('metric_keys', 'example_weighted'): [False, False, None], - ('metric_keys', 'is_diff'): [False, False, False], - ('metric_values', 'double_value'): [0.1, 0.02, 0.3], - }) - pd.testing.assert_frame_equal(expected, dfs.double_value) + expected = pd.DataFrame( + { + ("slices", "age"): [38.0, 38.0, None], + ("slices", "sex"): [b"Female", b"Female", None], + ("slices", "Overall"): [None, None, ""], + ("metric_keys", "name"): [ + "mean_absolute_error", + "mean_squared_logarithmic_error", + "mean_absolute_error", + ], + ("metric_keys", "model_name"): ["", "", ""], + ("metric_keys", "output_name"): ["", "", ""], + ("metric_keys", "example_weighted"): [False, False, None], + ("metric_keys", "is_diff"): [False, False, False], + ("metric_values", "double_value"): [0.1, 0.02, 0.3], + } + ) + pd.testing.assert_frame_equal(expected, dfs.double_value) - def testLoadMetricsAsDataFrame_DoubleValueIncludeEmptyColumn(self): - dfs = dataframe.metrics_as_dataframes( - self.metrics_for_slices, include_empty_columns=True) - expected = pd.DataFrame({ - ('slices', 'age'): [38.0, 38.0, None], - ('slices', 'sex'): [b'Female', b'Female', None], - ('slices', 'Overall'): [None, None, ''], - ('metric_keys', 'name'): [ - 'mean_absolute_error', 'mean_squared_logarithmic_error', - 'mean_absolute_error' - ], - ('metric_keys', 'model_name'): ['', '', ''], - ('metric_keys', 'output_name'): ['', '', ''], - ('metric_keys', 'sub_key'): [None, None, None], - ('metric_keys', 'aggregation_type'): [None, None, None], - ('metric_keys', 'example_weighted'): [False, False, None], - ('metric_keys', 'is_diff'): [False, False, False], - ('metric_values', 'double_value'): [0.1, 0.02, 0.3], - }) - pd.testing.assert_frame_equal(expected, dfs.double_value) + def testLoadMetricsAsDataFrame_DoubleValueIncludeEmptyColumn(self): + dfs = dataframe.metrics_as_dataframes( + self.metrics_for_slices, include_empty_columns=True + ) + expected = pd.DataFrame( + { + ("slices", "age"): [38.0, 38.0, None], + ("slices", "sex"): [b"Female", b"Female", None], + ("slices", "Overall"): [None, None, ""], + ("metric_keys", "name"): [ + "mean_absolute_error", + "mean_squared_logarithmic_error", + "mean_absolute_error", + ], + ("metric_keys", "model_name"): ["", "", ""], + ("metric_keys", "output_name"): ["", "", ""], + ("metric_keys", "sub_key"): [None, None, None], + ("metric_keys", "aggregation_type"): [None, None, None], + ("metric_keys", "example_weighted"): [False, False, None], + ("metric_keys", "is_diff"): [False, False, False], + ("metric_values", "double_value"): [0.1, 0.02, 0.3], + } + ) + pd.testing.assert_frame_equal(expected, dfs.double_value) - def testLoadMetricsAsDataFrame_DoubleValueOverallSliceOnly(self): - dfs = dataframe.metrics_as_dataframes( - self.metrics_overall_slice_only, include_empty_columns=False) - expected = pd.DataFrame({ - ('slices', 'Overall'): ['', ''], - ('metric_keys', 'name'): ['mean_absolute_error', 'example_count'], - ('metric_keys', 'model_name'): ['', ''], - ('metric_keys', 'output_name'): ['', ''], - ('metric_keys', 'is_diff'): [False, False], - ('metric_values', 'double_value'): [0.3, 10], - }) - pd.testing.assert_frame_equal(expected, dfs.double_value) + def testLoadMetricsAsDataFrame_DoubleValueOverallSliceOnly(self): + dfs = dataframe.metrics_as_dataframes( + self.metrics_overall_slice_only, include_empty_columns=False + ) + expected = pd.DataFrame( + { + ("slices", "Overall"): ["", ""], + ("metric_keys", "name"): ["mean_absolute_error", "example_count"], + ("metric_keys", "model_name"): ["", ""], + ("metric_keys", "output_name"): ["", ""], + ("metric_keys", "is_diff"): [False, False], + ("metric_values", "double_value"): [0.3, 10], + } + ) + pd.testing.assert_frame_equal(expected, dfs.double_value) - def testLoadMetricsAsDataFrame_Empty(self): - metrics_for_slices = [ - text_format.Parse( - """ + def testLoadMetricsAsDataFrame_Empty(self): + metrics_for_slices = [ + text_format.Parse( + """ slice_key { single_slice_keys { column: "age" @@ -175,154 +190,167 @@ def testLoadMetricsAsDataFrame_Empty(self): bytes_value: "Female" } } - """, metrics_for_slice_pb2.MetricsForSlice()), - ] - dfs = dataframe.metrics_as_dataframes(metrics_for_slices) - self.assertTrue(all(d is None for d in dataclasses.astuple(dfs))) + """, + metrics_for_slice_pb2.MetricsForSlice(), + ), + ] + dfs = dataframe.metrics_as_dataframes(metrics_for_slices) + self.assertTrue(all(d is None for d in dataclasses.astuple(dfs))) - def testAutoPivot_MetricsDataFrame(self): - df = pd.DataFrame({ - ('slices', 'age'): [38.0, 38.0, None], - ('slices', 'sex'): [b'Female', b'Female', None], - ('metric_keys', 'name'): [ - 'mean_absolute_error', 'mean_squared_logarithmic_error', - 'mean_absolute_error' - ], - ('metric_keys', 'model_name'): ['', '', ''], - ('metric_keys', 'output_name'): ['', '', ''], - ('metric_keys', 'example_weighted'): [False, False, None], - ('metric_keys', 'is_diff'): [False, False, False], - ('metric_values', 'double_value'): [0.1, 0.02, 0.3], - }) - df = dataframe.auto_pivot( - df, stringify_slices=False, collapse_column_names=False) - mux = pd.MultiIndex.from_tuples( - [ - (('metric_values', 'double_value'), False, 'mean_absolute_error'), - ( - ('metric_values', 'double_value'), - False, - 'mean_squared_logarithmic_error', + def testAutoPivot_MetricsDataFrame(self): + df = pd.DataFrame( + { + ("slices", "age"): [38.0, 38.0, None], + ("slices", "sex"): [b"Female", b"Female", None], + ("metric_keys", "name"): [ + "mean_absolute_error", + "mean_squared_logarithmic_error", + "mean_absolute_error", + ], + ("metric_keys", "model_name"): ["", "", ""], + ("metric_keys", "output_name"): ["", "", ""], + ("metric_keys", "example_weighted"): [False, False, None], + ("metric_keys", "is_diff"): [False, False, False], + ("metric_values", "double_value"): [0.1, 0.02, 0.3], + } + ) + df = dataframe.auto_pivot( + df, stringify_slices=False, collapse_column_names=False + ) + mux = pd.MultiIndex.from_tuples( + [ + (("metric_values", "double_value"), False, "mean_absolute_error"), + ( + ("metric_values", "double_value"), + False, + "mean_squared_logarithmic_error", + ), + (("metric_values", "double_value"), np.nan, "mean_absolute_error"), + ], + names=( + None, + ("metric_keys", "example_weighted"), + ("metric_keys", "name"), ), - (('metric_values', 'double_value'), np.nan, 'mean_absolute_error'), - ], - names=( - None, - ('metric_keys', 'example_weighted'), - ('metric_keys', 'name'), - ), - ) - mix = pd.MultiIndex.from_tuples( - [(np.nan, np.nan), (38.0, b'Female')], - names=[('slices', 'age'), ('slices', 'sex')], - ) - expected = pd.DataFrame( - [[np.nan, np.nan, 0.3], [0.1, 0.02, np.nan]], - index=mix, - columns=mux, - ) - pd.testing.assert_frame_equal(expected, df, check_column_type=False) + ) + mix = pd.MultiIndex.from_tuples( + [(np.nan, np.nan), (38.0, b"Female")], + names=[("slices", "age"), ("slices", "sex")], + ) + expected = pd.DataFrame( + [[np.nan, np.nan, 0.3], [0.1, 0.02, np.nan]], + index=mix, + columns=mux, + ) + pd.testing.assert_frame_equal(expected, df, check_column_type=False) - def testAutoPivot_MetricsDataFrameStringifySlices(self): - df = pd.DataFrame({ - ('slices', 'age'): [38.0, 38.0, None], - ('slices', 'sex'): [b'Female', b'Female', None], - ('metric_keys', 'name'): [ - 'mean_absolute_error', 'mean_squared_logarithmic_error', - 'mean_absolute_error' - ], - ('metric_keys', 'model_name'): ['', '', ''], - ('metric_keys', 'output_name'): ['', '', ''], - ('metric_keys', 'example_weighted'): [False, False, None], - ('metric_keys', 'is_diff'): [False, False, False], - ('metric_values', 'double_value'): [0.1, 0.02, 0.3], - }) - df = dataframe.auto_pivot( - df, stringify_slices=True, collapse_column_names=False) - mux = pd.MultiIndex.from_tuples( - [ - (('metric_values', 'double_value'), np.nan, 'mean_absolute_error'), - (('metric_values', 'double_value'), False, 'mean_absolute_error'), - ( - ('metric_values', 'double_value'), - False, - 'mean_squared_logarithmic_error', + def testAutoPivot_MetricsDataFrameStringifySlices(self): + df = pd.DataFrame( + { + ("slices", "age"): [38.0, 38.0, None], + ("slices", "sex"): [b"Female", b"Female", None], + ("metric_keys", "name"): [ + "mean_absolute_error", + "mean_squared_logarithmic_error", + "mean_absolute_error", + ], + ("metric_keys", "model_name"): ["", "", ""], + ("metric_keys", "output_name"): ["", "", ""], + ("metric_keys", "example_weighted"): [False, False, None], + ("metric_keys", "is_diff"): [False, False, False], + ("metric_values", "double_value"): [0.1, 0.02, 0.3], + } + ) + df = dataframe.auto_pivot( + df, stringify_slices=True, collapse_column_names=False + ) + mux = pd.MultiIndex.from_tuples( + [ + (("metric_values", "double_value"), np.nan, "mean_absolute_error"), + (("metric_values", "double_value"), False, "mean_absolute_error"), + ( + ("metric_values", "double_value"), + False, + "mean_squared_logarithmic_error", + ), + ], + names=( + None, + ("metric_keys", "example_weighted"), + ("metric_keys", "name"), ), - ], - names=( - None, - ('metric_keys', 'example_weighted'), - ('metric_keys', 'name'), - ), - ) - index = pd.Index( - ['', "age:38.0; sex:b'Female'"], dtype='object', name='slices' - ) - expected = pd.DataFrame( - [[0.3, np.nan, np.nan], [np.nan, 0.1, 0.02]], - index=index, - columns=mux, - ) - pd.testing.assert_frame_equal(expected, df, check_column_type=False) + ) + index = pd.Index(["", "age:38.0; sex:b'Female'"], dtype="object", name="slices") + expected = pd.DataFrame( + [[0.3, np.nan, np.nan], [np.nan, 0.1, 0.02]], + index=index, + columns=mux, + ) + pd.testing.assert_frame_equal(expected, df, check_column_type=False) - def testAutoPivot_MetricsDataFrameCollapseColumnNames(self): - df = pd.DataFrame({ - ('slices', 'age'): [38.0, 38.0, None], - ('slices', 'sex'): [b'Female', b'Female', None], - ('metric_keys', 'name'): [ - 'mean_absolute_error', - 'mean_squared_logarithmic_error', - 'mean_absolute_error', - ], - ('metric_keys', 'model_name'): ['', '', ''], - ('metric_keys', 'output_name'): ['', '', ''], - ('metric_keys', 'example_weighted'): [False, False, None], - ('metric_keys', 'is_diff'): [False, False, False], - ('metric_values', 'double_value'): [0.1, 0.02, 0.3], - }) - df = dataframe.auto_pivot( - df, stringify_slices=False, collapse_column_names=True) - mux = pd.MultiIndex.from_tuples( - [ - (False, 'mean_absolute_error'), - (False, 'mean_squared_logarithmic_error'), - (np.nan, 'mean_absolute_error'), - ], - names=[('metric_keys', 'example_weighted'), ('metric_keys', 'name')], - ) - mix = pd.MultiIndex.from_tuples( - [(np.nan, np.nan), (38.0, b'Female')], - names=[('slices', 'age'), ('slices', 'sex')], - ) - expected = pd.DataFrame( - [[np.nan, np.nan, 0.3], [0.1, 0.02, np.nan]], - index=mix, - columns=mux, - ) - pd.testing.assert_frame_equal(expected, df, check_column_type=False) + def testAutoPivot_MetricsDataFrameCollapseColumnNames(self): + df = pd.DataFrame( + { + ("slices", "age"): [38.0, 38.0, None], + ("slices", "sex"): [b"Female", b"Female", None], + ("metric_keys", "name"): [ + "mean_absolute_error", + "mean_squared_logarithmic_error", + "mean_absolute_error", + ], + ("metric_keys", "model_name"): ["", "", ""], + ("metric_keys", "output_name"): ["", "", ""], + ("metric_keys", "example_weighted"): [False, False, None], + ("metric_keys", "is_diff"): [False, False, False], + ("metric_values", "double_value"): [0.1, 0.02, 0.3], + } + ) + df = dataframe.auto_pivot( + df, stringify_slices=False, collapse_column_names=True + ) + mux = pd.MultiIndex.from_tuples( + [ + (False, "mean_absolute_error"), + (False, "mean_squared_logarithmic_error"), + (np.nan, "mean_absolute_error"), + ], + names=[("metric_keys", "example_weighted"), ("metric_keys", "name")], + ) + mix = pd.MultiIndex.from_tuples( + [(np.nan, np.nan), (38.0, b"Female")], + names=[("slices", "age"), ("slices", "sex")], + ) + expected = pd.DataFrame( + [[np.nan, np.nan, 0.3], [0.1, 0.02, np.nan]], + index=mix, + columns=mux, + ) + pd.testing.assert_frame_equal(expected, df, check_column_type=False) - def testAutoPivot_MetricsDataFrameOverallSliceOnly(self): - dfs = dataframe.metrics_as_dataframes( - self.metrics_overall_slice_only, include_empty_columns=False) - df = dfs.double_value - expected = df.pivot( - index=[ - ('slices', 'Overall'), - ], - columns=[('metric_keys', 'name')], - values=[('metric_values', 'double_value')]) - df = dataframe.auto_pivot( - df, stringify_slices=False, collapse_column_names=False) - pd.testing.assert_frame_equal(expected, df, check_column_type=False) + def testAutoPivot_MetricsDataFrameOverallSliceOnly(self): + dfs = dataframe.metrics_as_dataframes( + self.metrics_overall_slice_only, include_empty_columns=False + ) + df = dfs.double_value + expected = df.pivot( + index=[ + ("slices", "Overall"), + ], + columns=[("metric_keys", "name")], + values=[("metric_values", "double_value")], + ) + df = dataframe.auto_pivot( + df, stringify_slices=False, collapse_column_names=False + ) + pd.testing.assert_frame_equal(expected, df, check_column_type=False) class PlotsAsDataFrameTest(tf.test.TestCase): - - def setUp(self): - super().setUp() - self.plots_for_slice = [ - text_format.Parse( - """ + def setUp(self): + super().setUp() + self.plots_for_slice = [ + text_format.Parse( + """ slice_key { single_slice_keys { column: "age" @@ -361,61 +389,68 @@ def setUp(self): } } } - """, metrics_for_slice_pb2.PlotsForSlice()) - ] + """, + metrics_for_slice_pb2.PlotsForSlice(), + ) + ] - def testLoadPlotsAsDataFrame(self): - dfs = dataframe.plots_as_dataframes(self.plots_for_slice) - expected = pd.DataFrame({ - ('slices', 'age'): [38.0, 38.0], - ('slices', 'sex'): [b'Female', b'Female'], - ('plot_keys', 'name'): ['', ''], - ('plot_keys', 'model_name'): ['', ''], - ('plot_keys', 'output_name'): ['', ''], - ('plot_keys', 'example_weighted'): [False, False], - ('plot_data', 'threshold'): [0.5, 0.5], - ('plot_data', 'false_negatives'): [10.0, 10.0], - ('plot_data', 'true_negatives'): [10.0, 10.0], - ('plot_data', 'false_positives'): [10.0, 10.0], - ('plot_data', 'true_positives'): [10.0, 10.0], - ('plot_data', 'precision'): [0.9, 0.9], - ('plot_data', 'recall'): [0.8, 0.8], - ('plot_data', 'false_positive_rate'): [0.0, 0.0], - ('plot_data', 'f1'): [0.0, 0.0], - ('plot_data', 'accuracy'): [0.0, 0.0], - ('plot_data', 'false_omission_rate'): [0.0, 0.0], - }) - pd.testing.assert_frame_equal(expected, dfs.confusion_matrix_at_thresholds) + def testLoadPlotsAsDataFrame(self): + dfs = dataframe.plots_as_dataframes(self.plots_for_slice) + expected = pd.DataFrame( + { + ("slices", "age"): [38.0, 38.0], + ("slices", "sex"): [b"Female", b"Female"], + ("plot_keys", "name"): ["", ""], + ("plot_keys", "model_name"): ["", ""], + ("plot_keys", "output_name"): ["", ""], + ("plot_keys", "example_weighted"): [False, False], + ("plot_data", "threshold"): [0.5, 0.5], + ("plot_data", "false_negatives"): [10.0, 10.0], + ("plot_data", "true_negatives"): [10.0, 10.0], + ("plot_data", "false_positives"): [10.0, 10.0], + ("plot_data", "true_positives"): [10.0, 10.0], + ("plot_data", "precision"): [0.9, 0.9], + ("plot_data", "recall"): [0.8, 0.8], + ("plot_data", "false_positive_rate"): [0.0, 0.0], + ("plot_data", "f1"): [0.0, 0.0], + ("plot_data", "accuracy"): [0.0, 0.0], + ("plot_data", "false_omission_rate"): [0.0, 0.0], + } + ) + pd.testing.assert_frame_equal(expected, dfs.confusion_matrix_at_thresholds) - def testLoadPlotsAsDataFrame_IncludeEmptyColumn(self): - dfs = dataframe.plots_as_dataframes( - self.plots_for_slice, include_empty_columns=True) - expected = pd.DataFrame({ - ('slices', 'age'): [38.0, 38.0], - ('slices', 'sex'): [b'Female', b'Female'], - ('plot_keys', 'name'): ['', ''], - ('plot_keys', 'model_name'): ['', ''], - ('plot_keys', 'output_name'): ['', ''], - ('plot_keys', 'sub_key'): [None, None], - ('plot_keys', 'example_weighted'): [False, False], - ('plot_data', 'threshold'): [0.5, 0.5], - ('plot_data', 'false_negatives'): [10.0, 10.0], - ('plot_data', 'true_negatives'): [10.0, 10.0], - ('plot_data', 'false_positives'): [10.0, 10.0], - ('plot_data', 'true_positives'): [10.0, 10.0], - ('plot_data', 'precision'): [0.9, 0.9], - ('plot_data', 'recall'): [0.8, 0.8], - ('plot_data', 'false_positive_rate'): [0.0, 0.0], - ('plot_data', 'f1'): [0.0, 0.0], - ('plot_data', 'accuracy'): [0.0, 0.0], - ('plot_data', 'false_omission_rate'): [0.0, 0.0], - }) - pd.testing.assert_frame_equal(expected, dfs.confusion_matrix_at_thresholds) + def testLoadPlotsAsDataFrame_IncludeEmptyColumn(self): + dfs = dataframe.plots_as_dataframes( + self.plots_for_slice, include_empty_columns=True + ) + expected = pd.DataFrame( + { + ("slices", "age"): [38.0, 38.0], + ("slices", "sex"): [b"Female", b"Female"], + ("plot_keys", "name"): ["", ""], + ("plot_keys", "model_name"): ["", ""], + ("plot_keys", "output_name"): ["", ""], + ("plot_keys", "sub_key"): [None, None], + ("plot_keys", "example_weighted"): [False, False], + ("plot_data", "threshold"): [0.5, 0.5], + ("plot_data", "false_negatives"): [10.0, 10.0], + ("plot_data", "true_negatives"): [10.0, 10.0], + ("plot_data", "false_positives"): [10.0, 10.0], + ("plot_data", "true_positives"): [10.0, 10.0], + ("plot_data", "precision"): [0.9, 0.9], + ("plot_data", "recall"): [0.8, 0.8], + ("plot_data", "false_positive_rate"): [0.0, 0.0], + ("plot_data", "f1"): [0.0, 0.0], + ("plot_data", "accuracy"): [0.0, 0.0], + ("plot_data", "false_omission_rate"): [0.0, 0.0], + } + ) + pd.testing.assert_frame_equal(expected, dfs.confusion_matrix_at_thresholds) - def testLoadPlotsAsDataFrame_Empty(self): - plots_for_slice = [ - text_format.Parse( - """ + def testLoadPlotsAsDataFrame_Empty(self): + plots_for_slice = [ + text_format.Parse( + """ slice_key { single_slice_keys { column: "age" @@ -426,95 +461,104 @@ def testLoadPlotsAsDataFrame_Empty(self): bytes_value: "Female" } } - """, metrics_for_slice_pb2.PlotsForSlice()) - ] + """, + metrics_for_slice_pb2.PlotsForSlice(), + ) + ] + + dfs = dataframe.plots_as_dataframes(plots_for_slice) + self.assertIsNone(dfs.confusion_matrix_at_thresholds) - dfs = dataframe.plots_as_dataframes(plots_for_slice) - self.assertIsNone(dfs.confusion_matrix_at_thresholds) + def testAutoPivot_PlotsDataFrame(self): + dfs = dataframe.plots_as_dataframes(self.plots_for_slice) + df = dataframe.auto_pivot( + dfs.confusion_matrix_at_thresholds, stringify_slices=False + ) + expected = pd.DataFrame( + { + ("slices", "age"): [38.0, 38.0], + ("slices", "sex"): [b"Female", b"Female"], + ("plot_keys", "name"): ["", ""], + ("plot_keys", "model_name"): ["", ""], + ("plot_keys", "output_name"): ["", ""], + ("plot_keys", "example_weighted"): [False, False], + ("plot_data", "threshold"): [0.5, 0.5], + ("plot_data", "false_negatives"): [10.0, 10.0], + ("plot_data", "true_negatives"): [10.0, 10.0], + ("plot_data", "false_positives"): [10.0, 10.0], + ("plot_data", "true_positives"): [10.0, 10.0], + ("plot_data", "precision"): [0.9, 0.9], + ("plot_data", "recall"): [0.8, 0.8], + ("plot_data", "false_positive_rate"): [0.0, 0.0], + ("plot_data", "f1"): [0.0, 0.0], + ("plot_data", "accuracy"): [0.0, 0.0], + ("plot_data", "false_omission_rate"): [0.0, 0.0], + } + ).pivot( + index=[("slices", "age"), ("slices", "sex")], + columns=[], + values=[ + ("plot_data", "threshold"), + ("plot_data", "false_negatives"), + ("plot_data", "true_negatives"), + ("plot_data", "false_positives"), + ("plot_data", "true_positives"), + ("plot_data", "precision"), + ("plot_data", "recall"), + ("plot_data", "false_positive_rate"), + ("plot_data", "f1"), + ("plot_data", "accuracy"), + ("plot_data", "false_omission_rate"), + ], + ) + pd.testing.assert_frame_equal(expected, df, check_column_type=False) - def testAutoPivot_PlotsDataFrame(self): - dfs = dataframe.plots_as_dataframes(self.plots_for_slice) - df = dataframe.auto_pivot( - dfs.confusion_matrix_at_thresholds, stringify_slices=False) - expected = pd.DataFrame({ - ('slices', 'age'): [38.0, 38.0], - ('slices', 'sex'): [b'Female', b'Female'], - ('plot_keys', 'name'): ['', ''], - ('plot_keys', 'model_name'): ['', ''], - ('plot_keys', 'output_name'): ['', ''], - ('plot_keys', 'example_weighted'): [False, False], - ('plot_data', 'threshold'): [0.5, 0.5], - ('plot_data', 'false_negatives'): [10.0, 10.0], - ('plot_data', 'true_negatives'): [10.0, 10.0], - ('plot_data', 'false_positives'): [10.0, 10.0], - ('plot_data', 'true_positives'): [10.0, 10.0], - ('plot_data', 'precision'): [0.9, 0.9], - ('plot_data', 'recall'): [0.8, 0.8], - ('plot_data', 'false_positive_rate'): [0.0, 0.0], - ('plot_data', 'f1'): [0.0, 0.0], - ('plot_data', 'accuracy'): [0.0, 0.0], - ('plot_data', 'false_omission_rate'): [0.0, 0.0], - }).pivot( - index=[('slices', 'age'), ('slices', 'sex')], - columns=[], - values=[ - ('plot_data', 'threshold'), - ('plot_data', 'false_negatives'), - ('plot_data', 'true_negatives'), - ('plot_data', 'false_positives'), - ('plot_data', 'true_positives'), - ('plot_data', 'precision'), - ('plot_data', 'recall'), - ('plot_data', 'false_positive_rate'), - ('plot_data', 'f1'), - ('plot_data', 'accuracy'), - ('plot_data', 'false_omission_rate'), - ], - ) - pd.testing.assert_frame_equal(expected, df, check_column_type=False) + def testAutoPivot_PlotsDataFrameCollapseColumnNames(self): + dfs = dataframe.plots_as_dataframes(self.plots_for_slice) + df = dataframe.auto_pivot( + dfs.confusion_matrix_at_thresholds, + stringify_slices=False, + collapse_column_names=True, + ) + expected = pd.DataFrame( + { + ("slices", "age"): [38.0, 38.0], + ("slices", "sex"): [b"Female", b"Female"], + ("plot_keys", "name"): ["", ""], + ("plot_keys", "model_name"): ["", ""], + ("plot_keys", "output_name"): ["", ""], + ("plot_keys", "example_weighted"): [False, False], + ("plot_data", "threshold"): [0.5, 0.5], + ("plot_data", "false_negatives"): [10.0, 10.0], + ("plot_data", "true_negatives"): [10.0, 10.0], + ("plot_data", "false_positives"): [10.0, 10.0], + ("plot_data", "true_positives"): [10.0, 10.0], + ("plot_data", "precision"): [0.9, 0.9], + ("plot_data", "recall"): [0.8, 0.8], + ("plot_data", "false_positive_rate"): [0.0, 0.0], + ("plot_data", "f1"): [0.0, 0.0], + ("plot_data", "accuracy"): [0.0, 0.0], + ("plot_data", "false_omission_rate"): [0.0, 0.0], + } + ).pivot( + index=[("slices", "age"), ("slices", "sex")], + columns=[], + values=[ + ("plot_data", "threshold"), + ("plot_data", "false_negatives"), + ("plot_data", "true_negatives"), + ("plot_data", "false_positives"), + ("plot_data", "true_positives"), + ("plot_data", "precision"), + ("plot_data", "recall"), + ("plot_data", "false_positive_rate"), + ("plot_data", "f1"), + ("plot_data", "accuracy"), + ("plot_data", "false_omission_rate"), + ], + ) + pd.testing.assert_frame_equal(expected, df, check_column_type=False) - def testAutoPivot_PlotsDataFrameCollapseColumnNames(self): - dfs = dataframe.plots_as_dataframes(self.plots_for_slice) - df = dataframe.auto_pivot( - dfs.confusion_matrix_at_thresholds, - stringify_slices=False, - collapse_column_names=True) - expected = pd.DataFrame({ - ('slices', 'age'): [38.0, 38.0], - ('slices', 'sex'): [b'Female', b'Female'], - ('plot_keys', 'name'): ['', ''], - ('plot_keys', 'model_name'): ['', ''], - ('plot_keys', 'output_name'): ['', ''], - ('plot_keys', 'example_weighted'): [False, False], - ('plot_data', 'threshold'): [0.5, 0.5], - ('plot_data', 'false_negatives'): [10.0, 10.0], - ('plot_data', 'true_negatives'): [10.0, 10.0], - ('plot_data', 'false_positives'): [10.0, 10.0], - ('plot_data', 'true_positives'): [10.0, 10.0], - ('plot_data', 'precision'): [0.9, 0.9], - ('plot_data', 'recall'): [0.8, 0.8], - ('plot_data', 'false_positive_rate'): [0.0, 0.0], - ('plot_data', 'f1'): [0.0, 0.0], - ('plot_data', 'accuracy'): [0.0, 0.0], - ('plot_data', 'false_omission_rate'): [0.0, 0.0], - }).pivot( - index=[('slices', 'age'), ('slices', 'sex')], - columns=[], - values=[ - ('plot_data', 'threshold'), - ('plot_data', 'false_negatives'), - ('plot_data', 'true_negatives'), - ('plot_data', 'false_positives'), - ('plot_data', 'true_positives'), - ('plot_data', 'precision'), - ('plot_data', 'recall'), - ('plot_data', 'false_positive_rate'), - ('plot_data', 'f1'), - ('plot_data', 'accuracy'), - ('plot_data', 'false_omission_rate'), - ], - ) - pd.testing.assert_frame_equal(expected, df, check_column_type=False) -if __name__ == '__main__': - tf.test.main() +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow_model_analysis/api/model_eval_lib.py b/tensorflow_model_analysis/api/model_eval_lib.py index 4f6ffedf55..f38db66c95 100644 --- a/tensorflow_model_analysis/api/model_eval_lib.py +++ b/tensorflow_model_analysis/api/model_eval_lib.py @@ -19,70 +19,74 @@ import tempfile from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Union -from absl import logging import apache_beam as beam import pandas as pd import pyarrow as pa import tensorflow as tf +from absl import logging +from tensorflow_metadata.proto.v0 import schema_pb2 +from tfx_bsl import beam as tfx_bsl_beam +from tfx_bsl.arrow import table_util +from tfx_bsl.tfxio import raw_tf_record, tensor_adapter, tf_example_record + from tensorflow_model_analysis import constants from tensorflow_model_analysis.api import types -from tensorflow_model_analysis.evaluators import evaluator -from tensorflow_model_analysis.evaluators import metrics_plots_and_validations_evaluator -from tensorflow_model_analysis.extractors import example_weights_extractor -from tensorflow_model_analysis.extractors import extractor -from tensorflow_model_analysis.extractors import features_extractor -from tensorflow_model_analysis.extractors import labels_extractor -from tensorflow_model_analysis.extractors import materialized_predictions_extractor -from tensorflow_model_analysis.extractors import predictions_extractor -from tensorflow_model_analysis.extractors import slice_key_extractor -from tensorflow_model_analysis.extractors import sql_slice_key_extractor -from tensorflow_model_analysis.extractors import tfjs_predict_extractor -from tensorflow_model_analysis.extractors import tflite_predict_extractor -from tensorflow_model_analysis.extractors import transformed_features_extractor -from tensorflow_model_analysis.extractors import unbatch_extractor -from tensorflow_model_analysis.proto import config_pb2 -from tensorflow_model_analysis.proto import metrics_for_slice_pb2 -from tensorflow_model_analysis.proto import validation_result_pb2 +from tensorflow_model_analysis.evaluators import ( + evaluator, + metrics_plots_and_validations_evaluator, +) +from tensorflow_model_analysis.extractors import ( + example_weights_extractor, + extractor, + features_extractor, + labels_extractor, + materialized_predictions_extractor, + predictions_extractor, + slice_key_extractor, + sql_slice_key_extractor, + tfjs_predict_extractor, + tflite_predict_extractor, + transformed_features_extractor, + unbatch_extractor, +) +from tensorflow_model_analysis.proto import ( + config_pb2, + metrics_for_slice_pb2, + validation_result_pb2, +) from tensorflow_model_analysis.slicer import slicer_lib as slicer -from tensorflow_model_analysis.utils import config_util -from tensorflow_model_analysis.utils import model_util +from tensorflow_model_analysis.utils import config_util, model_util from tensorflow_model_analysis.validators import validator from tensorflow_model_analysis.view import util as view_util from tensorflow_model_analysis.view import view_types -from tensorflow_model_analysis.writers import eval_config_writer -from tensorflow_model_analysis.writers import metrics_plots_and_validations_writer -from tensorflow_model_analysis.writers import writer -from tfx_bsl import beam as tfx_bsl_beam -from tfx_bsl.arrow import table_util -from tfx_bsl.tfxio import raw_tf_record -from tfx_bsl.tfxio import tensor_adapter -from tfx_bsl.tfxio import tf_example_record - -from tensorflow_metadata.proto.v0 import schema_pb2 +from tensorflow_model_analysis.writers import ( + eval_config_writer, + metrics_plots_and_validations_writer, + writer, +) tfx_bsl_beam.fix_code_type_pickling() # This is a legacy eval tag used to report failure with estimators correctly. -_LEGACY_EVAL_TAG = 'eval' +_LEGACY_EVAL_TAG = "eval" def _assert_tensorflow_version(): - """Check that we're using a compatible TF version.""" - # Fail with a clear error in case we are not using a compatible TF version. - major, minor, _ = tf.version.VERSION.split('.') - if (int(major) not in (1, 2)) or (int(major) == 1 and int(minor) < 15): - raise RuntimeError( - 'Tensorflow version >= 1.15, < 3 is required. Found (%s). Please ' - 'install the latest 1.x or 2.x version from ' - 'https://github.com/tensorflow/tensorflow. ' - % tf.version.VERSION - ) - if int(major) == 2: - logging.warning( - 'Tensorflow version (%s) found. Note that TFMA support for TF 2.0 ' - 'is currently in beta', - tf.version.VERSION, - ) + """Check that we're using a compatible TF version.""" + # Fail with a clear error in case we are not using a compatible TF version. + major, minor, _ = tf.version.VERSION.split(".") + if (int(major) not in (1, 2)) or (int(major) == 1 and int(minor) < 15): + raise RuntimeError( + "Tensorflow version >= 1.15, < 3 is required. Found (%s). Please " + "install the latest 1.x or 2.x version from " + "https://github.com/tensorflow/tensorflow. " % tf.version.VERSION + ) + if int(major) == 2: + logging.warning( + "Tensorflow version (%s) found. Note that TFMA support for TF 2.0 " + "is currently in beta", + tf.version.VERSION, + ) def _is_legacy_eval( @@ -90,34 +94,36 @@ def _is_legacy_eval( eval_shared_models: Optional[List[types.EvalSharedModel]], eval_config: Optional[config_pb2.EvalConfig], ): - """Returns True if legacy evaluation is being used. - - A legacy evaluation is an evalution that uses only a single EvalSharedModel, - has no tags (or uses "eval" as its tag), and does not specify an eval_config - The legacy evaluation is based on using add_metrics_callbacks to create a - modified version of the graph saved with an EvalSavedModel. The newer version - of evaluation supports both add_metrics_callbacks as well as metrics defined - in MetricsSpecs inside of EvalConfig. The newer version works with both "eval" - and serving models and also supports multi-model evaluation. This function is - used by code to support backwards compatibility for callers that have not - updated to use the new EvalConfig. - - Args: - config_version: Optionally, An explicit version of the config determined - elsewhere. This is used to handle cases where the provided eval_config was - generated internally, and thus not a reliable indicator of user intent. - eval_shared_models: Optionally, the model(s) to be evaluated. - eval_config: Optionally, an EvalConfig specifying v2 config. - - Returns: - Whether the user inputs should trigger a legacy evaluation. - """ - return (config_version is not None and config_version == 1) or ( - eval_shared_models - and len(eval_shared_models) == 1 - and eval_shared_models[0].model_type == constants.TFMA_EVAL - and not eval_config - ) + """Returns True if legacy evaluation is being used. + + A legacy evaluation is an evalution that uses only a single EvalSharedModel, + has no tags (or uses "eval" as its tag), and does not specify an eval_config + The legacy evaluation is based on using add_metrics_callbacks to create a + modified version of the graph saved with an EvalSavedModel. The newer version + of evaluation supports both add_metrics_callbacks as well as metrics defined + in MetricsSpecs inside of EvalConfig. The newer version works with both "eval" + and serving models and also supports multi-model evaluation. This function is + used by code to support backwards compatibility for callers that have not + updated to use the new EvalConfig. + + Args: + ---- + config_version: Optionally, An explicit version of the config determined + elsewhere. This is used to handle cases where the provided eval_config was + generated internally, and thus not a reliable indicator of user intent. + eval_shared_models: Optionally, the model(s) to be evaluated. + eval_config: Optionally, an EvalConfig specifying v2 config. + + Returns: + ------- + Whether the user inputs should trigger a legacy evaluation. + """ + return (config_version is not None and config_version == 1) or ( + eval_shared_models + and len(eval_shared_models) == 1 + and eval_shared_models[0].model_type == constants.TFMA_EVAL + and not eval_config + ) def _default_eval_config( @@ -127,134 +133,129 @@ def _default_eval_config( compute_confidence_intervals: Optional[bool], min_slice_size: int, ): - """Creates default EvalConfig (for use in legacy evaluations).""" - model_specs = [] - for shared_model in eval_shared_models: - example_weight_key = shared_model.example_weight_key - example_weight_keys = {} - if example_weight_key and isinstance(example_weight_key, dict): - example_weight_keys = example_weight_key - example_weight_key = '' - model_specs.append( - config_pb2.ModelSpec( - name=shared_model.model_name, - example_weight_key=example_weight_key, - example_weight_keys=example_weight_keys, + """Creates default EvalConfig (for use in legacy evaluations).""" + model_specs = [] + for shared_model in eval_shared_models: + example_weight_key = shared_model.example_weight_key + example_weight_keys = {} + if example_weight_key and isinstance(example_weight_key, dict): + example_weight_keys = example_weight_key + example_weight_key = "" + model_specs.append( + config_pb2.ModelSpec( + name=shared_model.model_name, + example_weight_key=example_weight_key, + example_weight_keys=example_weight_keys, + ) ) + slicing_specs = None + if slice_spec: + slicing_specs = [s.to_proto() for s in slice_spec] + options = config_pb2.Options() + options.compute_confidence_intervals.value = compute_confidence_intervals + options.min_slice_size.value = min_slice_size + if not write_config: + options.disabled_outputs.values.append(eval_config_writer.EVAL_CONFIG_FILE) + return config_pb2.EvalConfig( + model_specs=model_specs, slicing_specs=slicing_specs, options=options ) - slicing_specs = None - if slice_spec: - slicing_specs = [s.to_proto() for s in slice_spec] - options = config_pb2.Options() - options.compute_confidence_intervals.value = compute_confidence_intervals - options.min_slice_size.value = min_slice_size - if not write_config: - options.disabled_outputs.values.append(eval_config_writer.EVAL_CONFIG_FILE) - return config_pb2.EvalConfig( - model_specs=model_specs, slicing_specs=slicing_specs, options=options - ) def _model_types( eval_shared_models: Optional[List[types.EvalSharedModel]], ) -> Optional[Set[str]]: - """Returns model types associated with given EvalSharedModels.""" - if not eval_shared_models: - return None - else: - return set([m.model_type for m in eval_shared_models]) + """Returns model types associated with given EvalSharedModels.""" + if not eval_shared_models: + return None + else: + return set([m.model_type for m in eval_shared_models]) def _update_eval_config_with_defaults( eval_config: config_pb2.EvalConfig, eval_shared_model: Optional[types.MaybeMultipleEvalSharedModels], ) -> config_pb2.EvalConfig: - """Returns updated eval config with default values.""" - eval_shared_models = model_util.verify_and_update_eval_shared_models( - eval_shared_model - ) - has_baseline = eval_shared_models and len(eval_shared_models) > 1 - return config_util.update_eval_config_with_defaults( - eval_config=eval_config, - has_baseline=has_baseline, - rubber_stamp=model_util.has_rubber_stamp(eval_shared_models), - ) + """Returns updated eval config with default values.""" + eval_shared_models = model_util.verify_and_update_eval_shared_models( + eval_shared_model + ) + has_baseline = eval_shared_models and len(eval_shared_models) > 1 + return config_util.update_eval_config_with_defaults( + eval_config=eval_config, + has_baseline=has_baseline, + rubber_stamp=model_util.has_rubber_stamp(eval_shared_models), + ) def _get_extract_num_bytes(extract: types.Extracts) -> int: - """Returns the number of bytes in the input.""" - if constants.ARROW_RECORD_BATCH_KEY in extract: - return extract[constants.ARROW_RECORD_BATCH_KEY].nbytes - if constants.INPUT_KEY in extract: - if isinstance(extract[constants.INPUT_KEY], bytes): - return len(extract[constants.INPUT_KEY]) - logging.warning('Failed to extract number of input bytes.') - return 0 + """Returns the number of bytes in the input.""" + if constants.ARROW_RECORD_BATCH_KEY in extract: + return extract[constants.ARROW_RECORD_BATCH_KEY].nbytes + if constants.INPUT_KEY in extract: + if isinstance(extract[constants.INPUT_KEY], bytes): + return len(extract[constants.INPUT_KEY]) + logging.warning("Failed to extract number of input bytes.") + return 0 def _increment_counter(counter_name: str, value: int) -> int: - """Increments the specified counter by the value.""" - counter = beam.metrics.Metrics.counter( - constants.METRICS_NAMESPACE, counter_name - ) - counter.inc(value) - return value + """Increments the specified counter by the value.""" + counter = beam.metrics.Metrics.counter(constants.METRICS_NAMESPACE, counter_name) + counter.inc(value) + return value @beam.ptransform_fn def _TrackBytesProcessed( # pylint: disable=invalid-name dataset: beam.PCollection[types.Extracts], ) -> beam.pvalue.PCollection[int]: - """Gathers telemetry on input Extracts.""" - - return ( - dataset - | 'GetExtractSize' >> beam.Map(_get_extract_num_bytes) - | 'SumTotalBytes' >> beam.CombineGlobally(sum) - | 'IncrementCounter' - >> beam.Map(lambda x: _increment_counter('extract_input_bytes', x)) - ) + """Gathers telemetry on input Extracts.""" + return ( + dataset + | "GetExtractSize" >> beam.Map(_get_extract_num_bytes) + | "SumTotalBytes" >> beam.CombineGlobally(sum) + | "IncrementCounter" + >> beam.Map(lambda x: _increment_counter("extract_input_bytes", x)) + ) MetricsForSlice = metrics_for_slice_pb2.MetricsForSlice def load_metrics( - output_path: str, output_file_format: str = 'tfrecord' + output_path: str, output_file_format: str = "tfrecord" ) -> Iterator[MetricsForSlice]: - """Read and deserialize the MetricsForSlice records.""" - for m in metrics_plots_and_validations_writer.load_and_deserialize_metrics( - output_path, output_file_format - ): - yield m + """Read and deserialize the MetricsForSlice records.""" + for m in metrics_plots_and_validations_writer.load_and_deserialize_metrics( + output_path, output_file_format + ): + yield m PlotsForSlice = metrics_for_slice_pb2.PlotsForSlice def load_plots( - output_path: str, output_file_format: str = 'tfrecord' + output_path: str, output_file_format: str = "tfrecord" ) -> Iterator[PlotsForSlice]: - """Read and deserialize the PlotsForSlice records.""" - for p in metrics_plots_and_validations_writer.load_and_deserialize_plots( - output_path, output_file_format - ): - yield p + """Read and deserialize the PlotsForSlice records.""" + for p in metrics_plots_and_validations_writer.load_and_deserialize_plots( + output_path, output_file_format + ): + yield p AttributionsForSlice = metrics_for_slice_pb2.AttributionsForSlice def load_attributions( - output_path: str, output_file_format: str = 'tfrecord' + output_path: str, output_file_format: str = "tfrecord" ) -> Iterator[AttributionsForSlice]: - """Read and deserialize the AttributionsForSlice records.""" - for ( - a - ) in metrics_plots_and_validations_writer.load_and_deserialize_attributions( - output_path, output_file_format - ): - yield a + """Read and deserialize the AttributionsForSlice records.""" + for a in metrics_plots_and_validations_writer.load_and_deserialize_attributions( + output_path, output_file_format + ): + yield a # Define types here to avoid type errors between OSS and internal code. @@ -262,131 +263,133 @@ def load_attributions( def load_validation_result( - output_path: str, output_file_format: str = '' + output_path: str, output_file_format: str = "" ) -> ValidationResult: - """Read and deserialize the ValidationResult.""" - return metrics_plots_and_validations_writer.load_and_deserialize_validation_result( - output_path, output_file_format - ) + """Read and deserialize the ValidationResult.""" + return metrics_plots_and_validations_writer.load_and_deserialize_validation_result( + output_path, output_file_format + ) def make_eval_results( results: List[view_types.EvalResult], mode: str ) -> view_types.EvalResults: - """Run model analysis for a single model on multiple data sets. + """Run model analysis for a single model on multiple data sets. - Args: - results: A list of TFMA evaluation results. - mode: The mode of the evaluation. Currently, tfma.DATA_CENTRIC_MODE and - tfma.MODEL_CENTRIC_MODE are supported. + Args: + ---- + results: A list of TFMA evaluation results. + mode: The mode of the evaluation. Currently, tfma.DATA_CENTRIC_MODE and + tfma.MODEL_CENTRIC_MODE are supported. - Returns: - An `tfma.view.EvalResults` object containing all evaluation results. This - can be used to construct a time series view. - """ - return view_types.EvalResults(results, mode) + Returns: + ------- + An `tfma.view.EvalResults` object containing all evaluation results. This + can be used to construct a time series view. + """ + return view_types.EvalResults(results, mode) def load_eval_results( output_paths: Union[str, List[str]], - output_file_format: Optional[str] = 'tfrecord', + output_file_format: Optional[str] = "tfrecord", mode: str = constants.MODEL_CENTRIC_MODE, model_name: Optional[str] = None, ) -> view_types.EvalResults: - """Loads results for multiple models or multiple data sets. - - Args: - output_paths: A single path or list of output paths of completed tfma runs. - output_file_format: Optional file extension to filter files by. - mode: The mode of the evaluation. Currently, tfma.DATA_CENTRIC_MODE and - tfma.MODEL_CENTRIC_MODE are supported. - model_name: Filters to only return results for given model. If unset all - models are returned. - - Returns: - An EvalResults containing the evaluation results serialized at output_paths. - This can be used to construct a time series view. - """ - results = [] - if not isinstance(output_paths, list): - output_paths = [output_paths] - for output_path in output_paths: - if model_name is None: - _, _, _, model_locations = eval_config_writer.load_eval_run(output_path) - model_names = list(model_locations) - else: - model_names = [model_name] - for model_name in model_names: - results.append( - load_eval_result( - output_path, output_file_format, model_name=model_name - ) - ) - return make_eval_results(results, mode) + """Loads results for multiple models or multiple data sets. + + Args: + ---- + output_paths: A single path or list of output paths of completed tfma runs. + output_file_format: Optional file extension to filter files by. + mode: The mode of the evaluation. Currently, tfma.DATA_CENTRIC_MODE and + tfma.MODEL_CENTRIC_MODE are supported. + model_name: Filters to only return results for given model. If unset all + models are returned. + + Returns: + ------- + An EvalResults containing the evaluation results serialized at output_paths. + This can be used to construct a time series view. + """ + results = [] + if not isinstance(output_paths, list): + output_paths = [output_paths] + for output_path in output_paths: + if model_name is None: + _, _, _, model_locations = eval_config_writer.load_eval_run(output_path) + model_names = list(model_locations) + else: + model_names = [model_name] + for model_name in model_names: + results.append( + load_eval_result(output_path, output_file_format, model_name=model_name) + ) + return make_eval_results(results, mode) def load_eval_result( output_path: str, - output_file_format: Optional[str] = 'tfrecord', + output_file_format: Optional[str] = "tfrecord", model_name: Optional[str] = None, ) -> view_types.EvalResult: - """Loads EvalResult object for use with the visualization functions. - - Args: - output_path: Output directory containing config, metrics, plots, etc. - output_file_format: Optional file extension to filter files by. - model_name: Optional model name. Required if multi-model evaluation was run. - - Returns: - EvalResult object for use with the visualization functions. - """ - # Config, metrics, and plots files should all exist under the given output - # directory, but fairness plugin has a use-case where only the metrics are - # provided so we support all files as being optional (the EvalResult will have - # corresponding None values for files that are not present). - eval_config, data_location, file_format, model_locations = ( - eval_config_writer.load_eval_run(output_path) - ) - metrics_list = [] - for p in metrics_plots_and_validations_writer.load_and_deserialize_metrics( - output_path, output_file_format - ): - metrics = view_util.convert_metrics_proto_to_dict(p, model_name=model_name) - if metrics is not None: - metrics_list.append(metrics) - plots_list = [] - for p in metrics_plots_and_validations_writer.load_and_deserialize_plots( - output_path, output_file_format - ): - plots = view_util.convert_plots_proto_to_dict(p, model_name=model_name) - if plots is not None: - plots_list.append(plots) - attributions_list = [] - for ( - a - ) in metrics_plots_and_validations_writer.load_and_deserialize_attributions( - output_path, output_file_format - ): - attributions = view_util.convert_attributions_proto_to_dict( - a, model_name=model_name + """Loads EvalResult object for use with the visualization functions. + + Args: + ---- + output_path: Output directory containing config, metrics, plots, etc. + output_file_format: Optional file extension to filter files by. + model_name: Optional model name. Required if multi-model evaluation was run. + + Returns: + ------- + EvalResult object for use with the visualization functions. + """ + # Config, metrics, and plots files should all exist under the given output + # directory, but fairness plugin has a use-case where only the metrics are + # provided so we support all files as being optional (the EvalResult will have + # corresponding None values for files that are not present). + eval_config, data_location, file_format, model_locations = ( + eval_config_writer.load_eval_run(output_path) + ) + metrics_list = [] + for p in metrics_plots_and_validations_writer.load_and_deserialize_metrics( + output_path, output_file_format + ): + metrics = view_util.convert_metrics_proto_to_dict(p, model_name=model_name) + if metrics is not None: + metrics_list.append(metrics) + plots_list = [] + for p in metrics_plots_and_validations_writer.load_and_deserialize_plots( + output_path, output_file_format + ): + plots = view_util.convert_plots_proto_to_dict(p, model_name=model_name) + if plots is not None: + plots_list.append(plots) + attributions_list = [] + for a in metrics_plots_and_validations_writer.load_and_deserialize_attributions( + output_path, output_file_format + ): + attributions = view_util.convert_attributions_proto_to_dict( + a, model_name=model_name + ) + if attributions is not None: + attributions_list.append(attributions) + if not model_locations: + model_location = "" + elif model_name is None: + model_location = list(model_locations.values())[0] + else: + model_location = model_locations[model_name] + return view_types.EvalResult( # pytype: disable=wrong-arg-types + slicing_metrics=metrics_list, + plots=plots_list, + attributions=attributions_list, + config=eval_config, + data_location=data_location, + file_format=file_format, + model_location=model_location, ) - if attributions is not None: - attributions_list.append(attributions) - if not model_locations: - model_location = '' - elif model_name is None: - model_location = list(model_locations.values())[0] - else: - model_location = model_locations[model_name] - return view_types.EvalResult( # pytype: disable=wrong-arg-types - slicing_metrics=metrics_list, - plots=plots_list, - attributions=attributions_list, - config=eval_config, - data_location=data_location, - file_format=file_format, - model_location=model_location, - ) def default_eval_shared_model( @@ -397,122 +400,119 @@ def default_eval_shared_model( additional_fetches: Optional[List[str]] = None, blacklist_feature_fetches: Optional[List[str]] = None, tags: Optional[List[str]] = None, - model_name: str = '', + model_name: str = "", eval_config: Optional[config_pb2.EvalConfig] = None, custom_model_loader: Optional[types.ModelLoader] = None, rubber_stamp: Optional[bool] = False, resource_hints: Optional[Dict[str, Any]] = None, backend_config: Optional[Any] = None, ) -> types.EvalSharedModel: - """Returns default EvalSharedModel. - - Args: - eval_saved_model_path: Path to EvalSavedModel. - add_metrics_callbacks: Optional list of callbacks for adding additional - metrics to the graph (see EvalSharedModel for more information on how to - configure additional metrics). Metrics for example count and example - weights will be added automatically. Only used if EvalSavedModel used. - include_default_metrics: DEPRECATED. Use - eval_config.options.include_default_metrics. - example_weight_key: DEPRECATED. Use - eval_config.model_specs.example_weight_key or - eval_config.model_specs.example_weight_keys. - additional_fetches: Optional prefixes of additional tensors stored in - signature_def.inputs that should be fetched at prediction time. The - "features" and "labels" tensors are handled automatically and should not - be included. Only used if EvalSavedModel used. - blacklist_feature_fetches: Optional list of tensor names in the features - dictionary which should be excluded from the fetches request. This is - useful in scenarios where features are large (e.g. images) and can lead to - excessive memory use if stored. Only used if EvalSavedModel used. - tags: Optional model tags (e.g. 'serve' for serving or 'eval' for - EvalSavedModel). - model_name: Optional name of the model being created (should match - ModelSpecs.name). The name should only be provided if multiple models are - being evaluated. - eval_config: Eval config. - custom_model_loader: Optional custom model loader for non-TF models. - rubber_stamp: True when this run is a first run without a baseline model - while a baseline is configured, the diff thresholds will be ignored. - resource_hints: The beam resource hints to apply to the PTransform which - runs inference for this model. - backend_config: Optional configuration of backend running model inference - with *some* prediction extractors. - """ - if not eval_config: - # Default to tfma eval model unless eval - is_baseline = False - if tags and _LEGACY_EVAL_TAG in tags: - model_type = constants.TFMA_EVAL - elif tags and tf.saved_model.SERVING in tags: - model_type = constants.TF_ESTIMATOR + """Returns default EvalSharedModel. + + Args: + ---- + eval_saved_model_path: Path to EvalSavedModel. + add_metrics_callbacks: Optional list of callbacks for adding additional + metrics to the graph (see EvalSharedModel for more information on how to + configure additional metrics). Metrics for example count and example + weights will be added automatically. Only used if EvalSavedModel used. + include_default_metrics: DEPRECATED. Use + eval_config.options.include_default_metrics. + example_weight_key: DEPRECATED. Use + eval_config.model_specs.example_weight_key or + eval_config.model_specs.example_weight_keys. + additional_fetches: Optional prefixes of additional tensors stored in + signature_def.inputs that should be fetched at prediction time. The + "features" and "labels" tensors are handled automatically and should not + be included. Only used if EvalSavedModel used. + blacklist_feature_fetches: Optional list of tensor names in the features + dictionary which should be excluded from the fetches request. This is + useful in scenarios where features are large (e.g. images) and can lead to + excessive memory use if stored. Only used if EvalSavedModel used. + tags: Optional model tags (e.g. 'serve' for serving or 'eval' for + EvalSavedModel). + model_name: Optional name of the model being created (should match + ModelSpecs.name). The name should only be provided if multiple models are + being evaluated. + eval_config: Eval config. + custom_model_loader: Optional custom model loader for non-TF models. + rubber_stamp: True when this run is a first run without a baseline model + while a baseline is configured, the diff thresholds will be ignored. + resource_hints: The beam resource hints to apply to the PTransform which + runs inference for this model. + backend_config: Optional configuration of backend running model inference + with *some* prediction extractors. + """ + if not eval_config: + # Default to tfma eval model unless eval + is_baseline = False + if tags and _LEGACY_EVAL_TAG in tags: + model_type = constants.TFMA_EVAL + elif tags and tf.saved_model.SERVING in tags: + model_type = constants.TF_ESTIMATOR + else: + model_type = constants.TFMA_EVAL + if tags is None: + tags = [_LEGACY_EVAL_TAG] else: - model_type = constants.TFMA_EVAL - if tags is None: - tags = [_LEGACY_EVAL_TAG] - else: - model_spec = model_util.get_model_spec(eval_config, model_name) - if not model_spec: - raise ValueError( - 'ModelSpec for model name {} not found in EvalConfig: ' - 'config={}'.format(model_name, eval_config) - ) - is_baseline = model_spec.is_baseline - model_type = model_util.get_model_type( - model_spec, eval_saved_model_path, tags - ) - if tags is None: - # Default to serving unless tfma_eval is used. - if model_type == constants.TFMA_EVAL: - tags = [_LEGACY_EVAL_TAG] - else: - tags = [tf.saved_model.SERVING] - if model_spec.example_weight_key or model_spec.example_weight_keys: - example_weight_key = ( - model_spec.example_weight_key or model_spec.example_weight_keys - ) - if eval_config.options.HasField('include_default_metrics'): - include_default_metrics = ( - eval_config.options.include_default_metrics.value - ) - - model_loader = custom_model_loader - if not model_loader and model_type in constants.VALID_TF_MODEL_TYPES: - model_loader = types.ModelLoader( - construct_fn=model_util.model_construct_fn( - eval_saved_model_path=eval_saved_model_path, - add_metrics_callbacks=add_metrics_callbacks, - include_default_metrics=include_default_metrics, - additional_fetches=additional_fetches, - blacklist_feature_fetches=blacklist_feature_fetches, - model_type=model_type, + model_spec = model_util.get_model_spec(eval_config, model_name) + if not model_spec: + raise ValueError( + f"ModelSpec for model name {model_name} not found in EvalConfig: " + f"config={eval_config}" + ) + is_baseline = model_spec.is_baseline + model_type = model_util.get_model_type(model_spec, eval_saved_model_path, tags) + if tags is None: + # Default to serving unless tfma_eval is used. + if model_type == constants.TFMA_EVAL: + tags = [_LEGACY_EVAL_TAG] + else: + tags = [tf.saved_model.SERVING] + if model_spec.example_weight_key or model_spec.example_weight_keys: + example_weight_key = ( + model_spec.example_weight_key or model_spec.example_weight_keys + ) + if eval_config.options.HasField("include_default_metrics"): + include_default_metrics = eval_config.options.include_default_metrics.value + + model_loader = custom_model_loader + if not model_loader and model_type in constants.VALID_TF_MODEL_TYPES: + model_loader = types.ModelLoader( + construct_fn=model_util.model_construct_fn( + eval_saved_model_path=eval_saved_model_path, + add_metrics_callbacks=add_metrics_callbacks, + include_default_metrics=include_default_metrics, + additional_fetches=additional_fetches, + blacklist_feature_fetches=blacklist_feature_fetches, + model_type=model_type, + tags=tags, + ), tags=tags, - ), - tags=tags, - ) + ) - return types.EvalSharedModel( - model_name=model_name, - model_type=model_type, - model_path=eval_saved_model_path, - add_metrics_callbacks=add_metrics_callbacks, - include_default_metrics=include_default_metrics, - example_weight_key=example_weight_key, - additional_fetches=additional_fetches, - model_loader=model_loader, - rubber_stamp=rubber_stamp, - is_baseline=is_baseline, - resource_hints=resource_hints, - backend_config=backend_config, - ) + return types.EvalSharedModel( + model_name=model_name, + model_type=model_type, + model_path=eval_saved_model_path, + add_metrics_callbacks=add_metrics_callbacks, + include_default_metrics=include_default_metrics, + example_weight_key=example_weight_key, + additional_fetches=additional_fetches, + model_loader=model_loader, + rubber_stamp=rubber_stamp, + is_baseline=is_baseline, + resource_hints=resource_hints, + backend_config=backend_config, + ) def _has_sql_slices(eval_config: Optional[config_pb2.EvalConfig]) -> bool: - if eval_config: - for spec in eval_config.slicing_specs: - if spec.slice_keys_sql: - return True - return False + if eval_config: + for spec in eval_config.slicing_specs: + if spec.slice_keys_sql: + return True + return False def default_extractors( # pylint: disable=invalid-name @@ -524,158 +524,160 @@ def default_extractors( # pylint: disable=invalid-name custom_predict_extractor: Optional[extractor.Extractor] = None, config_version: Optional[int] = None, ) -> List[extractor.Extractor]: - """Returns the default extractors for use in ExtractAndEvaluate. - - Args: - eval_shared_model: Shared model (single-model evaluation) or list of shared - models (multi-model evaluation). Required unless the predictions are - provided alongside of the features (i.e. model-agnostic evaluations). - eval_config: Eval config. - slice_spec: Deprecated (use EvalConfig). - materialize: True to have extractors create materialized output. - tensor_adapter_config: Tensor adapter config which specifies how to obtain - tensors from the Arrow RecordBatch. If None, an attempt will be made to - create the tensors using default TensorRepresentations. - custom_predict_extractor: Optional custom predict extractor for non-TF - models. - config_version: Optional config version for this evaluation. This should not - be explicitly set by users. It is only intended to be used in cases where - the provided eval_config was generated internally, and thus not a reliable - indicator of user intent. - - Raises: - NotImplementedError: If eval_config contains mixed serving and eval models. - """ - if materialize is None: - # TODO(b/172969312): Once analysis table is supported, remove defaulting - # to false unless 'analysis' is in disabled_outputs. - materialize = False - if slice_spec and eval_config: - raise ValueError('slice_spec is deprecated, only use eval_config') - - if eval_config is not None: - eval_config = _update_eval_config_with_defaults( - eval_config, eval_shared_model + """Returns the default extractors for use in ExtractAndEvaluate. + + Args: + ---- + eval_shared_model: Shared model (single-model evaluation) or list of shared + models (multi-model evaluation). Required unless the predictions are + provided alongside of the features (i.e. model-agnostic evaluations). + eval_config: Eval config. + slice_spec: Deprecated (use EvalConfig). + materialize: True to have extractors create materialized output. + tensor_adapter_config: Tensor adapter config which specifies how to obtain + tensors from the Arrow RecordBatch. If None, an attempt will be made to + create the tensors using default TensorRepresentations. + custom_predict_extractor: Optional custom predict extractor for non-TF + models. + config_version: Optional config version for this evaluation. This should not + be explicitly set by users. It is only intended to be used in cases where + the provided eval_config was generated internally, and thus not a reliable + indicator of user intent. + + Raises: + ------ + NotImplementedError: If eval_config contains mixed serving and eval models. + """ + if materialize is None: + # TODO(b/172969312): Once analysis table is supported, remove defaulting + # to false unless 'analysis' is in disabled_outputs. + materialize = False + if slice_spec and eval_config: + raise ValueError("slice_spec is deprecated, only use eval_config") + + if eval_config is not None: + eval_config = _update_eval_config_with_defaults(eval_config, eval_shared_model) + tensor_representations = None + if tensor_adapter_config: + tensor_representations = tensor_adapter_config.tensor_representations + + eval_shared_models = model_util.verify_and_update_eval_shared_models( + eval_shared_model ) - tensor_representations = None - if tensor_adapter_config: - tensor_representations = tensor_adapter_config.tensor_representations - - eval_shared_models = model_util.verify_and_update_eval_shared_models( - eval_shared_model - ) - slicing_extractors = [] - if _has_sql_slices(eval_config): - slicing_extractors.append( - sql_slice_key_extractor.SqlSliceKeyExtractor(eval_config) + slicing_extractors = [] + if _has_sql_slices(eval_config): + slicing_extractors.append( + sql_slice_key_extractor.SqlSliceKeyExtractor(eval_config) + ) + slicing_extractors.extend( + [ + unbatch_extractor.UnbatchExtractor(), + slice_key_extractor.SliceKeyExtractor( + eval_config=eval_config, materialize=materialize + ), + ] ) - slicing_extractors.extend([ - unbatch_extractor.UnbatchExtractor(), - slice_key_extractor.SliceKeyExtractor( - eval_config=eval_config, materialize=materialize - ), - ]) - - extract_features = features_extractor.FeaturesExtractor( - eval_config=eval_config, tensor_representations=tensor_representations - ) - extract_labels = labels_extractor.LabelsExtractor(eval_config=eval_config) - extract_example_weights = example_weights_extractor.ExampleWeightsExtractor( - eval_config=eval_config - ) - extract_materialized_predictions = ( - materialized_predictions_extractor.MaterializedPredictionsExtractor( - eval_config=eval_config - ) - ) - if eval_shared_model: - model_types = _model_types(eval_shared_models) - logging.info('eval_shared_models have model_types: %s', model_types) - assert model_types is not None - if ( - not model_types.issubset(constants.VALID_TF_MODEL_TYPES) - and not custom_predict_extractor - ): - raise NotImplementedError( - 'either a custom_predict_extractor must be used or model type must ' - 'be one of: {}. evalconfig={}'.format( - str(constants.VALID_TF_MODEL_TYPES), eval_config - ) - ) - - if model_types == {constants.MATERIALIZED_PREDICTION}: - return [ - extract_features, - extract_labels, - extract_example_weights, - extract_materialized_predictions, - ] + slicing_extractors - elif model_types == {constants.TF_LITE}: - # TODO(b/163889779): Convert TFLite extractor to operate on batched - # extracts. Then we can remove the input extractor. - return [ - extract_features, - transformed_features_extractor.TransformedFeaturesExtractor( - eval_config=eval_config, eval_shared_model=eval_shared_model - ), - extract_labels, - extract_example_weights, - ( - custom_predict_extractor - or tflite_predict_extractor.TFLitePredictExtractor( - eval_config=eval_config, eval_shared_model=eval_shared_model - ) - ), - ] + slicing_extractors - elif constants.TF_LITE in model_types: - raise NotImplementedError( - 'support for mixing tf_lite and non-tf_lite models is not ' - 'implemented: eval_config={}'.format(eval_config) - ) - elif model_types == {constants.TF_JS}: - return [ - extract_features, - extract_labels, - extract_example_weights, - ( - custom_predict_extractor - or tfjs_predict_extractor.TFJSPredictExtractor( - eval_config=eval_config, eval_shared_model=eval_shared_model - ) - ), - ] + slicing_extractors - elif constants.TF_JS in model_types: - raise NotImplementedError( - 'support for mixing tf_js and non-tf_js models is not ' - 'implemented: eval_config={}'.format(eval_config) - ) - else: - extractors = [extract_features] - if not custom_predict_extractor: - extractors.append( - transformed_features_extractor.TransformedFeaturesExtractor( - eval_config=eval_config, eval_shared_model=eval_shared_model - ) + + extract_features = features_extractor.FeaturesExtractor( + eval_config=eval_config, tensor_representations=tensor_representations + ) + extract_labels = labels_extractor.LabelsExtractor(eval_config=eval_config) + extract_example_weights = example_weights_extractor.ExampleWeightsExtractor( + eval_config=eval_config + ) + extract_materialized_predictions = ( + materialized_predictions_extractor.MaterializedPredictionsExtractor( + eval_config=eval_config ) - extractors.extend([ - extract_labels, - extract_example_weights, - ( - custom_predict_extractor - or predictions_extractor.PredictionsExtractor( - eval_config=eval_config, eval_shared_model=eval_shared_model - ) - ), - ]) - extractors.extend(slicing_extractors) - return extractors - else: - return [ - extract_features, - extract_labels, - extract_example_weights, - extract_materialized_predictions, - ] + slicing_extractors + ) + if eval_shared_model: + model_types = _model_types(eval_shared_models) + logging.info("eval_shared_models have model_types: %s", model_types) + assert model_types is not None + if ( + not model_types.issubset(constants.VALID_TF_MODEL_TYPES) + and not custom_predict_extractor + ): + raise NotImplementedError( + "either a custom_predict_extractor must be used or model type must " + f"be one of: {str(constants.VALID_TF_MODEL_TYPES)}. evalconfig={eval_config}" + ) + + if model_types == {constants.MATERIALIZED_PREDICTION}: + return [ + extract_features, + extract_labels, + extract_example_weights, + extract_materialized_predictions, + ] + slicing_extractors + elif model_types == {constants.TF_LITE}: + # TODO(b/163889779): Convert TFLite extractor to operate on batched + # extracts. Then we can remove the input extractor. + return [ + extract_features, + transformed_features_extractor.TransformedFeaturesExtractor( + eval_config=eval_config, eval_shared_model=eval_shared_model + ), + extract_labels, + extract_example_weights, + ( + custom_predict_extractor + or tflite_predict_extractor.TFLitePredictExtractor( + eval_config=eval_config, eval_shared_model=eval_shared_model + ) + ), + ] + slicing_extractors + elif constants.TF_LITE in model_types: + raise NotImplementedError( + "support for mixing tf_lite and non-tf_lite models is not " + f"implemented: eval_config={eval_config}" + ) + elif model_types == {constants.TF_JS}: + return [ + extract_features, + extract_labels, + extract_example_weights, + ( + custom_predict_extractor + or tfjs_predict_extractor.TFJSPredictExtractor( + eval_config=eval_config, eval_shared_model=eval_shared_model + ) + ), + ] + slicing_extractors + elif constants.TF_JS in model_types: + raise NotImplementedError( + "support for mixing tf_js and non-tf_js models is not " + f"implemented: eval_config={eval_config}" + ) + else: + extractors = [extract_features] + if not custom_predict_extractor: + extractors.append( + transformed_features_extractor.TransformedFeaturesExtractor( + eval_config=eval_config, eval_shared_model=eval_shared_model + ) + ) + extractors.extend( + [ + extract_labels, + extract_example_weights, + ( + custom_predict_extractor + or predictions_extractor.PredictionsExtractor( + eval_config=eval_config, eval_shared_model=eval_shared_model + ) + ), + ] + ) + extractors.extend(slicing_extractors) + return extractors + else: + return [ + extract_features, + extract_labels, + extract_example_weights, + extract_materialized_predictions, + ] + slicing_extractors def default_evaluators( # pylint: disable=invalid-name @@ -688,56 +690,55 @@ def default_evaluators( # pylint: disable=invalid-name random_seed_for_testing: Optional[int] = None, config_version: Optional[int] = None, ) -> List[evaluator.Evaluator]: - """Returns the default evaluators for use in ExtractAndEvaluate. - - Args: - eval_shared_model: Optional shared model (single-model evaluation) or list - of shared models (multi-model evaluation). Only required if there are - metrics to be computed in-graph using the model. - eval_config: Eval config. - schema: A schema to use for customizing default evaluators. - compute_confidence_intervals: Deprecated (use eval_config). - min_slice_size: Deprecated (use eval_config). - serialize: Deprecated. - random_seed_for_testing: Provide for deterministic tests only. - config_version: Optional config version for this evaluation. This should not - be explicitly set by users. It is only intended to be used in cases where - the provided eval_config was generated internally, and thus not a reliable - indicator of user intent. - """ - disabled_outputs = [] - eval_shared_models = model_util.verify_and_update_eval_shared_models( - eval_shared_model - ) - if eval_config: - eval_config = _update_eval_config_with_defaults( - eval_config, eval_shared_model + """Returns the default evaluators for use in ExtractAndEvaluate. + + Args: + ---- + eval_shared_model: Optional shared model (single-model evaluation) or list + of shared models (multi-model evaluation). Only required if there are + metrics to be computed in-graph using the model. + eval_config: Eval config. + schema: A schema to use for customizing default evaluators. + compute_confidence_intervals: Deprecated (use eval_config). + min_slice_size: Deprecated (use eval_config). + serialize: Deprecated. + random_seed_for_testing: Provide for deterministic tests only. + config_version: Optional config version for this evaluation. This should not + be explicitly set by users. It is only intended to be used in cases where + the provided eval_config was generated internally, and thus not a reliable + indicator of user intent. + """ + disabled_outputs = [] + eval_shared_models = model_util.verify_and_update_eval_shared_models( + eval_shared_model ) - disabled_outputs = eval_config.options.disabled_outputs.values - if _model_types(eval_shared_models) == {constants.TF_LITE} or _model_types( - eval_shared_models - ) == {constants.TF_JS}: - # no in-graph metrics present when tflite or tfjs is used. - if eval_shared_models: - eval_shared_models = [ - v._replace(include_default_metrics=False) - for v in eval_shared_models - ] - if ( - constants.METRICS_KEY in disabled_outputs - and constants.PLOTS_KEY in disabled_outputs - and constants.ATTRIBUTIONS_KEY in disabled_outputs - ): - return [] - - return [ - metrics_plots_and_validations_evaluator.MetricsPlotsAndValidationsEvaluator( - eval_config=eval_config, - eval_shared_model=eval_shared_model, - schema=schema, - random_seed_for_testing=random_seed_for_testing, - ) - ] + if eval_config: + eval_config = _update_eval_config_with_defaults(eval_config, eval_shared_model) + disabled_outputs = eval_config.options.disabled_outputs.values + if _model_types(eval_shared_models) == {constants.TF_LITE} or _model_types( + eval_shared_models + ) == {constants.TF_JS}: + # no in-graph metrics present when tflite or tfjs is used. + if eval_shared_models: + eval_shared_models = [ + v._replace(include_default_metrics=False) + for v in eval_shared_models + ] + if ( + constants.METRICS_KEY in disabled_outputs + and constants.PLOTS_KEY in disabled_outputs + and constants.ATTRIBUTIONS_KEY in disabled_outputs + ): + return [] + + return [ + metrics_plots_and_validations_evaluator.MetricsPlotsAndValidationsEvaluator( + eval_config=eval_config, + eval_shared_model=eval_shared_model, + schema=schema, + random_seed_for_testing=random_seed_for_testing, + ) + ] def default_writers( @@ -746,92 +747,91 @@ def default_writers( eval_config: Optional[config_pb2.EvalConfig] = None, display_only_data_location: Optional[str] = None, display_only_data_file_format: Optional[str] = None, - output_file_format: str = 'tfrecord', + output_file_format: str = "tfrecord", add_metric_callbacks: Optional[List[types.AddMetricsCallbackType]] = None, ) -> List[writer.Writer]: # pylint: disable=invalid-name - """Returns the default writers for use in WriteResults. - - Note, sharding will be enabled by default if an output_file_format is - provided. Filenames will be -SSSSS-of-NNNNN. - where SSSSS is the shard number and NNNNN is the number of shards. - - Args: - output_path: Output path. - eval_shared_model: Optional shared model (single-model evaluation) or list - of shared models (multi-model evaluation). Required unless the predictions - are provided alongside of the features (i.e. model-agnostic evaluations). - eval_config: Eval config for writing out config along with results. Also - used for to check for missing slices. - display_only_data_location: Optional path indicating where the examples were - read from. This is used only for display purposes - data will not actually - be read from this path. - display_only_data_file_format: Optional format of the input examples. This - is used only for display purposes. - output_file_format: File format to use when saving files. Currently only - 'tfrecord' is supported. - add_metric_callbacks: Optional list of metric callbacks (if used). - """ - writers = [] - - if not add_metric_callbacks: - add_metric_callbacks = [] - # The add_metric_callbacks are used in the metrics and plots serialization - # code to post process the metric data by calling populate_stats_and_pop. - # While both the legacy (V1) and new (V2) evaluation implementations support - # EvalSavedModels using add_metric_callbacks, this particular code is only - # required for the legacy evaluation based on the MetricsAndPlotsEvaluator. - # The V2 MetricsAndPlotsEvaluator output requires no additional processing. - # Since the V1 code only supports a single EvalSharedModel, we only set the - # add_metrics_callbacks if a dict is not passed. - if ( - eval_shared_model - and not isinstance(eval_shared_model, dict) - and not isinstance(eval_shared_model, list) - ): - add_metric_callbacks = eval_shared_model.add_metrics_callbacks - - eval_shared_models = model_util.verify_and_update_eval_shared_models( - eval_shared_model - ) - - if eval_config: - model_locations = {} - for v in eval_shared_models or [None]: - k = '' if v is None else v.model_name - model_locations[k] = ( - '' if v is None or v.model_path is None else v.model_path - ) + """Returns the default writers for use in WriteResults. + + Note, sharding will be enabled by default if an output_file_format is + provided. Filenames will be -SSSSS-of-NNNNN. + where SSSSS is the shard number and NNNNN is the number of shards. + + Args: + ---- + output_path: Output path. + eval_shared_model: Optional shared model (single-model evaluation) or list + of shared models (multi-model evaluation). Required unless the predictions + are provided alongside of the features (i.e. model-agnostic evaluations). + eval_config: Eval config for writing out config along with results. Also + used for to check for missing slices. + display_only_data_location: Optional path indicating where the examples were + read from. This is used only for display purposes - data will not actually + be read from this path. + display_only_data_file_format: Optional format of the input examples. This + is used only for display purposes. + output_file_format: File format to use when saving files. Currently only + 'tfrecord' is supported. + add_metric_callbacks: Optional list of metric callbacks (if used). + """ + writers = [] + + if not add_metric_callbacks: + add_metric_callbacks = [] + # The add_metric_callbacks are used in the metrics and plots serialization + # code to post process the metric data by calling populate_stats_and_pop. + # While both the legacy (V1) and new (V2) evaluation implementations support + # EvalSavedModels using add_metric_callbacks, this particular code is only + # required for the legacy evaluation based on the MetricsAndPlotsEvaluator. + # The V2 MetricsAndPlotsEvaluator output requires no additional processing. + # Since the V1 code only supports a single EvalSharedModel, we only set the + # add_metrics_callbacks if a dict is not passed. + if ( + eval_shared_model + and not isinstance(eval_shared_model, dict) + and not isinstance(eval_shared_model, list) + ): + add_metric_callbacks = eval_shared_model.add_metrics_callbacks + + eval_shared_models = model_util.verify_and_update_eval_shared_models( + eval_shared_model + ) + + if eval_config: + model_locations = {} + for v in eval_shared_models or [None]: + k = "" if v is None else v.model_name + model_locations[k] = ( + "" if v is None or v.model_path is None else v.model_path + ) + writers.append( + eval_config_writer.EvalConfigWriter( + output_path, + eval_config=eval_config, + data_location=display_only_data_location, + data_file_format=display_only_data_file_format, + model_locations=model_locations, + ) + ) + + output_paths = { + constants.METRICS_KEY: os.path.join(output_path, constants.METRICS_KEY), + constants.PLOTS_KEY: os.path.join(output_path, constants.PLOTS_KEY), + constants.ATTRIBUTIONS_KEY: os.path.join( + output_path, constants.ATTRIBUTIONS_KEY + ), + constants.VALIDATIONS_KEY: os.path.join(output_path, constants.VALIDATIONS_KEY), + } writers.append( - eval_config_writer.EvalConfigWriter( - output_path, - eval_config=eval_config, - data_location=display_only_data_location, - data_file_format=display_only_data_file_format, - model_locations=model_locations, + metrics_plots_and_validations_writer.MetricsPlotsAndValidationsWriter( + output_paths=output_paths, + # Empty EvalConfig supported for backwards compatibility. + eval_config=eval_config or config_pb2.EvalConfig(), + add_metrics_callbacks=add_metric_callbacks, + output_file_format=output_file_format, + rubber_stamp=model_util.has_rubber_stamp(eval_shared_models), ) ) - - output_paths = { - constants.METRICS_KEY: os.path.join(output_path, constants.METRICS_KEY), - constants.PLOTS_KEY: os.path.join(output_path, constants.PLOTS_KEY), - constants.ATTRIBUTIONS_KEY: os.path.join( - output_path, constants.ATTRIBUTIONS_KEY - ), - constants.VALIDATIONS_KEY: os.path.join( - output_path, constants.VALIDATIONS_KEY - ), - } - writers.append( - metrics_plots_and_validations_writer.MetricsPlotsAndValidationsWriter( - output_paths=output_paths, - # Empty EvalConfig supported for backwards compatibility. - eval_config=eval_config or config_pb2.EvalConfig(), - add_metrics_callbacks=add_metric_callbacks, - output_file_format=output_file_format, - rubber_stamp=model_util.has_rubber_stamp(eval_shared_models), - ) - ) - return writers + return writers @beam.ptransform_fn @@ -842,17 +842,17 @@ def default_writers( def InputsToExtracts( # pylint: disable=invalid-name inputs: beam.pvalue.PCollection, ) -> beam.pvalue.PCollection: - """Converts serialized inputs (e.g. examples) to Extracts if not already.""" + """Converts serialized inputs (e.g. examples) to Extracts if not already.""" - def to_extracts(x: Union[bytes, str, types.Extracts]) -> types.Extracts: - result = {} - if isinstance(x, dict): - result.update(x) - else: - result[constants.INPUT_KEY] = x - return result + def to_extracts(x: Union[bytes, str, types.Extracts]) -> types.Extracts: + result = {} + if isinstance(x, dict): + result.update(x) + else: + result[constants.INPUT_KEY] = x + return result - return inputs | 'AddInputKey' >> beam.Map(to_extracts) + return inputs | "AddInputKey" >> beam.Map(to_extracts) @beam.ptransform_fn @@ -861,19 +861,19 @@ def to_extracts(x: Union[bytes, str, types.Extracts]) -> types.Extracts: def BatchedInputsToExtracts( # pylint: disable=invalid-name batched_inputs: beam.pvalue.PCollection, ) -> beam.pvalue.PCollection: - """Converts Arrow RecordBatch inputs to Extracts.""" - - def to_extracts( - x: Union[bytes, types.Extracts, pa.RecordBatch], - ) -> types.Extracts: - result = {} - if isinstance(x, dict): - result.update(x) - else: - result[constants.ARROW_RECORD_BATCH_KEY] = x - return result + """Converts Arrow RecordBatch inputs to Extracts.""" + + def to_extracts( + x: Union[bytes, types.Extracts, pa.RecordBatch], + ) -> types.Extracts: + result = {} + if isinstance(x, dict): + result.update(x) + else: + result[constants.ARROW_RECORD_BATCH_KEY] = x + return result - return batched_inputs | 'AddArrowRecordBatchKey' >> beam.Map(to_extracts) + return batched_inputs | "AddArrowRecordBatchKey" >> beam.Map(to_extracts) @beam.ptransform_fn @@ -884,93 +884,91 @@ def ExtractAndEvaluate( # pylint: disable=invalid-name extractors: List[extractor.Extractor], evaluators: List[evaluator.Evaluator], ) -> evaluator.Evaluation: - """Performs Extractions and Evaluations in provided order.""" - # evaluation[k] = list of values for k - evaluation = {} - - def update(evaluation: Dict[str, Any], new_evaluation: Dict[str, Any]): - for k, v in new_evaluation.items(): - if k not in evaluation: - evaluation[k] = [] - evaluation[k].append(v) - return evaluation - - _ = extracts | 'TrackInputBytes' >> _TrackBytesProcessed() # pylint: disable=no-value-for-parameter - # Run evaluators that run before extraction (i.e. that only require - # the incoming input extract added by ReadInputs) - for v in evaluators: - if not v.run_after: - update(evaluation, extracts | v.stage_name >> v.ptransform) - for x in extractors: - extracts = extracts | x.stage_name >> x.ptransform + """Performs Extractions and Evaluations in provided order.""" + # evaluation[k] = list of values for k + evaluation = {} + + def update(evaluation: Dict[str, Any], new_evaluation: Dict[str, Any]): + for k, v in new_evaluation.items(): + if k not in evaluation: + evaluation[k] = [] + evaluation[k].append(v) + return evaluation + + _ = extracts | "TrackInputBytes" >> _TrackBytesProcessed() # pylint: disable=no-value-for-parameter + # Run evaluators that run before extraction (i.e. that only require + # the incoming input extract added by ReadInputs) for v in evaluators: - if v.run_after == x.stage_name: - update(evaluation, extracts | v.stage_name >> v.ptransform) - for v in evaluators: - if v.run_after == extractor.LAST_EXTRACTOR_STAGE_NAME: - update(evaluation, extracts | v.stage_name >> v.ptransform) - - # Merge multi-valued keys if necessary. - result = {} - for k, v in evaluation.items(): - if len(v) == 1: - result[k] = v[0] - continue - - # Note that we assume that if a key is multivalued, its values are - # dictionaries with disjoint keys. The combined value will simply be the - # disjoint union of all the dictionaries. - result[k] = ( - v - | 'FlattenEvaluationOutput(%s)' % k >> beam.Flatten() - | 'CombineEvaluationOutput(%s)' % k - >> beam.CombinePerKey(_CombineEvaluationDictionariesFn()) - ) + if not v.run_after: + update(evaluation, extracts | v.stage_name >> v.ptransform) + for x in extractors: + extracts = extracts | x.stage_name >> x.ptransform + for v in evaluators: + if v.run_after == x.stage_name: + update(evaluation, extracts | v.stage_name >> v.ptransform) + for v in evaluators: + if v.run_after == extractor.LAST_EXTRACTOR_STAGE_NAME: + update(evaluation, extracts | v.stage_name >> v.ptransform) - return result + # Merge multi-valued keys if necessary. + result = {} + for k, v in evaluation.items(): + if len(v) == 1: + result[k] = v[0] + continue + + # Note that we assume that if a key is multivalued, its values are + # dictionaries with disjoint keys. The combined value will simply be the + # disjoint union of all the dictionaries. + result[k] = ( + v + | "FlattenEvaluationOutput(%s)" % k >> beam.Flatten() + | "CombineEvaluationOutput(%s)" % k + >> beam.CombinePerKey(_CombineEvaluationDictionariesFn()) + ) + + return result class _CombineEvaluationDictionariesFn(beam.CombineFn): - """CombineFn to combine dictionaries generated by different evaluators.""" - - def create_accumulator(self) -> Dict[str, Any]: - return {} - - def _merge( - self, accumulator: Dict[str, Any], output_dict: Dict[str, Any] - ) -> None: - intersection = set(accumulator) & set(output_dict) - if intersection: - raise ValueError( - 'Dictionaries generated by different evaluators should have ' - 'different keys, but keys %s appeared in the output of multiple ' - 'evaluators' % intersection - ) - accumulator.update(output_dict) - - def add_input( - self, accumulator: Dict[str, Any], output_dict: Dict[str, Any] - ) -> Dict[str, Any]: - if not isinstance(output_dict, dict): - raise TypeError( - 'for outputs written to by multiple evaluators, the outputs must all ' - 'be dictionaries, but got output of type %s, value %s' - % (type(output_dict), str(output_dict)) - ) - self._merge(accumulator, output_dict) - return accumulator - - def merge_accumulators( - self, accumulators: Iterable[Dict[str, Any]] - ) -> Dict[str, Any]: - accumulators = iter(accumulators) - result = next(accumulators) - for acc in accumulators: - self._merge(result, acc) - return result + """CombineFn to combine dictionaries generated by different evaluators.""" + + def create_accumulator(self) -> Dict[str, Any]: + return {} + + def _merge(self, accumulator: Dict[str, Any], output_dict: Dict[str, Any]) -> None: + intersection = set(accumulator) & set(output_dict) + if intersection: + raise ValueError( + "Dictionaries generated by different evaluators should have " + "different keys, but keys %s appeared in the output of multiple " + "evaluators" % intersection + ) + accumulator.update(output_dict) + + def add_input( + self, accumulator: Dict[str, Any], output_dict: Dict[str, Any] + ) -> Dict[str, Any]: + if not isinstance(output_dict, dict): + raise TypeError( + "for outputs written to by multiple evaluators, the outputs must all " + "be dictionaries, but got output of type %s, value %s" + % (type(output_dict), str(output_dict)) + ) + self._merge(accumulator, output_dict) + return accumulator - def extract_output(self, accumulator: Dict[str, Any]) -> Dict[str, Any]: - return accumulator + def merge_accumulators( + self, accumulators: Iterable[Dict[str, Any]] + ) -> Dict[str, Any]: + accumulators = iter(accumulators) + result = next(accumulators) + for acc in accumulators: + self._merge(result, acc) + return result + + def extract_output(self, accumulator: Dict[str, Any]) -> Dict[str, Any]: + return accumulator @beam.ptransform_fn @@ -979,52 +977,53 @@ def WriteResults( # pylint: disable=invalid-name evaluation_or_validation: Union[evaluator.Evaluation, validator.Validation], writers: List[writer.Writer], ) -> Dict[str, beam.PCollection]: - """Writes Evaluation or Validation results using given writers. - - Args: - evaluation_or_validation: Evaluation or Validation output. - writers: Writes to use for writing out output. - - Raises: - ValueError: If Evaluation or Validation is empty. - - Returns: - A dict of writer results keyed by the writer stage name. - """ - if not evaluation_or_validation: - raise ValueError('Evaluations and Validations cannot be empty') - result = {} - for w in writers: - result[w.stage_name] = ( - evaluation_or_validation | w.stage_name >> w.ptransform - ) - return result + """Writes Evaluation or Validation results using given writers. + + Args: + ---- + evaluation_or_validation: Evaluation or Validation output. + writers: Writes to use for writing out output. + + Raises: + ------ + ValueError: If Evaluation or Validation is empty. + + Returns: + ------- + A dict of writer results keyed by the writer stage name. + """ + if not evaluation_or_validation: + raise ValueError("Evaluations and Validations cannot be empty") + result = {} + for w in writers: + result[w.stage_name] = evaluation_or_validation | w.stage_name >> w.ptransform + return result def is_legacy_estimator( eval_shared_model: Optional[types.MaybeMultipleEvalSharedModels] = None, ) -> bool: - """Returns true if there is a legacy estimator. - - Args: - eval_shared_model: Shared model (single-model evaluation) or list of shared - models (multi-model evaluation). Required unless the predictions are - provided alongside of the features (i.e. model-agnostic evaluations). - - Returns: - A boolean indicating if legacy predict extractor will be used. - """ - eval_shared_models = model_util.verify_and_update_eval_shared_models( - eval_shared_model - ) - model_types = _model_types(eval_shared_models) - return ( - model_types is not None - and model_types == {constants.TFMA_EVAL} - and all( - _LEGACY_EVAL_TAG in m.model_loader.tags for m in eval_shared_models - ) - ) + """Returns true if there is a legacy estimator. + + Args: + ---- + eval_shared_model: Shared model (single-model evaluation) or list of shared + models (multi-model evaluation). Required unless the predictions are + provided alongside of the features (i.e. model-agnostic evaluations). + + Returns: + ------- + A boolean indicating if legacy predict extractor will be used. + """ + eval_shared_models = model_util.verify_and_update_eval_shared_models( + eval_shared_model + ) + model_types = _model_types(eval_shared_models) + return ( + model_types is not None + and model_types == {constants.TFMA_EVAL} + and all(_LEGACY_EVAL_TAG in m.model_loader.tags for m in eval_shared_models) + ) def is_batched_input( @@ -1032,30 +1031,32 @@ def is_batched_input( eval_config: Optional[config_pb2.EvalConfig] = None, config_version: Optional[int] = None, ) -> bool: - """Returns true if batched input should be used. - - We will keep supporting the legacy unbatched V1 PredictExtractor as it parses - the features and labels, and is the only solution currently that allows for - slicing on transformed features. Eventually we should have support for - transformed features via keras preprocessing layers. - - Args: - eval_shared_model: Shared model (single-model evaluation) or list of shared - models (multi-model evaluation). Required unless the predictions are - provided alongside of the features (i.e. model-agnostic evaluations). - eval_config: Eval config. - config_version: Optional config version for this evaluation. This should not - be explicitly set by users. It is only intended to be used in cases where - the provided eval_config was generated internally, and thus not a reliable - indicator of user intent. - - Returns: - A boolean indicating if batched extractors should be used. - """ - eval_shared_models = model_util.verify_and_update_eval_shared_models( - eval_shared_model - ) - return not _is_legacy_eval(config_version, eval_shared_models, eval_config) + """Returns true if batched input should be used. + + We will keep supporting the legacy unbatched V1 PredictExtractor as it parses + the features and labels, and is the only solution currently that allows for + slicing on transformed features. Eventually we should have support for + transformed features via keras preprocessing layers. + + Args: + ---- + eval_shared_model: Shared model (single-model evaluation) or list of shared + models (multi-model evaluation). Required unless the predictions are + provided alongside of the features (i.e. model-agnostic evaluations). + eval_config: Eval config. + config_version: Optional config version for this evaluation. This should not + be explicitly set by users. It is only intended to be used in cases where + the provided eval_config was generated internally, and thus not a reliable + indicator of user intent. + + Returns: + ------- + A boolean indicating if batched extractors should be used. + """ + eval_shared_models = model_util.verify_and_update_eval_shared_models( + eval_shared_model + ) + return not _is_legacy_eval(config_version, eval_shared_models, eval_config) @beam.ptransform_fn @@ -1078,158 +1079,159 @@ def ExtractEvaluateAndWriteResults( # pylint: disable=invalid-name schema: Optional[schema_pb2.Schema] = None, config_version: Optional[int] = None, ) -> Dict[str, beam.PCollection]: - """PTransform for performing extraction, evaluation, and writing results. - - Users who want to construct their own Beam pipelines instead of using the - lightweight run_model_analysis functions should use this PTransform. - - Example usage: - - ```python - eval_config = tfma.EvalConfig(model_specs=[...], metrics_specs=[...], - slicing_specs=[...]) - eval_shared_model = tfma.default_eval_shared_model( - eval_saved_model_path=model_location, eval_config=eval_config) - tfx_io = tf_example_record.TFExampleRecord( - file_pattern=data_location, - raw_record_column_name=tfma.ARROW_INPUT_COLUMN) - with beam.Pipeline(runner=...) as p: - _ = (p - | 'ReadData' >> tfx_io.BeamSource() - | 'ExtractEvaluateAndWriteResults' >> - tfma.ExtractEvaluateAndWriteResults( - eval_shared_model=eval_shared_model, - eval_config=eval_config, - ...)) - result = tfma.load_eval_result(output_path=output_path) - tfma.view.render_slicing_metrics(result) - - NOTE: If running with an EvalSavedModel (i.e. the ModelSpec has signature_name - "eval"), then instead of using the tfxio.BeamSource() code use the following - beam.io.ReadFromTFRecord(data_location) - ``` - - Note that the exact serialization format is an internal implementation detail - and subject to change. Users should only use the TFMA functions to write and - read the results. - - Args: - examples: PCollection of input examples or Arrow Record batches. Examples - can be any format the model accepts (e.g. string containing CSV row, - TensorFlow.Example, etc). If the examples are in the form of a dict it - will be assumed that input is already in the form of tfma.Extracts with - examples stored under tfma.INPUT_KEY (any other keys will be passed along - unchanged to downstream extractors and evaluators). - eval_shared_model: Optional shared model (single-model evaluation) or list - of shared models (multi-model evaluation). Only required if needed by - default extractors, evaluators, or writers and for display purposes of the - model path. - eval_config: Eval config. - extractors: Optional list of Extractors to apply to Extracts. Typically - these will be added by calling the default_extractors function. If no - extractors are provided, default_extractors (non-materialized) will be - used. - evaluators: Optional list of Evaluators for evaluating Extracts. Typically - these will be added by calling the default_evaluators function. If no - evaluators are provided, default_evaluators will be used. - writers: Optional list of Writers for writing Evaluation output. Typically - these will be added by calling the default_writers function. If no writers - are provided, default_writers will be used. - output_path: Path to output results to (config file, metrics, plots, etc). - display_only_data_location: Optional path indicating where the examples were - read from. This is used only for display purposes - data will not actually - be read from this path. - display_only_file_format: Optional format of the examples. This is used only - for display purposes. - slice_spec: Deprecated (use EvalConfig). - write_config: Deprecated (use EvalConfig). - compute_confidence_intervals: Deprecated (use EvalConfig). - min_slice_size: Deprecated (use EvalConfig). - random_seed_for_testing: Provide for deterministic tests only. - tensor_adapter_config: Tensor adapter config which specifies how to obtain - tensors from the Arrow RecordBatch. If None, an attempt will be made to - create the tensors using default TensorRepresentations. - schema: A schema to use for customizing evaluators. - config_version: Optional config version for this evaluation. This should not - be explicitly set by users. It is only intended to be used in cases where - the provided eval_config was generated internally, and thus not a reliable - indicator of user intent. - - Raises: - ValueError: If EvalConfig invalid or matching Extractor not found for an - Evaluator. - - Returns: - A dict of writer results keyed by the writer stage name. - """ - eval_shared_models = model_util.verify_and_update_eval_shared_models( - eval_shared_model - ) - - if eval_config is None: - config_version = 1 if config_version is None else config_version - eval_config = _default_eval_config( - eval_shared_models, - slice_spec, - write_config, - compute_confidence_intervals, - min_slice_size, - ) - else: - config_version = 2 if config_version is None else config_version - eval_config = _update_eval_config_with_defaults( - eval_config, eval_shared_model + """PTransform for performing extraction, evaluation, and writing results. + + Users who want to construct their own Beam pipelines instead of using the + lightweight run_model_analysis functions should use this PTransform. + + Example usage: + + ```python + eval_config = tfma.EvalConfig(model_specs=[...], metrics_specs=[...], + slicing_specs=[...]) + eval_shared_model = tfma.default_eval_shared_model( + eval_saved_model_path=model_location, eval_config=eval_config) + tfx_io = tf_example_record.TFExampleRecord( + file_pattern=data_location, + raw_record_column_name=tfma.ARROW_INPUT_COLUMN) + with beam.Pipeline(runner=...) as p: + _ = (p + | 'ReadData' >> tfx_io.BeamSource() + | 'ExtractEvaluateAndWriteResults' >> + tfma.ExtractEvaluateAndWriteResults( + eval_shared_model=eval_shared_model, + eval_config=eval_config, + ...)) + result = tfma.load_eval_result(output_path=output_path) + tfma.view.render_slicing_metrics(result) + + NOTE: If running with an EvalSavedModel (i.e. the ModelSpec has signature_name + "eval"), then instead of using the tfxio.BeamSource() code use the following + beam.io.ReadFromTFRecord(data_location) + ``` + + Note that the exact serialization format is an internal implementation detail + and subject to change. Users should only use the TFMA functions to write and + read the results. + + Args: + ---- + examples: PCollection of input examples or Arrow Record batches. Examples + can be any format the model accepts (e.g. string containing CSV row, + TensorFlow.Example, etc). If the examples are in the form of a dict it + will be assumed that input is already in the form of tfma.Extracts with + examples stored under tfma.INPUT_KEY (any other keys will be passed along + unchanged to downstream extractors and evaluators). + eval_shared_model: Optional shared model (single-model evaluation) or list + of shared models (multi-model evaluation). Only required if needed by + default extractors, evaluators, or writers and for display purposes of the + model path. + eval_config: Eval config. + extractors: Optional list of Extractors to apply to Extracts. Typically + these will be added by calling the default_extractors function. If no + extractors are provided, default_extractors (non-materialized) will be + used. + evaluators: Optional list of Evaluators for evaluating Extracts. Typically + these will be added by calling the default_evaluators function. If no + evaluators are provided, default_evaluators will be used. + writers: Optional list of Writers for writing Evaluation output. Typically + these will be added by calling the default_writers function. If no writers + are provided, default_writers will be used. + output_path: Path to output results to (config file, metrics, plots, etc). + display_only_data_location: Optional path indicating where the examples were + read from. This is used only for display purposes - data will not actually + be read from this path. + display_only_file_format: Optional format of the examples. This is used only + for display purposes. + slice_spec: Deprecated (use EvalConfig). + write_config: Deprecated (use EvalConfig). + compute_confidence_intervals: Deprecated (use EvalConfig). + min_slice_size: Deprecated (use EvalConfig). + random_seed_for_testing: Provide for deterministic tests only. + tensor_adapter_config: Tensor adapter config which specifies how to obtain + tensors from the Arrow RecordBatch. If None, an attempt will be made to + create the tensors using default TensorRepresentations. + schema: A schema to use for customizing evaluators. + config_version: Optional config version for this evaluation. This should not + be explicitly set by users. It is only intended to be used in cases where + the provided eval_config was generated internally, and thus not a reliable + indicator of user intent. + + Raises: + ------ + ValueError: If EvalConfig invalid or matching Extractor not found for an + Evaluator. + + Returns: + ------- + A dict of writer results keyed by the writer stage name. + """ + eval_shared_models = model_util.verify_and_update_eval_shared_models( + eval_shared_model ) - config_util.verify_eval_config(eval_config) - if not extractors: - extractors = default_extractors( - eval_config=eval_config, - eval_shared_model=eval_shared_model, - tensor_adapter_config=tensor_adapter_config, - config_version=config_version, - ) + if eval_config is None: + config_version = 1 if config_version is None else config_version + eval_config = _default_eval_config( + eval_shared_models, + slice_spec, + write_config, + compute_confidence_intervals, + min_slice_size, + ) + else: + config_version = 2 if config_version is None else config_version + eval_config = _update_eval_config_with_defaults(eval_config, eval_shared_model) + config_util.verify_eval_config(eval_config) - if not evaluators: - evaluators = default_evaluators( - eval_config=eval_config, - eval_shared_model=eval_shared_model, - random_seed_for_testing=random_seed_for_testing, - schema=schema, - config_version=config_version, - ) + if not extractors: + extractors = default_extractors( + eval_config=eval_config, + eval_shared_model=eval_shared_model, + tensor_adapter_config=tensor_adapter_config, + config_version=config_version, + ) - for v in evaluators: - evaluator.verify_evaluator(v, extractors) + if not evaluators: + evaluators = default_evaluators( + eval_config=eval_config, + eval_shared_model=eval_shared_model, + random_seed_for_testing=random_seed_for_testing, + schema=schema, + config_version=config_version, + ) - if not writers: - writers = default_writers( - output_path=output_path, - eval_shared_model=eval_shared_model, - eval_config=eval_config, - display_only_data_location=display_only_data_location, - display_only_data_file_format=display_only_file_format, - ) + for v in evaluators: + evaluator.verify_evaluator(v, extractors) + + if not writers: + writers = default_writers( + output_path=output_path, + eval_shared_model=eval_shared_model, + eval_config=eval_config, + display_only_data_location=display_only_data_location, + display_only_data_file_format=display_only_file_format, + ) - # pylint: disable=no-value-for-parameter - if is_batched_input(eval_shared_model, eval_config, config_version): - extracts = examples | 'BatchedInputsToExtracts' >> BatchedInputsToExtracts() - else: - extracts = examples | 'InputsToExtracts' >> InputsToExtracts() + # pylint: disable=no-value-for-parameter + if is_batched_input(eval_shared_model, eval_config, config_version): + extracts = examples | "BatchedInputsToExtracts" >> BatchedInputsToExtracts() + else: + extracts = examples | "InputsToExtracts" >> InputsToExtracts() - return ( - extracts - | 'ExtractAndEvaluate' - >> ExtractAndEvaluate(extractors=extractors, evaluators=evaluators) - | 'WriteResults' >> WriteResults(writers=writers) - ) + return ( + extracts + | "ExtractAndEvaluate" + >> ExtractAndEvaluate(extractors=extractors, evaluators=evaluators) + | "WriteResults" >> WriteResults(writers=writers) + ) def run_model_analysis( eval_shared_model: Optional[types.MaybeMultipleEvalSharedModels] = None, eval_config: Optional[config_pb2.EvalConfig] = None, - data_location: str = '', - file_format: str = 'tfrecords', + data_location: str = "", + file_format: str = "tfrecords", output_path: Optional[str] = None, extractors: Optional[List[extractor.Extractor]] = None, evaluators: Optional[List[evaluator.Evaluator]] = None, @@ -1242,125 +1244,123 @@ def run_model_analysis( random_seed_for_testing: Optional[int] = None, schema: Optional[schema_pb2.Schema] = None, ) -> Union[view_types.EvalResult, view_types.EvalResults]: - """Runs TensorFlow model analysis. - - It runs a Beam pipeline to compute the slicing metrics exported in TensorFlow - Eval SavedModel and returns the results. - - This is a simplified API for users who want to quickly get something running - locally. Users who wish to create their own Beam pipelines can use the - Evaluate PTransform instead. - - Args: - eval_shared_model: Optional shared model (single-model evaluation) or list - of shared models (multi-model evaluation). Only required if needed by - default extractors, evaluators, or writers. - eval_config: Eval config. - data_location: The location of the data files. - file_format: The file format of the data, can be either 'text' or - 'tfrecords' for now. By default, 'tfrecords' will be used. - output_path: The directory to output metrics and results to. If None, we use - a temporary directory. - extractors: Optional list of Extractors to apply to Extracts. Typically - these will be added by calling the default_extractors function. If no - extractors are provided, default_extractors (non-materialized) will be - used. - evaluators: Optional list of Evaluators for evaluating Extracts. Typically - these will be added by calling the default_evaluators function. If no - evaluators are provided, default_evaluators will be used. - writers: Optional list of Writers for writing Evaluation output. Typically - these will be added by calling the default_writers function. If no writers - are provided, default_writers will be used. - pipeline_options: Optional arguments to run the Pipeline, for instance - whether to run directly. - slice_spec: Deprecated (use EvalConfig). - write_config: Deprecated (use EvalConfig). - compute_confidence_intervals: Deprecated (use EvalConfig). - min_slice_size: Deprecated (use EvalConfig). - random_seed_for_testing: Provide for deterministic tests only. - schema: Optional tf.Metadata schema of the input data. - - Returns: - An EvalResult that can be used with the TFMA visualization functions. - - Raises: - ValueError: If the file_format is unknown to us. - """ - _assert_tensorflow_version() - - if output_path is None: - output_path = tempfile.mkdtemp() - if not tf.io.gfile.exists(output_path): - tf.io.gfile.makedirs(output_path) - - if eval_config is None: - config_version = 1 - eval_shared_models = model_util.verify_and_update_eval_shared_models( - eval_shared_model - ) - eval_config = _default_eval_config( - eval_shared_models, - slice_spec, - write_config, - compute_confidence_intervals, - min_slice_size, - ) - else: - config_version = 2 - eval_config = _update_eval_config_with_defaults( - eval_config, eval_shared_model - ) - - tensor_adapter_config = None - with beam.Pipeline(options=pipeline_options) as p: - if file_format == 'tfrecords': - if is_batched_input(eval_shared_model, eval_config, config_version): - if is_legacy_estimator(eval_shared_model): - tfxio = raw_tf_record.RawTfRecordTFXIO( - file_pattern=data_location, - raw_record_column_name=constants.ARROW_INPUT_COLUMN, - telemetry_descriptors=['StandaloneTFMA'], - ) - else: - tfxio = tf_example_record.TFExampleRecord( - file_pattern=data_location, - schema=schema, - raw_record_column_name=constants.ARROW_INPUT_COLUMN, - telemetry_descriptors=['StandaloneTFMA'], - ) - if schema is not None: - tensor_adapter_config = tensor_adapter.TensorAdapterConfig( - arrow_schema=tfxio.ArrowSchema(), - tensor_representations=tfxio.TensorRepresentations(), - ) - data = p | 'ReadFromTFRecordToArrow' >> tfxio.BeamSource() - else: - data = p | 'ReadFromTFRecord' >> beam.io.ReadFromTFRecord( - file_pattern=data_location, - compression_type=beam.io.filesystem.CompressionTypes.AUTO, + """Runs TensorFlow model analysis. + + It runs a Beam pipeline to compute the slicing metrics exported in TensorFlow + Eval SavedModel and returns the results. + + This is a simplified API for users who want to quickly get something running + locally. Users who wish to create their own Beam pipelines can use the + Evaluate PTransform instead. + + Args: + ---- + eval_shared_model: Optional shared model (single-model evaluation) or list + of shared models (multi-model evaluation). Only required if needed by + default extractors, evaluators, or writers. + eval_config: Eval config. + data_location: The location of the data files. + file_format: The file format of the data, can be either 'text' or + 'tfrecords' for now. By default, 'tfrecords' will be used. + output_path: The directory to output metrics and results to. If None, we use + a temporary directory. + extractors: Optional list of Extractors to apply to Extracts. Typically + these will be added by calling the default_extractors function. If no + extractors are provided, default_extractors (non-materialized) will be + used. + evaluators: Optional list of Evaluators for evaluating Extracts. Typically + these will be added by calling the default_evaluators function. If no + evaluators are provided, default_evaluators will be used. + writers: Optional list of Writers for writing Evaluation output. Typically + these will be added by calling the default_writers function. If no writers + are provided, default_writers will be used. + pipeline_options: Optional arguments to run the Pipeline, for instance + whether to run directly. + slice_spec: Deprecated (use EvalConfig). + write_config: Deprecated (use EvalConfig). + compute_confidence_intervals: Deprecated (use EvalConfig). + min_slice_size: Deprecated (use EvalConfig). + random_seed_for_testing: Provide for deterministic tests only. + schema: Optional tf.Metadata schema of the input data. + + Returns: + ------- + An EvalResult that can be used with the TFMA visualization functions. + + Raises: + ------ + ValueError: If the file_format is unknown to us. + """ + _assert_tensorflow_version() + + if output_path is None: + output_path = tempfile.mkdtemp() + if not tf.io.gfile.exists(output_path): + tf.io.gfile.makedirs(output_path) + + if eval_config is None: + config_version = 1 + eval_shared_models = model_util.verify_and_update_eval_shared_models( + eval_shared_model + ) + eval_config = _default_eval_config( + eval_shared_models, + slice_spec, + write_config, + compute_confidence_intervals, + min_slice_size, ) - elif file_format == 'text': - tfxio = raw_tf_record.RawBeamRecordTFXIO( - physical_format='csv', - raw_record_column_name=constants.ARROW_INPUT_COLUMN, - telemetry_descriptors=['StandaloneTFMA'], - ) - data = ( - p - | 'ReadFromText' - >> beam.io.textio.ReadFromText( - data_location, coder=beam.coders.BytesCoder() - ) - | 'ConvertToArrow' >> tfxio.BeamSource() - ) else: - raise ValueError('unknown file_format: {}'.format(file_format)) + config_version = 2 + eval_config = _update_eval_config_with_defaults(eval_config, eval_shared_model) + + tensor_adapter_config = None + with beam.Pipeline(options=pipeline_options) as p: + if file_format == "tfrecords": + if is_batched_input(eval_shared_model, eval_config, config_version): + if is_legacy_estimator(eval_shared_model): + tfxio = raw_tf_record.RawTfRecordTFXIO( + file_pattern=data_location, + raw_record_column_name=constants.ARROW_INPUT_COLUMN, + telemetry_descriptors=["StandaloneTFMA"], + ) + else: + tfxio = tf_example_record.TFExampleRecord( + file_pattern=data_location, + schema=schema, + raw_record_column_name=constants.ARROW_INPUT_COLUMN, + telemetry_descriptors=["StandaloneTFMA"], + ) + if schema is not None: + tensor_adapter_config = tensor_adapter.TensorAdapterConfig( + arrow_schema=tfxio.ArrowSchema(), + tensor_representations=tfxio.TensorRepresentations(), + ) + data = p | "ReadFromTFRecordToArrow" >> tfxio.BeamSource() + else: + data = p | "ReadFromTFRecord" >> beam.io.ReadFromTFRecord( + file_pattern=data_location, + compression_type=beam.io.filesystem.CompressionTypes.AUTO, + ) + elif file_format == "text": + tfxio = raw_tf_record.RawBeamRecordTFXIO( + physical_format="csv", + raw_record_column_name=constants.ARROW_INPUT_COLUMN, + telemetry_descriptors=["StandaloneTFMA"], + ) + data = ( + p + | "ReadFromText" + >> beam.io.textio.ReadFromText( + data_location, coder=beam.coders.BytesCoder() + ) + | "ConvertToArrow" >> tfxio.BeamSource() + ) + else: + raise ValueError(f"unknown file_format: {file_format}") - # pylint: disable=no-value-for-parameter - _ = ( - data - | 'ExtractEvaluateAndWriteResults' - >> ExtractEvaluateAndWriteResults( + # pylint: disable=no-value-for-parameter + _ = data | "ExtractEvaluateAndWriteResults" >> ExtractEvaluateAndWriteResults( eval_config=eval_config, eval_shared_model=eval_shared_model, display_only_data_location=data_location, @@ -1374,16 +1374,15 @@ def run_model_analysis( schema=schema, config_version=config_version, ) - ) - # pylint: enable=no-value-for-parameter + # pylint: enable=no-value-for-parameter - if len(eval_config.model_specs) <= 1: - return load_eval_result(output_path) - else: - results = [] - for spec in eval_config.model_specs: - results.append(load_eval_result(output_path, model_name=spec.name)) - return view_types.EvalResults(results, constants.MODEL_CENTRIC_MODE) + if len(eval_config.model_specs) <= 1: + return load_eval_result(output_path) + else: + results = [] + for spec in eval_config.model_specs: + results.append(load_eval_result(output_path, model_name=spec.name)) + return view_types.EvalResults(results, constants.MODEL_CENTRIC_MODE) def single_model_analysis( @@ -1393,86 +1392,92 @@ def single_model_analysis( eval_config: Optional[config_pb2.EvalConfig] = None, slice_spec: Optional[List[slicer.SingleSliceSpec]] = None, ) -> view_types.EvalResult: - """Run model analysis for a single model on a single data set. - - This is a convenience wrapper around run_model_analysis for a single model - with a single data set. For more complex use cases, use - tfma.run_model_analysis. - - Args: - model_location: Path to the export eval saved model. - data_location: The location of the data files. - output_path: The directory to output metrics and results to. If None, we use - a temporary directory. - eval_config: Eval config. - slice_spec: Deprecated (use EvalConfig). - - Returns: - An EvalResult that can be used with the TFMA visualization functions. - """ - # Get working_dir ready. - if output_path is None: - output_path = tempfile.mkdtemp() - if not tf.io.gfile.exists(output_path): - tf.io.gfile.makedirs(output_path) - - if slice_spec and eval_config: - raise ValueError('slice_spec is deprecated, only use eval_config') - if slice_spec: - eval_config = config_pb2.EvalConfig( - slicing_specs=[s.to_proto() for s in slice_spec] - ) + """Run model analysis for a single model on a single data set. + + This is a convenience wrapper around run_model_analysis for a single model + with a single data set. For more complex use cases, use + tfma.run_model_analysis. + + Args: + ---- + model_location: Path to the export eval saved model. + data_location: The location of the data files. + output_path: The directory to output metrics and results to. If None, we use + a temporary directory. + eval_config: Eval config. + slice_spec: Deprecated (use EvalConfig). + + Returns: + ------- + An EvalResult that can be used with the TFMA visualization functions. + """ + # Get working_dir ready. + if output_path is None: + output_path = tempfile.mkdtemp() + if not tf.io.gfile.exists(output_path): + tf.io.gfile.makedirs(output_path) + + if slice_spec and eval_config: + raise ValueError("slice_spec is deprecated, only use eval_config") + if slice_spec: + eval_config = config_pb2.EvalConfig( + slicing_specs=[s.to_proto() for s in slice_spec] + ) - return run_model_analysis( - eval_config=eval_config, - eval_shared_model=default_eval_shared_model( - eval_saved_model_path=model_location - ), - data_location=data_location, - output_path=output_path, - ) # pytype: disable=bad-return-type + return run_model_analysis( + eval_config=eval_config, + eval_shared_model=default_eval_shared_model( + eval_saved_model_path=model_location + ), + data_location=data_location, + output_path=output_path, + ) # pytype: disable=bad-return-type def multiple_model_analysis( model_locations: List[str], data_location: str, **kwargs ) -> view_types.EvalResults: - """Run model analysis for multiple models on the same data set. - - Args: - model_locations: A list of paths to the export eval saved model. - data_location: The location of the data files. - **kwargs: The args used for evaluation. See tfma.single_model_analysis() for - details. - - Returns: - A tfma.EvalResults containing all the evaluation results with the same order - as model_locations. - """ - results = [] - for m in model_locations: - results.append(single_model_analysis(m, data_location, **kwargs)) - return view_types.EvalResults(results, constants.MODEL_CENTRIC_MODE) + """Run model analysis for multiple models on the same data set. + + Args: + ---- + model_locations: A list of paths to the export eval saved model. + data_location: The location of the data files. + **kwargs: The args used for evaluation. See tfma.single_model_analysis() for + details. + + Returns: + ------- + A tfma.EvalResults containing all the evaluation results with the same order + as model_locations. + """ + results = [] + for m in model_locations: + results.append(single_model_analysis(m, data_location, **kwargs)) + return view_types.EvalResults(results, constants.MODEL_CENTRIC_MODE) def multiple_data_analysis( model_location: str, data_locations: List[str], **kwargs ) -> view_types.EvalResults: - """Run model analysis for a single model on multiple data sets. - - Args: - model_location: The location of the exported eval saved model. - data_locations: A list of data set locations. - **kwargs: The args used for evaluation. See tfma.run_model_analysis() for - details. - - Returns: - A tfma.EvalResults containing all the evaluation results with the same order - as data_locations. - """ - results = [] - for d in data_locations: - results.append(single_model_analysis(model_location, d, **kwargs)) - return view_types.EvalResults(results, constants.DATA_CENTRIC_MODE) + """Run model analysis for a single model on multiple data sets. + + Args: + ---- + model_location: The location of the exported eval saved model. + data_locations: A list of data set locations. + **kwargs: The args used for evaluation. See tfma.run_model_analysis() for + details. + + Returns: + ------- + A tfma.EvalResults containing all the evaluation results with the same order + as data_locations. + """ + results = [] + for d in data_locations: + results.append(single_model_analysis(model_location, d, **kwargs)) + return view_types.EvalResults(results, constants.DATA_CENTRIC_MODE) def analyze_raw_data( @@ -1484,138 +1489,139 @@ def analyze_raw_data( writers: Optional[List[writer.Writer]] = None, add_metric_callbacks: Optional[List[types.AddMetricsCallbackType]] = None, ) -> view_types.EvalResult: - """Runs TensorFlow model analysis on a pandas.DataFrame. - - This function allows you to use TFMA with Pandas DataFrames. The dataframe - must include a 'predicted' column for the predicted label and a 'label' column - for the actual label. - - In addition to a DataFrame, this function requires an eval_config, a - `tfma.EvalConfig` object containing various configuration parameters (see - [config.proto](https://github.com/tensorflow/model-analysis/blob/master/tensorflow_model_analysis/proto/config.proto) - for a comprehensive list)... - - * the metrics to compute - * the slices to compute metrics on - * the DataFrame's column names for example labels and predictions ('label' - and 'prediction' by default) - * confidence interval options - - This function returns a `tfma.EvalResult`, which contains TFMA's computed - metrics and can be used to generate plots with - `tfma.view.render_slicing_metrics`. - - Example usage: - - ```python - model_specs = [ - tfma.ModelSpec( - prediction_key='prediction', - label_key='label') - ] - metrics_specs = [ - tfma.MetricsSpec(metrics=[ - tfma.MetricConfig(class_name='Accuracy'), - tfma.MetricConfig(class_name='ExampleCount') - ]) - ] - slicing_specs = [ - tfma.SlicingSpec(), # the empty slice represents overall dataset - tfma.SlicingSpec(feature_keys=['language']) - ] - eval_config = tfma.EvalConfig( - model_specs=model_specs, - metrics_specs=metrics_specs, - slicing_specs=slicing_specs) - result = tfma.analyze_raw_data(df, eval_config) - tfma.view.render_slicing_metrics(result) - - # Example with Fairness Indicators - from tensorflow_model_analysis.addons.fairness.post_export_metrics import - fairness_indicators - from tensorflow_model_analysis.addons.fairness.view import widget_view - add_metrics_callbacks = [ - tfma.post_export_metrics.fairness_indicators(thresholds=[0.25, 0.5, 0.75]) - ] - result = tfma.analyze_raw_data( - data=df, - metrics_specs=metrics_specs, - slicing_specs=slicing_specs, - add_metric_callbacks=add_metrics_callbacks - ) - widget_view.render_fairness_indicator(result) - ``` - - Args: - data: A pandas.DataFrame, where rows correspond to examples and columns - correspond to features. One column must indicate a row's predicted label, - and one column must indicate a row's actual label. - eval_config: A `tfma.EvalConfig`, which contains various configuration - parameters including metrics, slices, and label/prediction column names. - output_path: Path to write EvalResult to. - extractors: Optional list of Extractors to apply to Extracts. Typically - these will be added by calling the default_extractors function. If no - extractors are provided, default_extractors (non-materialized) will be - used. - evaluators: Optional list of Evaluators for evaluating Extracts. Typically - these will be added by calling the default_evaluators function. If no - evaluators are provided, default_evaluators will be used. - writers: Optional list of Writers for writing Evaluation output. Typically - these will be added by calling the default_writers function. If no writers - are provided, default_writers with `add_metric_callbacks` will be used. - add_metric_callbacks: Optional list of metric callbacks (if used). - - Returns: - A tfma.EvalResult to extract metrics or generate visualizations from. - - Raises: - KeyError: If the prediction or label columns are not found within the - DataFrame. - """ - for model_spec in eval_config.model_specs: # pytype: disable=attribute-error - model_spec.prediction_key = model_spec.prediction_key or 'prediction' - model_spec.label_key = model_spec.label_key or 'label' - if model_spec.prediction_key not in data.columns: - raise KeyError( - 'The prediction_key column was not found. Looked for %s but found: %s' - % (model_spec.prediction_key, list(data.columns)) - ) - if model_spec.label_key not in data.columns: - raise KeyError( - 'The label_key column was not found. Looked for %s but found: %s' - % (model_spec.label_key, list(data.columns)) - ) - - # TODO(b/153570803): Validity check / assertions for dataframe structure - if eval_config.slicing_specs is None: # pytype: disable=attribute-error - eval_config.slicing_specs = [config_pb2.SlicingSpec(feature_keys=[''])] - if output_path is None: - output_path = tempfile.mkdtemp() - - arrow_data = table_util.CanonicalizeRecordBatch( - pa.RecordBatch.from_pandas(data) - ) - beam_data = beam.Create([arrow_data]) - - if not writers: - writers = default_writers( - output_path, - eval_config=eval_config, - add_metric_callbacks=add_metric_callbacks, + """Runs TensorFlow model analysis on a pandas.DataFrame. + + This function allows you to use TFMA with Pandas DataFrames. The dataframe + must include a 'predicted' column for the predicted label and a 'label' column + for the actual label. + + In addition to a DataFrame, this function requires an eval_config, a + `tfma.EvalConfig` object containing various configuration parameters (see + [config.proto](https://github.com/tensorflow/model-analysis/blob/master/tensorflow_model_analysis/proto/config.proto) + for a comprehensive list)... + + * the metrics to compute + * the slices to compute metrics on + * the DataFrame's column names for example labels and predictions ('label' + and 'prediction' by default) + * confidence interval options + + This function returns a `tfma.EvalResult`, which contains TFMA's computed + metrics and can be used to generate plots with + `tfma.view.render_slicing_metrics`. + + Example usage: + + ```python + model_specs = [ + tfma.ModelSpec( + prediction_key='prediction', + label_key='label') + ] + metrics_specs = [ + tfma.MetricsSpec(metrics=[ + tfma.MetricConfig(class_name='Accuracy'), + tfma.MetricConfig(class_name='ExampleCount') + ]) + ] + slicing_specs = [ + tfma.SlicingSpec(), # the empty slice represents overall dataset + tfma.SlicingSpec(feature_keys=['language']) + ] + eval_config = tfma.EvalConfig( + model_specs=model_specs, + metrics_specs=metrics_specs, + slicing_specs=slicing_specs) + result = tfma.analyze_raw_data(df, eval_config) + tfma.view.render_slicing_metrics(result) + + # Example with Fairness Indicators + from tensorflow_model_analysis.addons.fairness.post_export_metrics import + fairness_indicators + from tensorflow_model_analysis.addons.fairness.view import widget_view + add_metrics_callbacks = [ + tfma.post_export_metrics.fairness_indicators(thresholds=[0.25, 0.5, 0.75]) + ] + result = tfma.analyze_raw_data( + data=df, + metrics_specs=metrics_specs, + slicing_specs=slicing_specs, + add_metric_callbacks=add_metrics_callbacks ) + widget_view.render_fairness_indicator(result) + ``` + + Args: + ---- + data: A pandas.DataFrame, where rows correspond to examples and columns + correspond to features. One column must indicate a row's predicted label, + and one column must indicate a row's actual label. + eval_config: A `tfma.EvalConfig`, which contains various configuration + parameters including metrics, slices, and label/prediction column names. + output_path: Path to write EvalResult to. + extractors: Optional list of Extractors to apply to Extracts. Typically + these will be added by calling the default_extractors function. If no + extractors are provided, default_extractors (non-materialized) will be + used. + evaluators: Optional list of Evaluators for evaluating Extracts. Typically + these will be added by calling the default_evaluators function. If no + evaluators are provided, default_evaluators will be used. + writers: Optional list of Writers for writing Evaluation output. Typically + these will be added by calling the default_writers function. If no writers + are provided, default_writers with `add_metric_callbacks` will be used. + add_metric_callbacks: Optional list of metric callbacks (if used). + + Returns: + ------- + A tfma.EvalResult to extract metrics or generate visualizations from. + + Raises: + ------ + KeyError: If the prediction or label columns are not found within the + DataFrame. + """ + for model_spec in eval_config.model_specs: # pytype: disable=attribute-error + model_spec.prediction_key = model_spec.prediction_key or "prediction" + model_spec.label_key = model_spec.label_key or "label" + if model_spec.prediction_key not in data.columns: + raise KeyError( + "The prediction_key column was not found. Looked for %s but found: %s" + % (model_spec.prediction_key, list(data.columns)) + ) + if model_spec.label_key not in data.columns: + raise KeyError( + "The label_key column was not found. Looked for %s but found: %s" + % (model_spec.label_key, list(data.columns)) + ) - with beam.Pipeline() as p: - _ = ( - p - | beam_data - | 'ExtractEvaluateAndWriteResults' - >> ExtractEvaluateAndWriteResults( # pylint: disable=no-value-for-parameter - extractors=extractors, - evaluators=evaluators, - writers=writers, + # TODO(b/153570803): Validity check / assertions for dataframe structure + if eval_config.slicing_specs is None: # pytype: disable=attribute-error + eval_config.slicing_specs = [config_pb2.SlicingSpec(feature_keys=[""])] + if output_path is None: + output_path = tempfile.mkdtemp() + + arrow_data = table_util.CanonicalizeRecordBatch(pa.RecordBatch.from_pandas(data)) + beam_data = beam.Create([arrow_data]) + + if not writers: + writers = default_writers( + output_path, eval_config=eval_config, - output_path=output_path, + add_metric_callbacks=add_metric_callbacks, + ) + + with beam.Pipeline() as p: + _ = ( + p + | beam_data + | "ExtractEvaluateAndWriteResults" + >> ExtractEvaluateAndWriteResults( # pylint: disable=no-value-for-parameter + extractors=extractors, + evaluators=evaluators, + writers=writers, + eval_config=eval_config, + output_path=output_path, + ) ) - ) - return load_eval_result(output_path) + return load_eval_result(output_path) diff --git a/tensorflow_model_analysis/api/model_eval_lib_test.py b/tensorflow_model_analysis/api/model_eval_lib_test.py index abf89f3259..6493c107c7 100644 --- a/tensorflow_model_analysis/api/model_eval_lib_test.py +++ b/tensorflow_model_analysis/api/model_eval_lib_test.py @@ -18,270 +18,268 @@ import tempfile import unittest -from absl.testing import absltest -from absl.testing import parameterized import apache_beam as beam import numpy as np import pandas as pd import tensorflow as tf +from absl.testing import absltest, parameterized +from google.protobuf import text_format, wrappers_pb2 from tensorflow import keras +from tensorflow_metadata.proto.v0 import schema_pb2 +from tfx_bsl.coders import example_coder + from tensorflow_model_analysis import constants from tensorflow_model_analysis.api import model_eval_lib from tensorflow_model_analysis.evaluators import metrics_plots_and_validations_evaluator -from tensorflow_model_analysis.metrics import calibration_plot -from tensorflow_model_analysis.metrics import confusion_matrix_metrics -from tensorflow_model_analysis.metrics import metric_specs -from tensorflow_model_analysis.metrics import metric_types -from tensorflow_model_analysis.metrics import metric_util -from tensorflow_model_analysis.metrics import ndcg -from tensorflow_model_analysis.proto import config_pb2 -from tensorflow_model_analysis.proto import metrics_for_slice_pb2 -from tensorflow_model_analysis.proto import validation_result_pb2 -from tensorflow_model_analysis.utils import example_keras_model -from tensorflow_model_analysis.utils import test_util +from tensorflow_model_analysis.metrics import ( + calibration_plot, + confusion_matrix_metrics, + metric_specs, + metric_types, + metric_util, + ndcg, +) +from tensorflow_model_analysis.proto import ( + config_pb2, + metrics_for_slice_pb2, + validation_result_pb2, +) +from tensorflow_model_analysis.utils import example_keras_model, test_util from tensorflow_model_analysis.utils.keras_lib import tf_keras from tensorflow_model_analysis.view import view_types -from tfx_bsl.coders import example_coder - -from google.protobuf import wrappers_pb2 -from google.protobuf import text_format -from tensorflow_metadata.proto.v0 import schema_pb2 try: - import tensorflow_ranking as tfr # pylint: disable=g-import-not-at-top + import tensorflow_ranking as tfr # pylint: disable=g-import-not-at-top - _TFR_IMPORTED = True + _TFR_IMPORTED = True except (ImportError, tf.errors.NotFoundError): - _TFR_IMPORTED = False + _TFR_IMPORTED = False try: - from tensorflowjs.converters import converter as tfjs_converter # pylint: disable=g-import-not-at-top + from tensorflowjs.converters import ( + converter as tfjs_converter, # pylint: disable=g-import-not-at-top + ) - _TFJS_IMPORTED = True + _TFJS_IMPORTED = True except ModuleNotFoundError: - _TFJS_IMPORTED = False + _TFJS_IMPORTED = False _TEST_SEED = 982735 -_TF_MAJOR_VERSION = int(tf.version.VERSION.split('.')[0]) - - -class EvaluateTest( - test_util.TensorflowModelAnalysisTest, parameterized.TestCase -): - - def setUp(self): - super().setUp() - self.longMessage = True # pylint: disable=invalid-name - - def _getTempDir(self): - return tempfile.mkdtemp() - - def _exportEvalSavedModel(self, classifier): - temp_eval_export_dir = os.path.join(self._getTempDir(), 'eval_export_dir') - _, eval_export_dir = classifier(None, temp_eval_export_dir) - return eval_export_dir - - def _exportKerasModel(self, classifier): - temp_export_dir = os.path.join(self._getTempDir(), 'saved_model_export_dir') - classifier.export(temp_export_dir) - return temp_export_dir - - def _writeTFExamplesToTFRecords(self, examples): - data_location = os.path.join(self._getTempDir(), 'input_data.rio') - with tf.io.TFRecordWriter(data_location) as writer: - for example in examples: - writer.write(example.SerializeToString()) - return data_location - - def _writeCSVToTextFile(self, examples): - data_location = os.path.join(self._getTempDir(), 'input_data.csv') - with open(data_location, 'w') as writer: - for example in examples: - writer.write(example + '\n') - return data_location - - def assertMetricsAlmostEqual( - self, - got_slicing_metrics, - expected_slicing_metrics, - output_name='', - subkey='', - ): - if got_slicing_metrics: - for s, m in got_slicing_metrics: - metrics = m[output_name][subkey] - self.assertIn(s, expected_slicing_metrics) - for metric_name in expected_slicing_metrics[s]: - self.assertIn(metric_name, metrics) - self.assertDictElementsAlmostEqual( - metrics[metric_name], expected_slicing_metrics[s][metric_name] - ) - else: - # Only pass if expected_slicing_metrics also evaluates to False. - self.assertFalse( - expected_slicing_metrics, msg='Actual slicing_metrics was empty.' - ) - - def assertSliceMetricsEqual(self, expected_metrics, got_metrics): - self.assertCountEqual( - list(expected_metrics), - list(got_metrics), - msg='keys do not match. expected_metrics: %s, got_metrics: %s' - % (expected_metrics, got_metrics), - ) - for key in expected_metrics: - self.assertProtoEquals( - expected_metrics[key], - got_metrics[key], - msg='value for key %s does not match' % key, - ) - - def assertSliceListEqual(self, expected_list, got_list, value_assert_fn): - self.assertEqual( - len(expected_list), - len(got_list), - msg='expected_list: %s, got_list: %s' % (expected_list, got_list), - ) - for index, (expected, got) in enumerate(zip(expected_list, got_list)): - (expected_key, expected_value) = expected - (got_key, got_value) = got - self.assertEqual( - expected_key, got_key, msg='key mismatch at index %d' % index - ) - value_assert_fn(expected_value, got_value) - - def assertSlicePlotsListEqual(self, expected_list, got_list): - self.assertSliceListEqual(expected_list, got_list, self.assertProtoEquals) - - def assertSliceMetricsListEqual(self, expected_list, got_list): - self.assertSliceListEqual( - expected_list, got_list, self.assertSliceMetricsEqual - ) - - @parameterized.named_parameters( - ('tflite', constants.TF_LITE), ('tfjs', constants.TF_JS) - ) - def testMixedModelTypes(self, model_type): - examples = [self._makeExample(age=3.0, language='english', label=1.0)] - data_location = self._writeTFExamplesToTFRecords(examples) - eval_config = config_pb2.EvalConfig( - model_specs=[ - config_pb2.ModelSpec(name='model1'), - config_pb2.ModelSpec(name='model2', model_type=model_type), - ] - ) - eval_shared_models = [ - model_eval_lib.default_eval_shared_model( - model_name='model1', - eval_saved_model_path='/model1/path', - eval_config=eval_config, - ), - model_eval_lib.default_eval_shared_model( - model_name='model2', - eval_saved_model_path='/model2/path', - eval_config=eval_config, - ), - ] - with self.assertRaisesRegex( - NotImplementedError, 'support for mixing .* models is not implemented' +_TF_MAJOR_VERSION = int(tf.version.VERSION.split(".")[0]) + + +class EvaluateTest(test_util.TensorflowModelAnalysisTest, parameterized.TestCase): + def setUp(self): + super().setUp() + self.longMessage = True # pylint: disable=invalid-name + + def _getTempDir(self): + return tempfile.mkdtemp() + + def _exportEvalSavedModel(self, classifier): + temp_eval_export_dir = os.path.join(self._getTempDir(), "eval_export_dir") + _, eval_export_dir = classifier(None, temp_eval_export_dir) + return eval_export_dir + + def _exportKerasModel(self, classifier): + temp_export_dir = os.path.join(self._getTempDir(), "saved_model_export_dir") + classifier.export(temp_export_dir) + return temp_export_dir + + def _writeTFExamplesToTFRecords(self, examples): + data_location = os.path.join(self._getTempDir(), "input_data.rio") + with tf.io.TFRecordWriter(data_location) as writer: + for example in examples: + writer.write(example.SerializeToString()) + return data_location + + def _writeCSVToTextFile(self, examples): + data_location = os.path.join(self._getTempDir(), "input_data.csv") + with open(data_location, "w") as writer: + for example in examples: + writer.write(example + "\n") + return data_location + + def assertMetricsAlmostEqual( + self, + got_slicing_metrics, + expected_slicing_metrics, + output_name="", + subkey="", ): - model_eval_lib.run_model_analysis( - eval_config=eval_config, - eval_shared_model=eval_shared_models, - data_location=data_location, - output_path=self._getTempDir(), - ) - - def testRunModelAnalysis(self): - examples = [ - self._makeExample(age=3.0, language='english', label=1.0), - self._makeExample(age=3.0, language='chinese', label=0.0), - self._makeExample(age=4.0, language='english', label=1.0), - self._makeExample(age=5.0, language='chinese', label=1.0), - self._makeExample(age=5.0, language='hindi', label=1.0), - ] - classifier = example_keras_model.get_example_classifier_model( - example_keras_model.LANGUAGE - ) - classifier.compile(optimizer=keras.optimizers.Adam(), loss='mse') - classifier.fit( - tf.constant([e.SerializeToString() for e in examples]), - np.array([ - e.features.feature[example_keras_model.LABEL].float_list.value[:][0] - for e in examples - ]), - batch_size=1, - ) - model_location = self._exportKerasModel(classifier) - data_location = self._writeTFExamplesToTFRecords(examples) - eval_config = config_pb2.EvalConfig( - model_specs=[ - config_pb2.ModelSpec( - name='model1', example_weight_key='age', label_key='label' + if got_slicing_metrics: + for s, m in got_slicing_metrics: + metrics = m[output_name][subkey] + self.assertIn(s, expected_slicing_metrics) + for metric_name in expected_slicing_metrics[s]: + self.assertIn(metric_name, metrics) + self.assertDictElementsAlmostEqual( + metrics[metric_name], expected_slicing_metrics[s][metric_name] + ) + else: + # Only pass if expected_slicing_metrics also evaluates to False. + self.assertFalse( + expected_slicing_metrics, msg="Actual slicing_metrics was empty." ) - ], - slicing_specs=[config_pb2.SlicingSpec(feature_keys=['language'])], - metrics_specs=[ - config_pb2.MetricsSpec( - metrics=[ - config_pb2.MetricConfig( - class_name='ExampleCount', - ), - config_pb2.MetricConfig( - class_name='Accuracy', - ), - ] + + def assertSliceMetricsEqual(self, expected_metrics, got_metrics): + self.assertCountEqual( + list(expected_metrics), + list(got_metrics), + msg="keys do not match. expected_metrics: %s, got_metrics: %s" + % (expected_metrics, got_metrics), + ) + for key in expected_metrics: + self.assertProtoEquals( + expected_metrics[key], + got_metrics[key], + msg="value for key %s does not match" % key, ) - ], - options=config_pb2.Options( - min_slice_size=wrappers_pb2.Int32Value(value=2) - ), - ) - eval_result = model_eval_lib.run_model_analysis( - eval_shared_model=model_eval_lib.default_eval_shared_model( - eval_saved_model_path=model_location, eval_config=eval_config - ), - data_location=data_location, - eval_config=eval_config, - output_path=self._getTempDir(), + + def assertSliceListEqual(self, expected_list, got_list, value_assert_fn): + self.assertEqual( + len(expected_list), + len(got_list), + msg="expected_list: %s, got_list: %s" % (expected_list, got_list), + ) + for index, (expected, got) in enumerate(zip(expected_list, got_list)): + (expected_key, expected_value) = expected + (got_key, got_value) = got + self.assertEqual( + expected_key, got_key, msg="key mismatch at index %d" % index + ) + value_assert_fn(expected_value, got_value) + + def assertSlicePlotsListEqual(self, expected_list, got_list): + self.assertSliceListEqual(expected_list, got_list, self.assertProtoEquals) + + def assertSliceMetricsListEqual(self, expected_list, got_list): + self.assertSliceListEqual(expected_list, got_list, self.assertSliceMetricsEqual) + + @parameterized.named_parameters( + ("tflite", constants.TF_LITE), ("tfjs", constants.TF_JS) ) - # We only check some of the metrics to ensure that the end-to-end - # pipeline works. - expected = { - (('language', 'hindi'),): { - '__ERROR__': { - 'debugMessage': ( - 'Example count for this slice key is lower than the ' - 'minimum required value: 2. No data is aggregated for ' - 'this slice.' + def testMixedModelTypes(self, model_type): + examples = [self._makeExample(age=3.0, language="english", label=1.0)] + data_location = self._writeTFExamplesToTFRecords(examples) + eval_config = config_pb2.EvalConfig( + model_specs=[ + config_pb2.ModelSpec(name="model1"), + config_pb2.ModelSpec(name="model2", model_type=model_type), + ] + ) + eval_shared_models = [ + model_eval_lib.default_eval_shared_model( + model_name="model1", + eval_saved_model_path="/model1/path", + eval_config=eval_config, + ), + model_eval_lib.default_eval_shared_model( + model_name="model2", + eval_saved_model_path="/model2/path", + eval_config=eval_config, + ), + ] + with self.assertRaisesRegex( + NotImplementedError, "support for mixing .* models is not implemented" + ): + model_eval_lib.run_model_analysis( + eval_config=eval_config, + eval_shared_model=eval_shared_models, + data_location=data_location, + output_path=self._getTempDir(), + ) + + def testRunModelAnalysis(self): + examples = [ + self._makeExample(age=3.0, language="english", label=1.0), + self._makeExample(age=3.0, language="chinese", label=0.0), + self._makeExample(age=4.0, language="english", label=1.0), + self._makeExample(age=5.0, language="chinese", label=1.0), + self._makeExample(age=5.0, language="hindi", label=1.0), + ] + classifier = example_keras_model.get_example_classifier_model( + example_keras_model.LANGUAGE + ) + classifier.compile(optimizer=keras.optimizers.Adam(), loss="mse") + classifier.fit( + tf.constant([e.SerializeToString() for e in examples]), + np.array( + [ + e.features.feature[example_keras_model.LABEL].float_list.value[:][0] + for e in examples + ] + ), + batch_size=1, + ) + model_location = self._exportKerasModel(classifier) + data_location = self._writeTFExamplesToTFRecords(examples) + eval_config = config_pb2.EvalConfig( + model_specs=[ + config_pb2.ModelSpec( + name="model1", example_weight_key="age", label_key="label" + ) + ], + slicing_specs=[config_pb2.SlicingSpec(feature_keys=["language"])], + metrics_specs=[ + config_pb2.MetricsSpec( + metrics=[ + config_pb2.MetricConfig( + class_name="ExampleCount", + ), + config_pb2.MetricConfig( + class_name="Accuracy", + ), + ] ) + ], + options=config_pb2.Options(min_slice_size=wrappers_pb2.Int32Value(value=2)), + ) + eval_result = model_eval_lib.run_model_analysis( + eval_shared_model=model_eval_lib.default_eval_shared_model( + eval_saved_model_path=model_location, eval_config=eval_config + ), + data_location=data_location, + eval_config=eval_config, + output_path=self._getTempDir(), + ) + # We only check some of the metrics to ensure that the end-to-end + # pipeline works. + expected = { + (("language", "hindi"),): { + "__ERROR__": { + "debugMessage": ( + "Example count for this slice key is lower than the " + "minimum required value: 2. No data is aggregated for " + "this slice." + ) + }, }, - }, - (('language', 'chinese'),): { - 'accuracy': {'doubleValue': 0.0}, - 'example_count': {'doubleValue': 8.0}, - }, - (('language', 'english'),): { - 'accuracy': {'doubleValue': 0.0}, - 'example_count': {'doubleValue': 7.0}, - }, - } - self.assertEqual(eval_result.model_location, model_location) - self.assertEqual(eval_result.data_location, data_location) - self.assertEqual( - eval_result.config.slicing_specs[0], - config_pb2.SlicingSpec(feature_keys=['language']), - ) - self.assertMetricsAlmostEqual(eval_result.slicing_metrics, expected) - for _, plot in eval_result.plots: - self.assertFalse(plot) - - @parameterized.named_parameters( - { - 'testcase_name': 'WithHistogram', - 'eval_config': text_format.Parse( - """ + (("language", "chinese"),): { + "accuracy": {"doubleValue": 0.0}, + "example_count": {"doubleValue": 8.0}, + }, + (("language", "english"),): { + "accuracy": {"doubleValue": 0.0}, + "example_count": {"doubleValue": 7.0}, + }, + } + self.assertEqual(eval_result.model_location, model_location) + self.assertEqual(eval_result.data_location, data_location) + self.assertEqual( + eval_result.config.slicing_specs[0], + config_pb2.SlicingSpec(feature_keys=["language"]), + ) + self.assertMetricsAlmostEqual(eval_result.slicing_metrics, expected) + for _, plot in eval_result.plots: + self.assertFalse(plot) + + @parameterized.named_parameters( + { + "testcase_name": "WithHistogram", + "eval_config": text_format.Parse( + """ model_specs { label_key: "labels" prediction_key: "predictions" @@ -310,10 +308,10 @@ def testRunModelAnalysis(self): } } """, - config_pb2.EvalConfig(), - ), - 'expected_class_0_recall': text_format.Parse( - """ + config_pb2.EvalConfig(), + ), + "expected_class_0_recall": text_format.Parse( + """ array_value { data_type: FLOAT64 shape: 3 @@ -322,10 +320,10 @@ def testRunModelAnalysis(self): float64_values: 0.0 } """, - metrics_for_slice_pb2.MetricValue(), - ), - 'expected_class_1_recall': text_format.Parse( - """ + metrics_for_slice_pb2.MetricValue(), + ), + "expected_class_1_recall": text_format.Parse( + """ array_value { data_type: FLOAT64 shape: 3 @@ -334,13 +332,13 @@ def testRunModelAnalysis(self): float64_values: 0.0 } """, - metrics_for_slice_pb2.MetricValue(), - ), - }, - { - 'testcase_name': 'NoHistogram', - 'eval_config': text_format.Parse( - """ + metrics_for_slice_pb2.MetricValue(), + ), + }, + { + "testcase_name": "NoHistogram", + "eval_config": text_format.Parse( + """ model_specs { label_key: "labels" prediction_key: "predictions" @@ -369,206 +367,194 @@ def testRunModelAnalysis(self): } } """, - config_pb2.EvalConfig(), - ), - 'expected_class_0_recall': text_format.Parse( - 'double_value { value: 1.0 }', - metrics_for_slice_pb2.MetricValue(), - ), - 'expected_class_1_recall': text_format.Parse( - 'double_value { value: 0.0 }', - metrics_for_slice_pb2.MetricValue(), - ), - }, - ) - def testRunModelAnalysisMultiMicroAggregation( - self, eval_config, expected_class_0_recall, expected_class_1_recall - ): - # class 0 is all TPs so has recall 1.0, class 1 is all FPs so has recall 0.0 - examples = [ - self._makeExample(labels=[1.0, 1.0], predictions=[0.9, 0.1]), - self._makeExample(labels=[1.0, 1.0], predictions=[0.9, 0.1]), - ] - data_location = self._writeTFExamplesToTFRecords(examples) - output_dir = self._getTempDir() - model_eval_lib.run_model_analysis( - eval_config=eval_config, - data_location=data_location, - output_path=output_dir, - ) - - metrics_for_slice = list(model_eval_lib.load_metrics(output_dir)) - self.assertLen(metrics_for_slice, 1) - metric_keys_to_values = { - metric_types.MetricKey.from_proto(kv.key): kv.value - for kv in metrics_for_slice[0].metric_keys_and_values - } - class_0_key = metric_types.MetricKey( - name='recall_class_0', - aggregation_type=metric_types.AggregationType(micro_average=True), - ) - class_1_key = metric_types.MetricKey( - name='recall_class_1', - aggregation_type=metric_types.AggregationType(micro_average=True), - ) - self.assertIn(class_0_key, metric_keys_to_values) - self.assertEqual( - expected_class_0_recall, metric_keys_to_values[class_0_key] - ) - self.assertIn(class_1_key, metric_keys_to_values) - self.assertEqual( - expected_class_1_recall, metric_keys_to_values[class_1_key] + config_pb2.EvalConfig(), + ), + "expected_class_0_recall": text_format.Parse( + "double_value { value: 1.0 }", + metrics_for_slice_pb2.MetricValue(), + ), + "expected_class_1_recall": text_format.Parse( + "double_value { value: 0.0 }", + metrics_for_slice_pb2.MetricValue(), + ), + }, ) + def testRunModelAnalysisMultiMicroAggregation( + self, eval_config, expected_class_0_recall, expected_class_1_recall + ): + # class 0 is all TPs so has recall 1.0, class 1 is all FPs so has recall 0.0 + examples = [ + self._makeExample(labels=[1.0, 1.0], predictions=[0.9, 0.1]), + self._makeExample(labels=[1.0, 1.0], predictions=[0.9, 0.1]), + ] + data_location = self._writeTFExamplesToTFRecords(examples) + output_dir = self._getTempDir() + model_eval_lib.run_model_analysis( + eval_config=eval_config, + data_location=data_location, + output_path=output_dir, + ) - def testRunModelAnalysisWithExplicitModelAgnosticPredictions(self): - examples = [ - self._makeExample( - age=3.0, language='english', label=1.0, prediction=0.9 - ), - self._makeExample( - age=3.0, language='chinese', label=0.0, prediction=0.4 - ), - self._makeExample( - age=4.0, language='english', label=1.0, prediction=0.7 - ), - self._makeExample( - age=5.0, language='chinese', label=1.0, prediction=0.2 - ), - ] - metrics_specs = [ - config_pb2.MetricsSpec( - metrics=[config_pb2.MetricConfig(class_name='ExampleCount')], - example_weights=config_pb2.ExampleWeightOptions(unweighted=True), - ), - config_pb2.MetricsSpec( - metrics=[ - config_pb2.MetricConfig(class_name='WeightedExampleCount') - ], - example_weights=config_pb2.ExampleWeightOptions(weighted=True), - ), - config_pb2.MetricsSpec( - metrics=[config_pb2.MetricConfig(class_name='BinaryAccuracy')], - example_weights=config_pb2.ExampleWeightOptions(weighted=True), - ), - ] - slicing_specs = [config_pb2.SlicingSpec(feature_keys=['language'])] - model_spec = config_pb2.ModelSpec( - prediction_key='prediction', - label_key='label', - example_weight_key='age', - ) - eval_config = config_pb2.EvalConfig( - model_specs=[model_spec], - metrics_specs=metrics_specs, - slicing_specs=slicing_specs, - ) - data_location = self._writeTFExamplesToTFRecords(examples) - eval_result = model_eval_lib.run_model_analysis( - eval_config=eval_config, - data_location=data_location, - output_path=self._getTempDir(), - ) - expected = { - (('language', 'chinese'),): { - 'binary_accuracy': {'doubleValue': 0.375}, - 'weighted_example_count': {'doubleValue': 8.0}, - 'example_count': {'doubleValue': 2.0}, - }, - (('language', 'english'),): { - 'binary_accuracy': {'doubleValue': 1.0}, - 'weighted_example_count': {'doubleValue': 7.0}, - 'example_count': {'doubleValue': 2.0}, - }, - } - self.assertEqual(eval_result.data_location, data_location) - self.assertEqual( - eval_result.config.slicing_specs[0], - config_pb2.SlicingSpec(feature_keys=['language']), + metrics_for_slice = list(model_eval_lib.load_metrics(output_dir)) + self.assertLen(metrics_for_slice, 1) + metric_keys_to_values = { + metric_types.MetricKey.from_proto(kv.key): kv.value + for kv in metrics_for_slice[0].metric_keys_and_values + } + class_0_key = metric_types.MetricKey( + name="recall_class_0", + aggregation_type=metric_types.AggregationType(micro_average=True), + ) + class_1_key = metric_types.MetricKey( + name="recall_class_1", + aggregation_type=metric_types.AggregationType(micro_average=True), + ) + self.assertIn(class_0_key, metric_keys_to_values) + self.assertEqual(expected_class_0_recall, metric_keys_to_values[class_0_key]) + self.assertIn(class_1_key, metric_keys_to_values) + self.assertEqual(expected_class_1_recall, metric_keys_to_values[class_1_key]) + + def testRunModelAnalysisWithExplicitModelAgnosticPredictions(self): + examples = [ + self._makeExample(age=3.0, language="english", label=1.0, prediction=0.9), + self._makeExample(age=3.0, language="chinese", label=0.0, prediction=0.4), + self._makeExample(age=4.0, language="english", label=1.0, prediction=0.7), + self._makeExample(age=5.0, language="chinese", label=1.0, prediction=0.2), + ] + metrics_specs = [ + config_pb2.MetricsSpec( + metrics=[config_pb2.MetricConfig(class_name="ExampleCount")], + example_weights=config_pb2.ExampleWeightOptions(unweighted=True), + ), + config_pb2.MetricsSpec( + metrics=[config_pb2.MetricConfig(class_name="WeightedExampleCount")], + example_weights=config_pb2.ExampleWeightOptions(weighted=True), + ), + config_pb2.MetricsSpec( + metrics=[config_pb2.MetricConfig(class_name="BinaryAccuracy")], + example_weights=config_pb2.ExampleWeightOptions(weighted=True), + ), + ] + slicing_specs = [config_pb2.SlicingSpec(feature_keys=["language"])] + model_spec = config_pb2.ModelSpec( + prediction_key="prediction", + label_key="label", + example_weight_key="age", + ) + eval_config = config_pb2.EvalConfig( + model_specs=[model_spec], + metrics_specs=metrics_specs, + slicing_specs=slicing_specs, + ) + data_location = self._writeTFExamplesToTFRecords(examples) + eval_result = model_eval_lib.run_model_analysis( + eval_config=eval_config, + data_location=data_location, + output_path=self._getTempDir(), + ) + expected = { + (("language", "chinese"),): { + "binary_accuracy": {"doubleValue": 0.375}, + "weighted_example_count": {"doubleValue": 8.0}, + "example_count": {"doubleValue": 2.0}, + }, + (("language", "english"),): { + "binary_accuracy": {"doubleValue": 1.0}, + "weighted_example_count": {"doubleValue": 7.0}, + "example_count": {"doubleValue": 2.0}, + }, + } + self.assertEqual(eval_result.data_location, data_location) + self.assertEqual( + eval_result.config.slicing_specs[0], + config_pb2.SlicingSpec(feature_keys=["language"]), + ) + self.assertMetricsAlmostEqual(eval_result.slicing_metrics, expected) + + @parameterized.named_parameters( + ("tf_keras", constants.TF_KERAS), + ("tf_lite", constants.TF_LITE), + ("tf_js", constants.TF_JS), + ("baseline_missing", constants.TF_KERAS, True), + ("rubber_stamp", constants.TF_KERAS, True, True), + ("tf_keras_custom_metrics", constants.TF_KERAS, False, False, True), ) - self.assertMetricsAlmostEqual(eval_result.slicing_metrics, expected) - - @parameterized.named_parameters( - ('tf_keras', constants.TF_KERAS), - ('tf_lite', constants.TF_LITE), - ('tf_js', constants.TF_JS), - ('baseline_missing', constants.TF_KERAS, True), - ('rubber_stamp', constants.TF_KERAS, True, True), - ('tf_keras_custom_metrics', constants.TF_KERAS, False, False, True), - ) - def testRunModelAnalysisWithKerasModel( - self, - model_type, - remove_baseline=False, - rubber_stamp=False, - add_custom_metrics=False, - ): - if model_type == constants.TF_JS and not _TFJS_IMPORTED: - self.skipTest('This test requires TensorFlow JS.') - - # Custom metrics not supported in TFv1 - if _TF_MAJOR_VERSION < 2: - add_custom_metrics = False - - def _build_keras_model( - eval_config, export_name='export_dir', rubber_stamp=False + def testRunModelAnalysisWithKerasModel( + self, + model_type, + remove_baseline=False, + rubber_stamp=False, + add_custom_metrics=False, ): - input_layer = tf_keras.layers.Input(shape=(28 * 28,), name='data') - output_layer = tf_keras.layers.Dense(10, activation=tf.nn.softmax)( - input_layer - ) - model = tf_keras.models.Model(input_layer, output_layer) - model.compile( - optimizer=tf_keras.optimizers.Adam(lr=0.001), - loss=tf_keras.losses.categorical_crossentropy, - ) - if add_custom_metrics: - model.add_metric(tf.reduce_sum(input_layer), 'custom') - model_location = os.path.join(self._getTempDir(), export_name) - if model_type == constants.TF_LITE: - converter = tf.compat.v2.lite.TFLiteConverter.from_keras_model(model) - tflite_model = converter.convert() - tf.io.gfile.makedirs(model_location) - with tf.io.gfile.GFile( - os.path.join(model_location, 'tflite'), 'wb' - ) as f: - f.write(tflite_model) - elif model_type == constants.TF_JS: - src_model_path = tempfile.mkdtemp() - model.export(src_model_path) - - tfjs_converter.convert([ - '--input_format=tf_saved_model', - '--saved_model_tags=serve', - '--signature_name=serving_default', - src_model_path, - model_location, - ]) - else: - model.export(model_location) - return model_eval_lib.default_eval_shared_model( - eval_saved_model_path=model_location, - eval_config=eval_config, - rubber_stamp=rubber_stamp, - ) - - examples = [ - self._makeExample( - data=[0.0] * 28 * 28, - label=[0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - ), - self._makeExample( - data=[1.0] * 28 * 28, - label=[0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0], - ), - self._makeExample( - data=[1.0] * 28 * 28, - label=[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0], - ), - ] - data_location = self._writeTFExamplesToTFRecords(examples) - - schema = text_format.Parse( - """ + if model_type == constants.TF_JS and not _TFJS_IMPORTED: + self.skipTest("This test requires TensorFlow JS.") + + # Custom metrics not supported in TFv1 + if _TF_MAJOR_VERSION < 2: + add_custom_metrics = False + + def _build_keras_model( + eval_config, export_name="export_dir", rubber_stamp=False + ): + input_layer = tf_keras.layers.Input(shape=(28 * 28,), name="data") + output_layer = tf_keras.layers.Dense(10, activation=tf.nn.softmax)( + input_layer + ) + model = tf_keras.models.Model(input_layer, output_layer) + model.compile( + optimizer=tf_keras.optimizers.Adam(lr=0.001), + loss=tf_keras.losses.categorical_crossentropy, + ) + if add_custom_metrics: + model.add_metric(tf.reduce_sum(input_layer), "custom") + model_location = os.path.join(self._getTempDir(), export_name) + if model_type == constants.TF_LITE: + converter = tf.compat.v2.lite.TFLiteConverter.from_keras_model(model) + tflite_model = converter.convert() + tf.io.gfile.makedirs(model_location) + with tf.io.gfile.GFile( + os.path.join(model_location, "tflite"), "wb" + ) as f: + f.write(tflite_model) + elif model_type == constants.TF_JS: + src_model_path = tempfile.mkdtemp() + model.export(src_model_path) + + tfjs_converter.convert( + [ + "--input_format=tf_saved_model", + "--saved_model_tags=serve", + "--signature_name=serving_default", + src_model_path, + model_location, + ] + ) + else: + model.export(model_location) + return model_eval_lib.default_eval_shared_model( + eval_saved_model_path=model_location, + eval_config=eval_config, + rubber_stamp=rubber_stamp, + ) + + examples = [ + self._makeExample( + data=[0.0] * 28 * 28, + label=[0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + ), + self._makeExample( + data=[1.0] * 28 * 28, + label=[0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0], + ), + self._makeExample( + data=[1.0] * 28 * 28, + label=[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0], + ), + ] + data_location = self._writeTFExamplesToTFRecords(examples) + + schema = text_format.Parse( + """ tensor_representation_group { key: "" value { @@ -594,103 +580,103 @@ def _build_keras_model( presence: { min_fraction: 1 } } """, - schema_pb2.Schema(), - ) - # TODO(b/73109633): Remove when field is removed or its default changes to - # False. - if hasattr(schema, 'generate_legacy_feature_spec'): - schema.generate_legacy_feature_spec = False - - metrics_spec = config_pb2.MetricsSpec() - for metric in (confusion_matrix_metrics.AUC(name='auc'),): - cfg = metric_util.serialize_keras_object(metric) - metrics_spec.metrics.append( - config_pb2.MetricConfig( - class_name=cfg['class_name'], config=json.dumps(cfg['config']) - ) - ) - tf_keras.backend.clear_session() - slicing_specs = [ - config_pb2.SlicingSpec(), - config_pb2.SlicingSpec(feature_keys=['non_existent_slice']), - ] - metrics_spec.metrics.append( - config_pb2.MetricConfig( - class_name='ExampleCount', - per_slice_thresholds=[ - config_pb2.PerSliceMetricThreshold( - slicing_specs=slicing_specs, - threshold=config_pb2.MetricThreshold( - value_threshold=config_pb2.GenericValueThreshold( - lower_bound={'value': 1} - ) + schema_pb2.Schema(), + ) + # TODO(b/73109633): Remove when field is removed or its default changes to + # False. + if hasattr(schema, "generate_legacy_feature_spec"): + schema.generate_legacy_feature_spec = False + + metrics_spec = config_pb2.MetricsSpec() + for metric in (confusion_matrix_metrics.AUC(name="auc"),): + cfg = metric_util.serialize_keras_object(metric) + metrics_spec.metrics.append( + config_pb2.MetricConfig( + class_name=cfg["class_name"], config=json.dumps(cfg["config"]) + ) + ) + tf_keras.backend.clear_session() + slicing_specs = [ + config_pb2.SlicingSpec(), + config_pb2.SlicingSpec(feature_keys=["non_existent_slice"]), + ] + metrics_spec.metrics.append( + config_pb2.MetricConfig( + class_name="ExampleCount", + per_slice_thresholds=[ + config_pb2.PerSliceMetricThreshold( + slicing_specs=slicing_specs, + threshold=config_pb2.MetricThreshold( + value_threshold=config_pb2.GenericValueThreshold( + lower_bound={"value": 1} + ) + ), ), - ), - # Change thresholds would be ignored when rubber stamp is true. - config_pb2.PerSliceMetricThreshold( - slicing_specs=slicing_specs, - threshold=config_pb2.MetricThreshold( - change_threshold=config_pb2.GenericChangeThreshold( - direction=config_pb2.MetricDirection.HIGHER_IS_BETTER, - absolute={'value': 1}, - ) + # Change thresholds would be ignored when rubber stamp is true. + config_pb2.PerSliceMetricThreshold( + slicing_specs=slicing_specs, + threshold=config_pb2.MetricThreshold( + change_threshold=config_pb2.GenericChangeThreshold( + direction=config_pb2.MetricDirection.HIGHER_IS_BETTER, + absolute={"value": 1}, + ) + ), ), - ), - ], + ], + ) ) - ) - for class_id in (0, 5): - metrics_spec.binarize.class_ids.values.append(class_id) - eval_config = config_pb2.EvalConfig( - model_specs=[config_pb2.ModelSpec(label_key='label')], - metrics_specs=[metrics_spec], - ) - if model_type != constants.TF_KERAS: - for s in eval_config.model_specs: - s.model_type = model_type - - model = _build_keras_model(eval_config, rubber_stamp=rubber_stamp) - baseline = _build_keras_model(eval_config, 'baseline_export') - if remove_baseline: - eval_shared_model = model - else: - eval_shared_model = {'candidate': model, 'baseline': baseline} - output_path = self._getTempDir() - # Raise RuntimeError for missing baseline with change thresholds. - if not rubber_stamp and remove_baseline: - with self.assertRaises(RuntimeError): - model_eval_lib.run_model_analysis( - eval_config=eval_config, - eval_shared_model=eval_shared_model, - data_location=data_location, - output_path=output_path, - schema=schema, + for class_id in (0, 5): + metrics_spec.binarize.class_ids.values.append(class_id) + eval_config = config_pb2.EvalConfig( + model_specs=[config_pb2.ModelSpec(label_key="label")], + metrics_specs=[metrics_spec], ) - # Will not have any result since the pipeline didn't run. - return - else: - eval_results = model_eval_lib.run_model_analysis( - eval_config=eval_config, - eval_shared_model=eval_shared_model, - data_location=data_location, - output_path=output_path, - schema=schema, - ) - - # Directly check validation file since it is not in EvalResult. - validations_file = os.path.join( - output_path, f'{constants.VALIDATIONS_KEY}.tfrecord' - ) - self.assertTrue(os.path.exists(validations_file)) - validation_records = [] - for record in tf.compat.v1.python_io.tf_record_iterator(validations_file): - validation_records.append( - validation_result_pb2.ValidationResult.FromString(record) - ) - self.assertLen(validation_records, 1) - # Change thresholds ignored when rubber stamping - expected_result = text_format.Parse( - """ + if model_type != constants.TF_KERAS: + for s in eval_config.model_specs: + s.model_type = model_type + + model = _build_keras_model(eval_config, rubber_stamp=rubber_stamp) + baseline = _build_keras_model(eval_config, "baseline_export") + if remove_baseline: + eval_shared_model = model + else: + eval_shared_model = {"candidate": model, "baseline": baseline} + output_path = self._getTempDir() + # Raise RuntimeError for missing baseline with change thresholds. + if not rubber_stamp and remove_baseline: + with self.assertRaises(RuntimeError): + model_eval_lib.run_model_analysis( + eval_config=eval_config, + eval_shared_model=eval_shared_model, + data_location=data_location, + output_path=output_path, + schema=schema, + ) + # Will not have any result since the pipeline didn't run. + return + else: + eval_results = model_eval_lib.run_model_analysis( + eval_config=eval_config, + eval_shared_model=eval_shared_model, + data_location=data_location, + output_path=output_path, + schema=schema, + ) + + # Directly check validation file since it is not in EvalResult. + validations_file = os.path.join( + output_path, f"{constants.VALIDATIONS_KEY}.tfrecord" + ) + self.assertTrue(os.path.exists(validations_file)) + validation_records = [] + for record in tf.compat.v1.python_io.tf_record_iterator(validations_file): + validation_records.append( + validation_result_pb2.ValidationResult.FromString(record) + ) + self.assertLen(validation_records, 1) + # Change thresholds ignored when rubber stamping + expected_result = text_format.Parse( + """ validation_ok: false rubber_stamp: %s missing_slices: { @@ -702,13 +688,14 @@ def _build_keras_model( } num_matching_slices: 1 } - }""" % rubber_stamp, - validation_result_pb2.ValidationResult(), - ) - # Normal run with change threshold not satisfied. - if not rubber_stamp and not remove_baseline: - text_format.Parse( - """ + }""" + % rubber_stamp, + validation_result_pb2.ValidationResult(), + ) + # Normal run with change threshold not satisfied. + if not rubber_stamp and not remove_baseline: + text_format.Parse( + """ metric_validations_per_slice { slice_key {} failures { @@ -748,148 +735,147 @@ def _build_keras_model( metric_value { double_value {} } } }""", - expected_result, - ) - self.assertProtoEquals(expected_result, validation_records[0]) - - def check_eval_result(eval_result, model_location): - self.assertEqual(eval_result.model_location, model_location) - self.assertEqual(eval_result.data_location, data_location) - self.assertLen(eval_result.slicing_metrics, 1) - got_slice_key, got_metrics = eval_result.slicing_metrics[0] - self.assertEqual(got_slice_key, ()) - self.assertIn('', got_metrics) # output_name - got_metrics = got_metrics[''] - expected_metrics = { - 'classId:0': { - 'auc': True, - }, - 'classId:5': { - 'auc': True, - }, - } - if ( - model_type - not in (constants.TF_LITE, constants.TF_JS, constants.TF_KERAS) - and _TF_MAJOR_VERSION >= 2 - ): - expected_metrics[''] = {'loss': True} - if add_custom_metrics: - expected_metrics['']['custom'] = True - for class_id in expected_metrics: - self.assertIn(class_id, got_metrics) - for k in expected_metrics[class_id]: - self.assertIn(k, got_metrics[class_id]) - - # TODO(b/173657964): assert exception for the missing baseline but non - # rubber stamping test. - if rubber_stamp or remove_baseline: - self.assertIsInstance(eval_results, view_types.EvalResult) - check_eval_result(eval_results, model.model_path) - else: - self.assertLen(eval_results._results, 2) - eval_result_0, eval_result_1 = eval_results._results - check_eval_result(eval_result_0, model.model_path) - check_eval_result(eval_result_1, baseline.model_path) - - def testRunModelAnalysisWithKerasMultiOutputModel(self): - - def _build_keras_model(eval_config, export_name='export_dir'): - layers_per_output = {} - for output_name in ('output_1', 'output_2'): - layers_per_output[output_name] = tf_keras.layers.Input( - shape=(1,), name=output_name + expected_result, + ) + self.assertProtoEquals(expected_result, validation_records[0]) + + def check_eval_result(eval_result, model_location): + self.assertEqual(eval_result.model_location, model_location) + self.assertEqual(eval_result.data_location, data_location) + self.assertLen(eval_result.slicing_metrics, 1) + got_slice_key, got_metrics = eval_result.slicing_metrics[0] + self.assertEqual(got_slice_key, ()) + self.assertIn("", got_metrics) # output_name + got_metrics = got_metrics[""] + expected_metrics = { + "classId:0": { + "auc": True, + }, + "classId:5": { + "auc": True, + }, + } + if ( + model_type + not in (constants.TF_LITE, constants.TF_JS, constants.TF_KERAS) + and _TF_MAJOR_VERSION >= 2 + ): + expected_metrics[""] = {"loss": True} + if add_custom_metrics: + expected_metrics[""]["custom"] = True + for class_id in expected_metrics: + self.assertIn(class_id, got_metrics) + for k in expected_metrics[class_id]: + self.assertIn(k, got_metrics[class_id]) + + # TODO(b/173657964): assert exception for the missing baseline but non + # rubber stamping test. + if rubber_stamp or remove_baseline: + self.assertIsInstance(eval_results, view_types.EvalResult) + check_eval_result(eval_results, model.model_path) + else: + self.assertLen(eval_results._results, 2) + eval_result_0, eval_result_1 = eval_results._results + check_eval_result(eval_result_0, model.model_path) + check_eval_result(eval_result_1, baseline.model_path) + + def testRunModelAnalysisWithKerasMultiOutputModel(self): + def _build_keras_model(eval_config, export_name="export_dir"): + layers_per_output = {} + for output_name in ("output_1", "output_2"): + layers_per_output[output_name] = tf_keras.layers.Input( + shape=(1,), name=output_name + ) + model = tf_keras.models.Model(layers_per_output, layers_per_output) + model.compile(loss=tf_keras.losses.categorical_crossentropy) + model_location = os.path.join(self._getTempDir(), export_name) + model.export(model_location) + return model_eval_lib.default_eval_shared_model( + eval_saved_model_path=model_location, + eval_config=eval_config, + rubber_stamp=False, + ) + + examples = [ + self._makeExample(output_1=1.0, output_2=0.0, label_1=0.0, label_2=0.0), + self._makeExample(output_1=0.7, output_2=0.3, label_1=1.0, label_2=1.0), + self._makeExample(output_1=0.5, output_2=0.8, label_1=0.0, label_2=1.0), + ] + data_location = self._writeTFExamplesToTFRecords(examples) + + metrics_spec = config_pb2.MetricsSpec( + output_names=["output_1", "output_2"], + output_weights={"output_1": 1.0, "output_2": 1.0}, ) - model = tf_keras.models.Model(layers_per_output, layers_per_output) - model.compile(loss=tf_keras.losses.categorical_crossentropy) - model_location = os.path.join(self._getTempDir(), export_name) - model.export(model_location) - return model_eval_lib.default_eval_shared_model( - eval_saved_model_path=model_location, - eval_config=eval_config, - rubber_stamp=False, - ) - - examples = [ - self._makeExample(output_1=1.0, output_2=0.0, label_1=0.0, label_2=0.0), - self._makeExample(output_1=0.7, output_2=0.3, label_1=1.0, label_2=1.0), - self._makeExample(output_1=0.5, output_2=0.8, label_1=0.0, label_2=1.0), - ] - data_location = self._writeTFExamplesToTFRecords(examples) - - metrics_spec = config_pb2.MetricsSpec( - output_names=['output_1', 'output_2'], - output_weights={'output_1': 1.0, 'output_2': 1.0}, - ) - for metric in (confusion_matrix_metrics.AUC(name='auc'),): - cfg = metric_util.serialize_keras_object(metric) - metrics_spec.metrics.append( - config_pb2.MetricConfig( - class_name=cfg['class_name'], config=json.dumps(cfg['config']) - ) - ) - slicing_specs = [ - config_pb2.SlicingSpec(), - config_pb2.SlicingSpec(feature_keys=['non_existent_slice']), - ] - metrics_spec.metrics.append( - config_pb2.MetricConfig( - class_name='ExampleCount', - per_slice_thresholds=[ - config_pb2.PerSliceMetricThreshold( - slicing_specs=slicing_specs, - threshold=config_pb2.MetricThreshold( - value_threshold=config_pb2.GenericValueThreshold( - lower_bound={'value': 1} - ) + for metric in (confusion_matrix_metrics.AUC(name="auc"),): + cfg = metric_util.serialize_keras_object(metric) + metrics_spec.metrics.append( + config_pb2.MetricConfig( + class_name=cfg["class_name"], config=json.dumps(cfg["config"]) + ) + ) + slicing_specs = [ + config_pb2.SlicingSpec(), + config_pb2.SlicingSpec(feature_keys=["non_existent_slice"]), + ] + metrics_spec.metrics.append( + config_pb2.MetricConfig( + class_name="ExampleCount", + per_slice_thresholds=[ + config_pb2.PerSliceMetricThreshold( + slicing_specs=slicing_specs, + threshold=config_pb2.MetricThreshold( + value_threshold=config_pb2.GenericValueThreshold( + lower_bound={"value": 1} + ) + ), ), - ), - # Change thresholds would be ignored when rubber stamp is true. - config_pb2.PerSliceMetricThreshold( - slicing_specs=slicing_specs, - threshold=config_pb2.MetricThreshold( - change_threshold=config_pb2.GenericChangeThreshold( - direction=config_pb2.MetricDirection.HIGHER_IS_BETTER, - absolute={'value': 1}, - ) + # Change thresholds would be ignored when rubber stamp is true. + config_pb2.PerSliceMetricThreshold( + slicing_specs=slicing_specs, + threshold=config_pb2.MetricThreshold( + change_threshold=config_pb2.GenericChangeThreshold( + direction=config_pb2.MetricDirection.HIGHER_IS_BETTER, + absolute={"value": 1}, + ) + ), ), - ), + ], + ) + ) + eval_config = config_pb2.EvalConfig( + model_specs=[ + config_pb2.ModelSpec( + label_keys={"output_1": "label_1", "output_2": "label_2"} + ) ], + metrics_specs=[metrics_spec], ) - ) - eval_config = config_pb2.EvalConfig( - model_specs=[ - config_pb2.ModelSpec( - label_keys={'output_1': 'label_1', 'output_2': 'label_2'} - ) - ], - metrics_specs=[metrics_spec], - ) - model = _build_keras_model(eval_config) - baseline = _build_keras_model(eval_config, 'baseline_export') - eval_shared_model = {'candidate': model, 'baseline': baseline} - output_path = self._getTempDir() - eval_results = model_eval_lib.run_model_analysis( - eval_config=eval_config, - eval_shared_model=eval_shared_model, - data_location=data_location, - output_path=output_path, - ) + model = _build_keras_model(eval_config) + baseline = _build_keras_model(eval_config, "baseline_export") + eval_shared_model = {"candidate": model, "baseline": baseline} + output_path = self._getTempDir() + eval_results = model_eval_lib.run_model_analysis( + eval_config=eval_config, + eval_shared_model=eval_shared_model, + data_location=data_location, + output_path=output_path, + ) - # Directly check validation file since it is not in EvalResult. - validations_file = os.path.join( - output_path, f'{constants.VALIDATIONS_KEY}.tfrecord' - ) - self.assertTrue(os.path.exists(validations_file)) - validation_records = [] - for record in tf.compat.v1.python_io.tf_record_iterator(validations_file): - validation_records.append( - validation_result_pb2.ValidationResult.FromString(record) - ) - self.assertLen(validation_records, 1) - expected_result = text_format.Parse( - """ + # Directly check validation file since it is not in EvalResult. + validations_file = os.path.join( + output_path, f"{constants.VALIDATIONS_KEY}.tfrecord" + ) + self.assertTrue(os.path.exists(validations_file)) + validation_records = [] + for record in tf.compat.v1.python_io.tf_record_iterator(validations_file): + validation_records.append( + validation_result_pb2.ValidationResult.FromString(record) + ) + self.assertLen(validation_records, 1) + expected_result = text_format.Parse( + """ metric_validations_per_slice { slice_key {} failures { @@ -934,56 +920,54 @@ def _build_keras_model(eval_config, export_name='export_dir'): num_matching_slices: 1 } }""", - validation_result_pb2.ValidationResult(), - ) - self.assertProtoEquals(expected_result, validation_records[0]) - - def check_eval_result(eval_result, model_location): - self.assertEqual(eval_result.model_location, model_location) - self.assertEqual(eval_result.data_location, data_location) - self.assertLen(eval_result.slicing_metrics, 1) - got_slice_key, got_metrics = eval_result.slicing_metrics[0] - self.assertEqual(got_slice_key, ()) - self.assertIn('output_1', got_metrics) - self.assertIn('auc', got_metrics['output_1']['']) - self.assertIn('output_2', got_metrics) - self.assertIn('auc', got_metrics['output_2']['']) - # Aggregate metrics - self.assertIn('', got_metrics) - self.assertIn('auc', got_metrics['']['']) - - # TODO(b/173657964): assert exception for the missing baseline but non - # rubber stamping test. - self.assertLen(eval_results._results, 2) - eval_result_0, eval_result_1 = eval_results._results - check_eval_result(eval_result_0, model.model_path) - check_eval_result(eval_result_1, baseline.model_path) - - def testRunModelAnalysisWithQueryBasedMetrics(self): - input_layer = tf_keras.layers.Input(shape=(1,), name='age') - output_layer = tf_keras.layers.Dense(1, activation=tf.nn.sigmoid)( - input_layer - ) - model = tf_keras.models.Model(input_layer, output_layer) - model.compile( - optimizer=tf_keras.optimizers.Adam(lr=0.001), - loss=tf_keras.losses.binary_crossentropy, - ) + validation_result_pb2.ValidationResult(), + ) + self.assertProtoEquals(expected_result, validation_records[0]) + + def check_eval_result(eval_result, model_location): + self.assertEqual(eval_result.model_location, model_location) + self.assertEqual(eval_result.data_location, data_location) + self.assertLen(eval_result.slicing_metrics, 1) + got_slice_key, got_metrics = eval_result.slicing_metrics[0] + self.assertEqual(got_slice_key, ()) + self.assertIn("output_1", got_metrics) + self.assertIn("auc", got_metrics["output_1"][""]) + self.assertIn("output_2", got_metrics) + self.assertIn("auc", got_metrics["output_2"][""]) + # Aggregate metrics + self.assertIn("", got_metrics) + self.assertIn("auc", got_metrics[""][""]) + + # TODO(b/173657964): assert exception for the missing baseline but non + # rubber stamping test. + self.assertLen(eval_results._results, 2) + eval_result_0, eval_result_1 = eval_results._results + check_eval_result(eval_result_0, model.model_path) + check_eval_result(eval_result_1, baseline.model_path) + + def testRunModelAnalysisWithQueryBasedMetrics(self): + input_layer = tf_keras.layers.Input(shape=(1,), name="age") + output_layer = tf_keras.layers.Dense(1, activation=tf.nn.sigmoid)(input_layer) + model = tf_keras.models.Model(input_layer, output_layer) + model.compile( + optimizer=tf_keras.optimizers.Adam(lr=0.001), + loss=tf_keras.losses.binary_crossentropy, + ) - features = {'age': [[20.0]]} - labels = [[1]] - example_weights = [1.0] - dataset = tf.data.Dataset.from_tensor_slices( - (features, labels, example_weights) - ) - dataset = dataset.shuffle(buffer_size=1).repeat().batch(1) - model.fit(dataset, steps_per_epoch=1) + features = {"age": [[20.0]]} + labels = [[1]] + example_weights = [1.0] + dataset = tf.data.Dataset.from_tensor_slices( + (features, labels, example_weights) + ) + dataset = dataset.shuffle(buffer_size=1).repeat().batch(1) + model.fit(dataset, steps_per_epoch=1) - model_location = os.path.join(self._getTempDir(), 'export_dir') - model.export(model_location) + model_location = os.path.join(self._getTempDir(), "export_dir") + model.export(model_location) - schema = text_format.Parse( - """ + schema = text_format.Parse( + """ tensor_representation_group { key: "" value { @@ -1024,341 +1008,345 @@ def testRunModelAnalysisWithQueryBasedMetrics(self): type: INT } """, - schema_pb2.Schema(), - ) - examples = [ - self._makeExample(age=3.0, language='english', label=1.0, varlen=[0]), - self._makeExample(age=5.0, language='chinese', label=0.0, varlen=[1]), - self._makeExample(age=3.0, language='english', label=0.0, varlen=[2]), - self._makeExample( - age=5.0, language='chinese', label=1.0, varlen=[3, 4] - ), - ] - data_location = self._writeTFExamplesToTFRecords(examples) - slicing_specs = [config_pb2.SlicingSpec()] - # Test with both a TFMA metric (NDCG), a keras metric (Recall). - metrics = [ - ndcg.NDCG(gain_key='age', name='ndcg', top_k_list=[1, 2]), - tf_keras.metrics.Recall(top_k=1), - ] - # If tensorflow-ranking imported add MRRMetric. - if _TFR_IMPORTED: - metrics.append(tfr.keras.metrics.MRRMetric()) - metrics_specs = metric_specs.specs_from_metrics( - metrics, query_key='language', include_weighted_example_count=True - ) - metrics_specs.append( - config_pb2.MetricsSpec( - metrics=[ - config_pb2.MetricConfig( - class_name='ExampleCount', - threshold=config_pb2.MetricThreshold( - value_threshold=config_pb2.GenericValueThreshold( - lower_bound={'value': 0} - ) - ), - ) - ] + schema_pb2.Schema(), ) - ) - eval_config = config_pb2.EvalConfig( - model_specs=[config_pb2.ModelSpec(label_key='label')], - slicing_specs=slicing_specs, - metrics_specs=metrics_specs, - ) - eval_shared_model = model_eval_lib.default_eval_shared_model( - eval_saved_model_path=model_location, eval_config=eval_config - ) - output_path = self._getTempDir() - eval_result = model_eval_lib.run_model_analysis( - eval_config=eval_config, - eval_shared_model=eval_shared_model, - data_location=data_location, - output_path=output_path, - evaluators=[ - metrics_plots_and_validations_evaluator.MetricsPlotsAndValidationsEvaluator( - eval_config=eval_config, eval_shared_model=eval_shared_model - ) - ], - schema=schema, - ) - - # Directly check validation file since it is not in EvalResult. - validations_file = os.path.join( - output_path, f'{constants.VALIDATIONS_KEY}.tfrecord' - ) - self.assertTrue(os.path.exists(validations_file)) - validation_records = [] - for record in tf.compat.v1.python_io.tf_record_iterator(validations_file): - validation_records.append( - validation_result_pb2.ValidationResult.FromString(record) - ) - self.assertLen(validation_records, 1) - self.assertTrue(validation_records[0].validation_ok) - - self.assertEqual(eval_result.model_location, model_location) - self.assertEqual(eval_result.data_location, data_location) - self.assertLen(eval_result.slicing_metrics, 1) - got_slice_key, got_metrics = eval_result.slicing_metrics[0] - self.assertEqual(got_slice_key, ()) - self.assertIn('', got_metrics) # output_name - got_metrics = got_metrics[''] - expected_metrics = { - '': { - 'example_count': True, - 'weighted_example_count': True, - }, - 'topK:1': { - 'ndcg': True, - 'recall': True, - }, - 'topK:2': { - 'ndcg': True, - }, - } - if _TFR_IMPORTED: - expected_metrics['']['mrr_metric'] = True - for group in expected_metrics: - self.assertIn(group, got_metrics) - for k in expected_metrics[group]: - self.assertIn(k, got_metrics[group]) - - # PR 189: Remove the `skip` mark if the test passes for all supported versions - # of python - @unittest.skip('Fails for some versions of Python, including 3.9') - def testRunModelAnalysisWithUncertainty(self): - examples = [ - self._makeExample(age=3.0, language='english', label=1.0), - self._makeExample(age=3.0, language='chinese', label=0.0), - self._makeExample(age=4.0, language='english', label=1.0), - self._makeExample(age=5.0, language='chinese', label=1.0), - self._makeExample(age=5.0, language='hindi', label=1.0), - ] - classifier = example_keras_model.get_example_classifier_model( - example_keras_model.LANGUAGE - ) - classifier.compile(optimizer=keras.optimizers.Adam(), loss='mse') - classifier.fit( - tf.constant([e.SerializeToString() for e in examples]), - np.array([ - e.features.feature[example_keras_model.LABEL].float_list.value[:][0] - for e in examples - ]), - batch_size=1, - ) - eval_config = config_pb2.EvalConfig( - model_specs=[ - config_pb2.ModelSpec( - name='model1', example_weight_key='age', label_key='label' - ) - ], - slicing_specs=[config_pb2.SlicingSpec(feature_keys=['language'])], - metrics_specs=[ + examples = [ + self._makeExample(age=3.0, language="english", label=1.0, varlen=[0]), + self._makeExample(age=5.0, language="chinese", label=0.0, varlen=[1]), + self._makeExample(age=3.0, language="english", label=0.0, varlen=[2]), + self._makeExample(age=5.0, language="chinese", label=1.0, varlen=[3, 4]), + ] + data_location = self._writeTFExamplesToTFRecords(examples) + slicing_specs = [config_pb2.SlicingSpec()] + # Test with both a TFMA metric (NDCG), a keras metric (Recall). + metrics = [ + ndcg.NDCG(gain_key="age", name="ndcg", top_k_list=[1, 2]), + tf_keras.metrics.Recall(top_k=1), + ] + # If tensorflow-ranking imported add MRRMetric. + if _TFR_IMPORTED: + metrics.append(tfr.keras.metrics.MRRMetric()) + metrics_specs = metric_specs.specs_from_metrics( + metrics, query_key="language", include_weighted_example_count=True + ) + metrics_specs.append( config_pb2.MetricsSpec( metrics=[ config_pb2.MetricConfig( - class_name='ExampleCount', - ), - config_pb2.MetricConfig( - class_name='Accuracy', - ), + class_name="ExampleCount", + threshold=config_pb2.MetricThreshold( + value_threshold=config_pb2.GenericValueThreshold( + lower_bound={"value": 0} + ) + ), + ) ] ) - ], - options=config_pb2.Options( - compute_confidence_intervals=wrappers_pb2.BoolValue(value=True), - min_slice_size=wrappers_pb2.Int32Value(value=2), - ), - ) - model_location = self._exportKerasModel(classifier) - data_location = self._writeTFExamplesToTFRecords(examples) - eval_result = model_eval_lib.run_model_analysis( - eval_shared_model=model_eval_lib.default_eval_shared_model( + ) + eval_config = config_pb2.EvalConfig( + model_specs=[config_pb2.ModelSpec(label_key="label")], + slicing_specs=slicing_specs, + metrics_specs=metrics_specs, + ) + eval_shared_model = model_eval_lib.default_eval_shared_model( eval_saved_model_path=model_location, eval_config=eval_config - ), - data_location=data_location, - eval_config=eval_config, - output_path=self._getTempDir(), - ) - # We only check some of the metrics to ensure that the end-to-end - # pipeline works. - expected = { - (('language', 'hindi'),): { - '__ERROR__': { - 'debugMessage': ( - 'Example count for this slice key is lower than the ' - 'minimum required value: 2. No data is aggregated for ' - 'this slice.' + ) + output_path = self._getTempDir() + eval_result = model_eval_lib.run_model_analysis( + eval_config=eval_config, + eval_shared_model=eval_shared_model, + data_location=data_location, + output_path=output_path, + evaluators=[ + metrics_plots_and_validations_evaluator.MetricsPlotsAndValidationsEvaluator( + eval_config=eval_config, eval_shared_model=eval_shared_model ) + ], + schema=schema, + ) + + # Directly check validation file since it is not in EvalResult. + validations_file = os.path.join( + output_path, f"{constants.VALIDATIONS_KEY}.tfrecord" + ) + self.assertTrue(os.path.exists(validations_file)) + validation_records = [] + for record in tf.compat.v1.python_io.tf_record_iterator(validations_file): + validation_records.append( + validation_result_pb2.ValidationResult.FromString(record) + ) + self.assertLen(validation_records, 1) + self.assertTrue(validation_records[0].validation_ok) + + self.assertEqual(eval_result.model_location, model_location) + self.assertEqual(eval_result.data_location, data_location) + self.assertLen(eval_result.slicing_metrics, 1) + got_slice_key, got_metrics = eval_result.slicing_metrics[0] + self.assertEqual(got_slice_key, ()) + self.assertIn("", got_metrics) # output_name + got_metrics = got_metrics[""] + expected_metrics = { + "": { + "example_count": True, + "weighted_example_count": True, }, - }, - (('language', 'english'),): { - 'accuracy': { - 'boundedValue': { - 'lowerBound': 0.0, - 'upperBound': 0.0, - 'value': 0.0, - } + "topK:1": { + "ndcg": True, + "recall": True, }, - 'example_count': {'doubleValue': 7.0}, - }, - (('language', 'chinese'),): { - 'accuracy': { - 'boundedValue': { - 'lowerBound': 0.0, - 'upperBound': 0.0, - 'value': 0.0, - } + "topK:2": { + "ndcg": True, }, - 'example_count': {'doubleValue': 8.0}, - }, - } - self.assertEqual(eval_result.model_location, model_location) - self.assertEqual(eval_result.data_location, data_location) - self.assertEqual( - eval_result.config.slicing_specs[0], - config_pb2.SlicingSpec(feature_keys=['language']), - ) - self.assertMetricsAlmostEqual(eval_result.slicing_metrics, expected) - for _, plot in eval_result.plots: - self.assertFalse(plot) - - def testRunModelAnalysisWithDeterministicConfidenceIntervals(self): - examples = [ - self._makeExample(age=3.0, language='english', label=1.0), - self._makeExample(age=3.0, language='chinese', label=0.0), - self._makeExample(age=4.0, language='english', label=1.0), - self._makeExample(age=5.0, language='chinese', label=1.0), - self._makeExample(age=5.0, language='hindi', label=1.0), - ] - classifier = example_keras_model.get_example_classifier_model( - example_keras_model.LANGUAGE - ) - classifier.compile(optimizer=keras.optimizers.Adam(), loss='mse') - classifier.fit( - tf.constant([e.SerializeToString() for e in examples]), - np.array([ - e.features.feature[example_keras_model.LABEL].float_list.value[:][0] - for e in examples - ]), - batch_size=1, - ) - model_location = self._exportKerasModel(classifier) - data_location = self._writeTFExamplesToTFRecords(examples) - eval_config = config_pb2.EvalConfig( - model_specs=[ - config_pb2.ModelSpec( - name='model1', example_weight_key='age', label_key='label' - ) - ], - slicing_specs=[config_pb2.SlicingSpec(feature_keys=['language'])], - metrics_specs=[ - config_pb2.MetricsSpec( - metrics=[ - config_pb2.MetricConfig( - class_name='ExampleCount', - ), - config_pb2.MetricConfig( - class_name='Accuracy', - ), + } + if _TFR_IMPORTED: + expected_metrics[""]["mrr_metric"] = True + for group in expected_metrics: + self.assertIn(group, got_metrics) + for k in expected_metrics[group]: + self.assertIn(k, got_metrics[group]) + + # PR 189: Remove the `skip` mark if the test passes for all supported versions + # of python + @unittest.skip("Fails for some versions of Python, including 3.9") + def testRunModelAnalysisWithUncertainty(self): + examples = [ + self._makeExample(age=3.0, language="english", label=1.0), + self._makeExample(age=3.0, language="chinese", label=0.0), + self._makeExample(age=4.0, language="english", label=1.0), + self._makeExample(age=5.0, language="chinese", label=1.0), + self._makeExample(age=5.0, language="hindi", label=1.0), + ] + classifier = example_keras_model.get_example_classifier_model( + example_keras_model.LANGUAGE + ) + classifier.compile(optimizer=keras.optimizers.Adam(), loss="mse") + classifier.fit( + tf.constant([e.SerializeToString() for e in examples]), + np.array( + [ + e.features.feature[example_keras_model.LABEL].float_list.value[:][0] + for e in examples ] - ) - ], - options=config_pb2.Options( - compute_confidence_intervals=wrappers_pb2.BoolValue(value=True), - min_slice_size=wrappers_pb2.Int32Value(value=2), - ), - ) - eval_result = model_eval_lib.run_model_analysis( - eval_shared_model=model_eval_lib.default_eval_shared_model( - eval_saved_model_path=model_location, eval_config=eval_config - ), - data_location=data_location, - output_path=self._getTempDir(), - eval_config=eval_config, - random_seed_for_testing=_TEST_SEED, - ) - # We only check some of the metrics to ensure that the end-to-end - # pipeline works. - expected = { - (('language', 'hindi'),): { - '__ERROR__': { - 'debugMessage': ( - 'Example count for this slice key is lower than the ' - 'minimum required value: 2. No data is aggregated for ' - 'this slice.' + ), + batch_size=1, + ) + eval_config = config_pb2.EvalConfig( + model_specs=[ + config_pb2.ModelSpec( + name="model1", example_weight_key="age", label_key="label" + ) + ], + slicing_specs=[config_pb2.SlicingSpec(feature_keys=["language"])], + metrics_specs=[ + config_pb2.MetricsSpec( + metrics=[ + config_pb2.MetricConfig( + class_name="ExampleCount", + ), + config_pb2.MetricConfig( + class_name="Accuracy", + ), + ] ) + ], + options=config_pb2.Options( + compute_confidence_intervals=wrappers_pb2.BoolValue(value=True), + min_slice_size=wrappers_pb2.Int32Value(value=2), + ), + ) + model_location = self._exportKerasModel(classifier) + data_location = self._writeTFExamplesToTFRecords(examples) + eval_result = model_eval_lib.run_model_analysis( + eval_shared_model=model_eval_lib.default_eval_shared_model( + eval_saved_model_path=model_location, eval_config=eval_config + ), + data_location=data_location, + eval_config=eval_config, + output_path=self._getTempDir(), + ) + # We only check some of the metrics to ensure that the end-to-end + # pipeline works. + expected = { + (("language", "hindi"),): { + "__ERROR__": { + "debugMessage": ( + "Example count for this slice key is lower than the " + "minimum required value: 2. No data is aggregated for " + "this slice." + ) + }, }, - }, - (('language', 'english'),): { - 'accuracy': { - 'boundedValue': { - 'lowerBound': 0.0, - 'upperBound': 0.0, - 'value': 0.0, - } + (("language", "english"),): { + "accuracy": { + "boundedValue": { + "lowerBound": 0.0, + "upperBound": 0.0, + "value": 0.0, + } + }, + "example_count": {"doubleValue": 7.0}, }, - 'example_count': {'doubleValue': 7.0}, - }, - (('language', 'chinese'),): { - 'accuracy': { - 'boundedValue': { - 'lowerBound': 0.0, - 'upperBound': 0.0, - 'value': 0.0, - } + (("language", "chinese"),): { + "accuracy": { + "boundedValue": { + "lowerBound": 0.0, + "upperBound": 0.0, + "value": 0.0, + } + }, + "example_count": {"doubleValue": 8.0}, }, - 'example_count': {'doubleValue': 8.0}, - }, - } - self.assertEqual(eval_result.model_location, model_location) - self.assertEqual(eval_result.data_location, data_location) - self.assertEqual( - eval_result.config.slicing_specs[0], - config_pb2.SlicingSpec(feature_keys=['language']), - ) - self.assertMetricsAlmostEqual(eval_result.slicing_metrics, expected) - - for key, value in eval_result.slicing_metrics: - if (('language', 'english'),) == key: - metric = value['']['']['accuracy'] - self.assertAlmostEqual(0.0, metric['boundedValue']['value'], delta=0.1) - - for _, plot in eval_result.plots: - self.assertFalse(plot) - - # TODO(b/350996394): Add test for plots and CSVtext with Keras model. - - def testRunModelAnalysisWithSchema(self): - examples = [ - self._makeExample(age=3.0, language='english', label=2.0), - self._makeExample(age=3.0, language='chinese', label=1.0), - self._makeExample(age=4.0, language='english', label=2.0), - self._makeExample(age=5.0, language='chinese', label=2.0), - self._makeExample(age=5.0, language='hindi', label=2.0), - ] - data_location = self._writeTFExamplesToTFRecords(examples) - classifier = example_keras_model.get_example_classifier_model( - example_keras_model.LANGUAGE - ) - classifier.compile(optimizer=keras.optimizers.Adam(), loss='mse') - classifier.fit( - tf.constant([e.SerializeToString() for e in examples]), - np.array([ - e.features.feature[example_keras_model.LABEL].float_list.value[:][0] - for e in examples - ]), - batch_size=1, - ) - model_location = self._exportKerasModel(classifier) - eval_config = config_pb2.EvalConfig( - model_specs=[config_pb2.ModelSpec(label_key='label')], - metrics_specs=metric_specs.specs_from_metrics( - [calibration_plot.CalibrationPlot(num_buckets=4)] - ), - ) - schema = text_format.Parse( - """ + } + self.assertEqual(eval_result.model_location, model_location) + self.assertEqual(eval_result.data_location, data_location) + self.assertEqual( + eval_result.config.slicing_specs[0], + config_pb2.SlicingSpec(feature_keys=["language"]), + ) + self.assertMetricsAlmostEqual(eval_result.slicing_metrics, expected) + for _, plot in eval_result.plots: + self.assertFalse(plot) + + def testRunModelAnalysisWithDeterministicConfidenceIntervals(self): + examples = [ + self._makeExample(age=3.0, language="english", label=1.0), + self._makeExample(age=3.0, language="chinese", label=0.0), + self._makeExample(age=4.0, language="english", label=1.0), + self._makeExample(age=5.0, language="chinese", label=1.0), + self._makeExample(age=5.0, language="hindi", label=1.0), + ] + classifier = example_keras_model.get_example_classifier_model( + example_keras_model.LANGUAGE + ) + classifier.compile(optimizer=keras.optimizers.Adam(), loss="mse") + classifier.fit( + tf.constant([e.SerializeToString() for e in examples]), + np.array( + [ + e.features.feature[example_keras_model.LABEL].float_list.value[:][0] + for e in examples + ] + ), + batch_size=1, + ) + model_location = self._exportKerasModel(classifier) + data_location = self._writeTFExamplesToTFRecords(examples) + eval_config = config_pb2.EvalConfig( + model_specs=[ + config_pb2.ModelSpec( + name="model1", example_weight_key="age", label_key="label" + ) + ], + slicing_specs=[config_pb2.SlicingSpec(feature_keys=["language"])], + metrics_specs=[ + config_pb2.MetricsSpec( + metrics=[ + config_pb2.MetricConfig( + class_name="ExampleCount", + ), + config_pb2.MetricConfig( + class_name="Accuracy", + ), + ] + ) + ], + options=config_pb2.Options( + compute_confidence_intervals=wrappers_pb2.BoolValue(value=True), + min_slice_size=wrappers_pb2.Int32Value(value=2), + ), + ) + eval_result = model_eval_lib.run_model_analysis( + eval_shared_model=model_eval_lib.default_eval_shared_model( + eval_saved_model_path=model_location, eval_config=eval_config + ), + data_location=data_location, + output_path=self._getTempDir(), + eval_config=eval_config, + random_seed_for_testing=_TEST_SEED, + ) + # We only check some of the metrics to ensure that the end-to-end + # pipeline works. + expected = { + (("language", "hindi"),): { + "__ERROR__": { + "debugMessage": ( + "Example count for this slice key is lower than the " + "minimum required value: 2. No data is aggregated for " + "this slice." + ) + }, + }, + (("language", "english"),): { + "accuracy": { + "boundedValue": { + "lowerBound": 0.0, + "upperBound": 0.0, + "value": 0.0, + } + }, + "example_count": {"doubleValue": 7.0}, + }, + (("language", "chinese"),): { + "accuracy": { + "boundedValue": { + "lowerBound": 0.0, + "upperBound": 0.0, + "value": 0.0, + } + }, + "example_count": {"doubleValue": 8.0}, + }, + } + self.assertEqual(eval_result.model_location, model_location) + self.assertEqual(eval_result.data_location, data_location) + self.assertEqual( + eval_result.config.slicing_specs[0], + config_pb2.SlicingSpec(feature_keys=["language"]), + ) + self.assertMetricsAlmostEqual(eval_result.slicing_metrics, expected) + + for key, value in eval_result.slicing_metrics: + if key == (("language", "english"),): + metric = value[""][""]["accuracy"] + self.assertAlmostEqual(0.0, metric["boundedValue"]["value"], delta=0.1) + + for _, plot in eval_result.plots: + self.assertFalse(plot) + + # TODO(b/350996394): Add test for plots and CSVtext with Keras model. + + def testRunModelAnalysisWithSchema(self): + examples = [ + self._makeExample(age=3.0, language="english", label=2.0), + self._makeExample(age=3.0, language="chinese", label=1.0), + self._makeExample(age=4.0, language="english", label=2.0), + self._makeExample(age=5.0, language="chinese", label=2.0), + self._makeExample(age=5.0, language="hindi", label=2.0), + ] + data_location = self._writeTFExamplesToTFRecords(examples) + classifier = example_keras_model.get_example_classifier_model( + example_keras_model.LANGUAGE + ) + classifier.compile(optimizer=keras.optimizers.Adam(), loss="mse") + classifier.fit( + tf.constant([e.SerializeToString() for e in examples]), + np.array( + [ + e.features.feature[example_keras_model.LABEL].float_list.value[:][0] + for e in examples + ] + ), + batch_size=1, + ) + model_location = self._exportKerasModel(classifier) + eval_config = config_pb2.EvalConfig( + model_specs=[config_pb2.ModelSpec(label_key="label")], + metrics_specs=metric_specs.specs_from_metrics( + [calibration_plot.CalibrationPlot(num_buckets=4)] + ), + ) + schema = text_format.Parse( + """ feature { name: "label" type: FLOAT @@ -1368,109 +1356,108 @@ def testRunModelAnalysisWithSchema(self): } } """, - schema_pb2.Schema(), - ) - eval_result = model_eval_lib.run_model_analysis( - eval_config=eval_config, - schema=schema, - eval_shared_model=model_eval_lib.default_eval_shared_model( - eval_saved_model_path=model_location, eval_config=eval_config - ), - data_location=data_location, - output_path=self._getTempDir(), - ) + schema_pb2.Schema(), + ) + eval_result = model_eval_lib.run_model_analysis( + eval_config=eval_config, + schema=schema, + eval_shared_model=model_eval_lib.default_eval_shared_model( + eval_saved_model_path=model_location, eval_config=eval_config + ), + data_location=data_location, + output_path=self._getTempDir(), + ) - expected_metrics = { - (): { - 'example_count': {'doubleValue': 5.0}, + expected_metrics = { + (): { + "example_count": {"doubleValue": 5.0}, + } } - } - self.assertMetricsAlmostEqual(eval_result.slicing_metrics, expected_metrics) - self.assertLen(eval_result.plots, 1) - slice_key, plots = eval_result.plots[0] - self.assertEqual((), slice_key) - got_buckets = plots['']['']['calibrationHistogramBuckets']['buckets'] - # buckets include (-inf, left) and (right, inf) by default, but we are - # interested in the values of left and right - self.assertEqual(1.0, got_buckets[1]['lowerThresholdInclusive']) - self.assertEqual(2.0, got_buckets[-2]['upperThresholdExclusive']) - - # PR 189: Remove the `expectedFailure` mark if the test passes - @unittest.expectedFailure - def testLoadValidationResult(self): - result = validation_result_pb2.ValidationResult(validation_ok=True) - path = os.path.join(absltest.get_default_test_tmpdir(), 'results.tfrecord') - with tf.io.TFRecordWriter(path) as writer: - writer.write(result.SerializeToString()) - loaded_result = model_eval_lib.load_validation_result(path) - self.assertTrue(loaded_result.validation_ok) - - # PR 189: Remove the `expectedFailure` mark if the test passes - @unittest.expectedFailure - def testLoadValidationResultDir(self): - result = validation_result_pb2.ValidationResult(validation_ok=True) - path = os.path.join( - absltest.get_default_test_tmpdir(), constants.VALIDATIONS_KEY - ) - with tf.io.TFRecordWriter(path) as writer: - writer.write(result.SerializeToString()) - loaded_result = model_eval_lib.load_validation_result(os.path.dirname(path)) - self.assertTrue(loaded_result.validation_ok) - - # PR 189: Remove the `expectedFailure` mark if the test passes - @unittest.expectedFailure - def testLoadValidationResultEmptyFile(self): - path = os.path.join( - absltest.get_default_test_tmpdir(), constants.VALIDATIONS_KEY - ) - with tf.io.TFRecordWriter(path): - pass - with self.assertRaises(AssertionError): - model_eval_lib.load_validation_result(path) - - def testAnalyzeRawData(self): - - # Data - # age language label prediction - # 17 english 0 0 - # 30 spanish 1 1 - dict_data = [ - {'age': 17, 'language': 'english', 'prediction': 0, 'label': 0}, - {'age': 30, 'language': 'spanish', 'prediction': 1, 'label': 1}, - ] - df_data = pd.DataFrame(dict_data) - - # Expected Output - expected_slicing_metrics = { - (('language', 'english'),): { - '': { - '': { - 'binary_accuracy': {'doubleValue': 1.0}, - 'example_count': {'doubleValue': 1.0}, + self.assertMetricsAlmostEqual(eval_result.slicing_metrics, expected_metrics) + self.assertLen(eval_result.plots, 1) + slice_key, plots = eval_result.plots[0] + self.assertEqual((), slice_key) + got_buckets = plots[""][""]["calibrationHistogramBuckets"]["buckets"] + # buckets include (-inf, left) and (right, inf) by default, but we are + # interested in the values of left and right + self.assertEqual(1.0, got_buckets[1]["lowerThresholdInclusive"]) + self.assertEqual(2.0, got_buckets[-2]["upperThresholdExclusive"]) + + # PR 189: Remove the `expectedFailure` mark if the test passes + @unittest.expectedFailure + def testLoadValidationResult(self): + result = validation_result_pb2.ValidationResult(validation_ok=True) + path = os.path.join(absltest.get_default_test_tmpdir(), "results.tfrecord") + with tf.io.TFRecordWriter(path) as writer: + writer.write(result.SerializeToString()) + loaded_result = model_eval_lib.load_validation_result(path) + self.assertTrue(loaded_result.validation_ok) + + # PR 189: Remove the `expectedFailure` mark if the test passes + @unittest.expectedFailure + def testLoadValidationResultDir(self): + result = validation_result_pb2.ValidationResult(validation_ok=True) + path = os.path.join( + absltest.get_default_test_tmpdir(), constants.VALIDATIONS_KEY + ) + with tf.io.TFRecordWriter(path) as writer: + writer.write(result.SerializeToString()) + loaded_result = model_eval_lib.load_validation_result(os.path.dirname(path)) + self.assertTrue(loaded_result.validation_ok) + + # PR 189: Remove the `expectedFailure` mark if the test passes + @unittest.expectedFailure + def testLoadValidationResultEmptyFile(self): + path = os.path.join( + absltest.get_default_test_tmpdir(), constants.VALIDATIONS_KEY + ) + with tf.io.TFRecordWriter(path): + pass + with self.assertRaises(AssertionError): + model_eval_lib.load_validation_result(path) + + def testAnalyzeRawData(self): + # Data + # age language label prediction + # 17 english 0 0 + # 30 spanish 1 1 + dict_data = [ + {"age": 17, "language": "english", "prediction": 0, "label": 0}, + {"age": 30, "language": "spanish", "prediction": 1, "label": 1}, + ] + df_data = pd.DataFrame(dict_data) + + # Expected Output + expected_slicing_metrics = { + (("language", "english"),): { + "": { + "": { + "binary_accuracy": {"doubleValue": 1.0}, + "example_count": {"doubleValue": 1.0}, + } } - } - }, - (('language', 'spanish'),): { - '': { - '': { - 'binary_accuracy': {'doubleValue': 1.0}, - 'example_count': {'doubleValue': 1.0}, + }, + (("language", "spanish"),): { + "": { + "": { + "binary_accuracy": {"doubleValue": 1.0}, + "example_count": {"doubleValue": 1.0}, + } } - } - }, - (): { - '': { - '': { - 'binary_accuracy': {'doubleValue': 1.0}, - 'example_count': {'doubleValue': 2.0}, + }, + (): { + "": { + "": { + "binary_accuracy": {"doubleValue": 1.0}, + "example_count": {"doubleValue": 2.0}, + } } - } - }, - } + }, + } - # Actual Output - eval_config = text_format.Parse( - """ + # Actual Output + eval_config = text_format.Parse( + """ model_specs { label_key: 'label' prediction_key: 'prediction' @@ -1484,111 +1471,118 @@ def testAnalyzeRawData(self): feature_keys: 'language' } """, - config_pb2.EvalConfig(), - ) - eval_result = model_eval_lib.analyze_raw_data(df_data, eval_config) + config_pb2.EvalConfig(), + ) + eval_result = model_eval_lib.analyze_raw_data(df_data, eval_config) - # Compare Actual and Expected - self.assertEqual( - len(eval_result.slicing_metrics), len(expected_slicing_metrics) - ) - for slicing_metric in eval_result.slicing_metrics: - slice_key, slice_val = slicing_metric - self.assertIn(slice_key, expected_slicing_metrics) - self.assertDictEqual(slice_val, expected_slicing_metrics[slice_key]) - - def testAnalyzeRawDataWithoutPrediction(self): - model_specs = [ - config_pb2.ModelSpec(prediction_key='nonexistent_prediction_key') - ] - metrics_specs = [ - config_pb2.MetricsSpec( - metrics=[config_pb2.MetricConfig(class_name='Accuracy')] + # Compare Actual and Expected + self.assertEqual( + len(eval_result.slicing_metrics), len(expected_slicing_metrics) ) - ] - eval_config = config_pb2.EvalConfig( - model_specs=model_specs, metrics_specs=metrics_specs - ) - df_data = pd.DataFrame([{ - 'prediction': 0, - 'label': 0, - }]) - with self.assertRaises(KeyError): - model_eval_lib.analyze_raw_data(df_data, eval_config) - - def testAnalyzeRawDataWithoutLabel(self): - model_specs = [config_pb2.ModelSpec(prediction_key='nonexistent_label_key')] - metrics_specs = [ - config_pb2.MetricsSpec( - metrics=[config_pb2.MetricConfig(class_name='Accuracy')] + for slicing_metric in eval_result.slicing_metrics: + slice_key, slice_val = slicing_metric + self.assertIn(slice_key, expected_slicing_metrics) + self.assertDictEqual(slice_val, expected_slicing_metrics[slice_key]) + + def testAnalyzeRawDataWithoutPrediction(self): + model_specs = [ + config_pb2.ModelSpec(prediction_key="nonexistent_prediction_key") + ] + metrics_specs = [ + config_pb2.MetricsSpec( + metrics=[config_pb2.MetricConfig(class_name="Accuracy")] + ) + ] + eval_config = config_pb2.EvalConfig( + model_specs=model_specs, metrics_specs=metrics_specs ) - ] - eval_config = config_pb2.EvalConfig( - model_specs=model_specs, metrics_specs=metrics_specs - ) - df_data = pd.DataFrame([{ - 'prediction': 0, - 'label': 0, - }]) - with self.assertRaises(KeyError): - model_eval_lib.analyze_raw_data(df_data, eval_config) - - def testBytesProcessedCountForSerializedExamples(self): - examples = [ - self._makeExample(age=3.0, language='english', label=1.0), - self._makeExample(age=3.0, language='chinese', label=0.0), - self._makeExample(age=4.0, language='english', label=1.0), - self._makeExample(age=5.0, language='chinese', label=1.0), - self._makeExample(age=5.0, language='hindi', label=1.0), - ] - serialized_examples = [example.SerializeToString() for example in examples] - expected_num_bytes = sum([len(se) for se in serialized_examples]) - with beam.Pipeline() as p: - _ = ( - p - | beam.Create(serialized_examples) - | 'InputsToExtracts' >> model_eval_lib.InputsToExtracts() - | 'ExtractAndEvaluate' - >> model_eval_lib.ExtractAndEvaluate(extractors=[], evaluators=[]) - ) - pipeline_result = p.run() - metrics = pipeline_result.metrics() - actual_counter = metrics.query( - beam.metrics.metric.MetricsFilter().with_name('extract_input_bytes') - )['counters'] - self.assertLen(actual_counter, 1) - self.assertEqual(actual_counter[0].committed, expected_num_bytes) - - def testBytesProcessedCountForRecordBatches(self): - examples = [ - self._makeExample(age=3.0, language='english', label=1.0), - self._makeExample(age=3.0, language='chinese', label=0.0), - self._makeExample(age=4.0, language='english', label=1.0), - self._makeExample(age=5.0, language='chinese', label=1.0), - self._makeExample(age=5.0, language='hindi', label=1.0), - ] - examples = [example.SerializeToString() for example in examples] - decoder = example_coder.ExamplesToRecordBatchDecoder() - record_batch = decoder.DecodeBatch(examples) - expected_num_bytes = record_batch.nbytes - with beam.Pipeline() as p: - _ = ( - p - | beam.Create(record_batch) - | 'BatchedInputsToExtracts' - >> model_eval_lib.BatchedInputsToExtracts() - | 'ExtractAndEvaluate' - >> model_eval_lib.ExtractAndEvaluate(extractors=[], evaluators=[]) - ) - pipeline_result = p.run() - metrics = pipeline_result.metrics() - actual_counter = metrics.query( - beam.metrics.metric.MetricsFilter().with_name('extract_input_bytes') - )[metrics.COUNTERS] - self.assertLen(actual_counter, 1) - self.assertEqual(actual_counter[0].committed, expected_num_bytes) - - -if __name__ == '__main__': - tf.compat.v1.enable_v2_behavior() - tf.test.main() + df_data = pd.DataFrame( + [ + { + "prediction": 0, + "label": 0, + } + ] + ) + with self.assertRaises(KeyError): + model_eval_lib.analyze_raw_data(df_data, eval_config) + + def testAnalyzeRawDataWithoutLabel(self): + model_specs = [config_pb2.ModelSpec(prediction_key="nonexistent_label_key")] + metrics_specs = [ + config_pb2.MetricsSpec( + metrics=[config_pb2.MetricConfig(class_name="Accuracy")] + ) + ] + eval_config = config_pb2.EvalConfig( + model_specs=model_specs, metrics_specs=metrics_specs + ) + df_data = pd.DataFrame( + [ + { + "prediction": 0, + "label": 0, + } + ] + ) + with self.assertRaises(KeyError): + model_eval_lib.analyze_raw_data(df_data, eval_config) + + def testBytesProcessedCountForSerializedExamples(self): + examples = [ + self._makeExample(age=3.0, language="english", label=1.0), + self._makeExample(age=3.0, language="chinese", label=0.0), + self._makeExample(age=4.0, language="english", label=1.0), + self._makeExample(age=5.0, language="chinese", label=1.0), + self._makeExample(age=5.0, language="hindi", label=1.0), + ] + serialized_examples = [example.SerializeToString() for example in examples] + expected_num_bytes = sum([len(se) for se in serialized_examples]) + with beam.Pipeline() as p: + _ = ( + p + | beam.Create(serialized_examples) + | "InputsToExtracts" >> model_eval_lib.InputsToExtracts() + | "ExtractAndEvaluate" + >> model_eval_lib.ExtractAndEvaluate(extractors=[], evaluators=[]) + ) + pipeline_result = p.run() + metrics = pipeline_result.metrics() + actual_counter = metrics.query( + beam.metrics.metric.MetricsFilter().with_name("extract_input_bytes") + )["counters"] + self.assertLen(actual_counter, 1) + self.assertEqual(actual_counter[0].committed, expected_num_bytes) + + def testBytesProcessedCountForRecordBatches(self): + examples = [ + self._makeExample(age=3.0, language="english", label=1.0), + self._makeExample(age=3.0, language="chinese", label=0.0), + self._makeExample(age=4.0, language="english", label=1.0), + self._makeExample(age=5.0, language="chinese", label=1.0), + self._makeExample(age=5.0, language="hindi", label=1.0), + ] + examples = [example.SerializeToString() for example in examples] + decoder = example_coder.ExamplesToRecordBatchDecoder() + record_batch = decoder.DecodeBatch(examples) + expected_num_bytes = record_batch.nbytes + with beam.Pipeline() as p: + _ = ( + p + | beam.Create(record_batch) + | "BatchedInputsToExtracts" >> model_eval_lib.BatchedInputsToExtracts() + | "ExtractAndEvaluate" + >> model_eval_lib.ExtractAndEvaluate(extractors=[], evaluators=[]) + ) + pipeline_result = p.run() + metrics = pipeline_result.metrics() + actual_counter = metrics.query( + beam.metrics.metric.MetricsFilter().with_name("extract_input_bytes") + )[metrics.COUNTERS] + self.assertLen(actual_counter, 1) + self.assertEqual(actual_counter[0].committed, expected_num_bytes) + + +if __name__ == "__main__": + tf.compat.v1.enable_v2_behavior() + tf.test.main() diff --git a/tensorflow_model_analysis/api/types.py b/tensorflow_model_analysis/api/types.py index 81fe98fedd..0d285ad8a9 100644 --- a/tensorflow_model_analysis/api/types.py +++ b/tensorflow_model_analysis/api/types.py @@ -16,191 +16,203 @@ import abc import datetime import operator -from typing import Any, Callable, Dict, Iterable, List, MutableMapping, NamedTuple, Optional, Tuple, TypeVar, Union +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + MutableMapping, + NamedTuple, + Optional, + Tuple, + TypeVar, + Union, +) -from apache_beam.utils import shared import numpy as np import six import tensorflow as tf +from apache_beam.utils import shared + from tensorflow_model_analysis.proto import metrics_for_slice_pb2 class RaggedTensorValue( NamedTuple( - 'RaggedTensorValue', - [('values', np.ndarray), ('nested_row_splits', List[np.ndarray])], + "RaggedTensorValue", + [("values", np.ndarray), ("nested_row_splits", List[np.ndarray])], ) ): - """RaggedTensorValue encapsulates a batch of ragged tensor values. + """RaggedTensorValue encapsulates a batch of ragged tensor values. - Attributes: - values: A np.ndarray of values. - nested_row_splits: A list of np.ndarray values representing the row splits - (one per dimension including the batch dimension). - """ + Attributes + ---------- + values: A np.ndarray of values. + nested_row_splits: A list of np.ndarray values representing the row splits + (one per dimension including the batch dimension). + """ class SparseTensorValue( NamedTuple( - 'SparseTensorValue', + "SparseTensorValue", [ - ('values', np.ndarray), - ('indices', np.ndarray), - ('dense_shape', np.ndarray), + ("values", np.ndarray), + ("indices", np.ndarray), + ("dense_shape", np.ndarray), ], ) ): - """SparseTensorValue encapsulates a batch of sparse tensor values. + """SparseTensorValue encapsulates a batch of sparse tensor values. - Attributes: - values: A np.ndarray of values. - indices: A np.ndarray of indices. - dense_shape: A np.ndarray representing the dense shape. - """ + Attributes + ---------- + values: A np.ndarray of values. + indices: A np.ndarray of indices. + dense_shape: A np.ndarray representing the dense shape. + """ class VarLenTensorValue( NamedTuple( - 'VarLenTensorValue', + "VarLenTensorValue", [ - ('values', np.ndarray), - ('indices', np.ndarray), - ('dense_shape', np.ndarray), + ("values", np.ndarray), + ("indices", np.ndarray), + ("dense_shape", np.ndarray), ], ) ): - """VarLenTensorValue encapsulates a batch of varlen dense tensor values. - - Attributes: - values: A np.ndarray of values. - indices: A np.ndarray of indices. - dense_shape: A np.ndarray representing the dense shape of the entire tensor. - Note that each row (i.e. set of values sharing the same value for the - first / batch dimension) is considered to have its own shape based on the - presence of values. - """ - - def __new__( - cls, values: np.ndarray, indices: np.ndarray, dense_shape: np.ndarray - ): - # we keep the sparse representation despite not needing it so that we can - # convert back to TF sparse tensors for free. - if len(dense_shape) != 2: - raise ValueError( - 'A VarLenTensorValue can only be used to represent a ' - '2D tensor in which the size of the second dimension ' - 'varies over rows. However, the provided dense_shape ' - f'({dense_shape}) implies a {dense_shape.size}D tensor' - ) - row_index_diffs = np.diff(indices[:, 0]) - column_index_diffs = np.diff(indices[:, 1]) - # Enforce row-major ordering of indices by checking that row indices are - # always increasing, and column indices within the same row are also always - # increasing. - bad_index_mask = (row_index_diffs < 0) | ( - (row_index_diffs == 0) & (column_index_diffs < 0) - ) - if np.any(bad_index_mask): - raise ValueError( - 'The values and indices arrays must be provided in a ' - 'row major order, and represent a set of variable ' - 'length dense lists. However, indices[' - f'{np.nonzero(bad_index_mask)[0] + 1}, :] did not ' - 'follow this pattern. The full indices array was: ' - f'{indices}.' - ) - return super().__new__( - cls, values=values, indices=indices, dense_shape=dense_shape - ) - - class DenseRowIterator: - """An Iterator over rows of a VarLenTensorValue as dense np.arrays. - - Because the VarLenTensorValue was created from a set of variable length - (dense) arrays, we can invert this process to turn a VarLenTensorValue back - into the original dense arrays. + """VarLenTensorValue encapsulates a batch of varlen dense tensor values. + + Attributes + ---------- + values: A np.ndarray of values. + indices: A np.ndarray of indices. + dense_shape: A np.ndarray representing the dense shape of the entire tensor. + Note that each row (i.e. set of values sharing the same value for the + first / batch dimension) is considered to have its own shape based on the + presence of values. """ - def __init__(self, tensor): - self._tensor = tensor - self._offset = 0 - - def __iter__(self): - return self - - def __next__(self): - if ( - not self._tensor.indices.size - or self._offset >= self._tensor.dense_shape[0] - ): - raise StopIteration - row_mask = self._tensor.indices[:, 0] == self._offset - self._offset += 1 - if not row_mask.any(): - # handle empty rows - return np.array([]) - # we rely on slice indexing (a[start:end] rather than fancy indexing - # (a[mask]) to avoid making a copy of each row. For details, see: - # https://scipy-cookbook.readthedocs.io/items/ViewsVsCopies.html - row_mask_indices = np.nonzero(row_mask)[0] - row_start_index = row_mask_indices[0] - row_end = row_mask_indices[-1] + 1 - assert (row_end - row_start_index) == len(row_mask_indices), ( - 'The values for each row in the represented tensor must be ' - 'contiguous in the values and indices arrays but found ' - f'row_start_index: {row_start_index}, row_end: {row_end}' - f'len(row_mask_indices): {len(row_mask_indices)}' - ) - return self._tensor.values[row_start_index:row_end] - - def dense_rows(self): - return self.DenseRowIterator(self) - - @classmethod - def from_dense_rows( - cls, dense_rows: Iterable[np.ndarray] - ) -> 'VarLenTensorValue': - """Converts a collection of variable length dense arrays into a tensor. - - Args: - dense_rows: A sequence of possibly variable length 1D arrays. - - Returns: - A new VarLenTensorValue containing the sparse representation of the - vertically stacked dense rows. The dense_shape attribute on the result - will be (num_rows, max_row_len). - """ - rows = [] - index_arrays = [] - max_row_len = 0 - num_rows = 0 - for i, row in enumerate(dense_rows): - num_rows += 1 - if row.size: - if row.ndim <= 1: - # Add a dimension for unsized numpy array. This will solve the problem - # where scalar numpy arrays like np.array(None), np.array(0) can not - # be merged with other numpy arrays. - row = row.reshape(-1) - rows.append(row) + def __new__(cls, values: np.ndarray, indices: np.ndarray, dense_shape: np.ndarray): + # we keep the sparse representation despite not needing it so that we can + # convert back to TF sparse tensors for free. + if len(dense_shape) != 2: + raise ValueError( + "A VarLenTensorValue can only be used to represent a " + "2D tensor in which the size of the second dimension " + "varies over rows. However, the provided dense_shape " + f"({dense_shape}) implies a {dense_shape.size}D tensor" + ) + row_index_diffs = np.diff(indices[:, 0]) + column_index_diffs = np.diff(indices[:, 1]) + # Enforce row-major ordering of indices by checking that row indices are + # always increasing, and column indices within the same row are also always + # increasing. + bad_index_mask = (row_index_diffs < 0) | ( + (row_index_diffs == 0) & (column_index_diffs < 0) + ) + if np.any(bad_index_mask): + raise ValueError( + "The values and indices arrays must be provided in a " + "row major order, and represent a set of variable " + "length dense lists. However, indices[" + f"{np.nonzero(bad_index_mask)[0] + 1}, :] did not " + "follow this pattern. The full indices array was: " + f"{indices}." + ) + return super().__new__( + cls, values=values, indices=indices, dense_shape=dense_shape + ) + + class DenseRowIterator: + """An Iterator over rows of a VarLenTensorValue as dense np.arrays. + + Because the VarLenTensorValue was created from a set of variable length + (dense) arrays, we can invert this process to turn a VarLenTensorValue back + into the original dense arrays. + """ + + def __init__(self, tensor): + self._tensor = tensor + self._offset = 0 + + def __iter__(self): + return self + + def __next__(self): + if ( + not self._tensor.indices.size + or self._offset >= self._tensor.dense_shape[0] + ): + raise StopIteration + row_mask = self._tensor.indices[:, 0] == self._offset + self._offset += 1 + if not row_mask.any(): + # handle empty rows + return np.array([]) + # we rely on slice indexing (a[start:end] rather than fancy indexing + # (a[mask]) to avoid making a copy of each row. For details, see: + # https://scipy-cookbook.readthedocs.io/items/ViewsVsCopies.html + row_mask_indices = np.nonzero(row_mask)[0] + row_start_index = row_mask_indices[0] + row_end = row_mask_indices[-1] + 1 + assert (row_end - row_start_index) == len(row_mask_indices), ( + "The values for each row in the represented tensor must be " + "contiguous in the values and indices arrays but found " + f"row_start_index: {row_start_index}, row_end: {row_end}" + f"len(row_mask_indices): {len(row_mask_indices)}" + ) + return self._tensor.values[row_start_index:row_end] + + def dense_rows(self): + return self.DenseRowIterator(self) + + @classmethod + def from_dense_rows(cls, dense_rows: Iterable[np.ndarray]) -> "VarLenTensorValue": + """Converts a collection of variable length dense arrays into a tensor. + + Args: + ---- + dense_rows: A sequence of possibly variable length 1D arrays. + + Returns: + ------- + A new VarLenTensorValue containing the sparse representation of the + vertically stacked dense rows. The dense_shape attribute on the result + will be (num_rows, max_row_len). + """ + rows = [] + index_arrays = [] + max_row_len = 0 + num_rows = 0 + for i, row in enumerate(dense_rows): + num_rows += 1 + if row.size: + if row.ndim <= 1: + # Add a dimension for unsized numpy array. This will solve the problem + # where scalar numpy arrays like np.array(None), np.array(0) can not + # be merged with other numpy arrays. + row = row.reshape(-1) + rows.append(row) + else: + raise ValueError( + "Each non-empty dense row should be 1D or scalar but" + f" found row with shape {row.shape}." + ) + index_arrays.append(np.array([[i, j] for j in range(len(row))])) + max_row_len = max(max_row_len, row.size) + if index_arrays: + values = np.concatenate(rows, axis=0) + indices = np.concatenate(index_arrays, axis=0) else: - raise ValueError( - 'Each non-empty dense row should be 1D or scalar but' - f' found row with shape {row.shape}.' - ) - index_arrays.append(np.array([[i, j] for j in range(len(row))])) - max_row_len = max(max_row_len, row.size) - if index_arrays: - values = np.concatenate(rows, axis=0) - indices = np.concatenate(index_arrays, axis=0) - else: - # empty case - values = np.array([]) - indices = np.empty((0, 2)) - dense_shape = np.array([num_rows, max_row_len]) - return cls.__new__( - cls, values=values, indices=indices, dense_shape=dense_shape - ) + # empty case + values = np.array([]) + indices = np.empty((0, 2)) + dense_shape = np.array([num_rows, max_row_len]) + return cls.__new__(cls, values=values, indices=indices, dense_shape=dense_shape) # pylint: disable=invalid-name @@ -210,9 +222,7 @@ def from_dense_rows( DictOfTensorType = Dict[str, TensorType] TensorTypeMaybeDict = Union[TensorType, DictOfTensorType] DictOfTensorTypeMaybeDict = Dict[str, TensorTypeMaybeDict] -TensorTypeMaybeMultiLevelDict = Union[ - TensorTypeMaybeDict, DictOfTensorTypeMaybeDict -] +TensorTypeMaybeMultiLevelDict = Union[TensorTypeMaybeDict, DictOfTensorTypeMaybeDict] DictOfTypeSpec = Dict[str, tf.TypeSpec] TypeSpecMaybeDict = Union[tf.TypeSpec, DictOfTypeSpec] @@ -229,181 +239,180 @@ def from_dense_rows( DictOfTensorValue = Dict[str, TensorValue] TensorValueMaybeDict = Union[TensorValue, DictOfTensorValue] DictOfTensorValueMaybeDict = Dict[str, TensorValueMaybeDict] -TensorValueMaybeMultiLevelDict = Union[ - TensorValueMaybeDict, DictOfTensorValueMaybeDict -] +TensorValueMaybeMultiLevelDict = Union[TensorValueMaybeDict, DictOfTensorValueMaybeDict] MetricVariablesType = List[Any] PrimitiveMetricValueType = Union[float, int, np.number] ConcreteStructuredMetricValue = TypeVar( - 'ConcreteStructuredMetricValue', bound='StructuredMetricValue' + "ConcreteStructuredMetricValue", bound="StructuredMetricValue" ) class StructuredMetricValue(abc.ABC): - """The base class for all structured metrics used within TFMA. - - This class allows custom metrics to control how proto serialization happens, - and how to handle basic algebraic operations used in computing confidence - intervals and model diffs. By implementing the _apply_binary_op methods, - subclasses can then be treated like primitive numeric types. - """ - - @abc.abstractmethod - def to_proto(self) -> metrics_for_slice_pb2.MetricValue: - ... - - @abc.abstractmethod - def _apply_binary_op_elementwise( - self: ConcreteStructuredMetricValue, - other: ConcreteStructuredMetricValue, - op: Callable[[float, float], float], - ) -> ConcreteStructuredMetricValue: - """Applies the binary operator elementwise on self and `other`. - - Given two structures of the same type, this function's job is to find - corresponding pairs of elements within both structures, invoke `op` on each - pair, and store the result in a corresponding location within a new - structure. For example, to implement for a list, this function could be - implemented as: - - return [op(elem, other_elem) for elem, other_elem in zip(self, other)] - - Args: - other: A structure containing elements which should be the second operand - when applying `op`. `Other` must be a structured metric of the same type - as self. - op: A binary operator which should be applied elementwise to corresponding - primitive values in self and `other`. - - Returns: - A new structured metric that is the result of elementwise applying `op` - on corresponding elements within self and `other`. - """ - ... - - @abc.abstractmethod - def _apply_binary_op_broadcast( - self: ConcreteStructuredMetricValue, - other: float, - op: Callable[[float, float], float], - ) -> ConcreteStructuredMetricValue: - """Applies the binary operator on each element in self and a single float. - - This function supports broadcasting operations on the structured metric by - applying `op` on each element in self, paired with the primitive value - `other`. This makes it possible do things like add a fixed quantity to every - element in a structure. For example, to implement for a list, this function - could be implemented as: - - return [op(elem, other) for elem in self] - - Args: - other: The value to be used as the second operand when applying `op`. - op: A binary operator which should be applied elementwise to each element - in self and `other`. - - Returns: - A new structured metric that is the result of applying `op` on each - element within self and a single value, `other`. + """The base class for all structured metrics used within TFMA. + + This class allows custom metrics to control how proto serialization happens, + and how to handle basic algebraic operations used in computing confidence + intervals and model diffs. By implementing the _apply_binary_op methods, + subclasses can then be treated like primitive numeric types. """ - ... - - def _apply_binary_op( - self: ConcreteStructuredMetricValue, - other: Union[PrimitiveMetricValueType, ConcreteStructuredMetricValue], - op: Callable[[float, float], float], - ) -> ConcreteStructuredMetricValue: - if type(other) is type(self): # pylint: disable=unidiomatic-typecheck - return self._apply_binary_op_elementwise(other, op) - elif isinstance(other, (float, int, np.number)): - return self._apply_binary_op_broadcast(float(other), op) - else: - raise ValueError( - 'Binary ops can only be applied elementwise on two instances of the ' - 'same StructuredMetricValue subclass or using broadcasting with one ' - 'StructuredMetricValue and a primitive numeric type (int, float, ' - 'np.number). Cannot apply binary op on objects of type ' - '{} and {}'.format(type(self), type(other)) - ) - - def __add__( - self: ConcreteStructuredMetricValue, - other: Union[ConcreteStructuredMetricValue, float], - ): - return self._apply_binary_op(other, operator.add) - - def __sub__( - self: ConcreteStructuredMetricValue, - other: Union[ConcreteStructuredMetricValue, float], - ): - return self._apply_binary_op(other, operator.sub) - - def __mul__( - self: ConcreteStructuredMetricValue, - other: Union[ConcreteStructuredMetricValue, float], - ): - return self._apply_binary_op(other, operator.mul) - - def __truediv__( - self: ConcreteStructuredMetricValue, - other: Union[ConcreteStructuredMetricValue, float], - ): - return self._apply_binary_op(other, operator.truediv) - - def __pow__( - self: ConcreteStructuredMetricValue, - other: Union[ConcreteStructuredMetricValue, float], - ): - return self._apply_binary_op(other, operator.pow) - - -MetricValueType = Union[ - PrimitiveMetricValueType, np.ndarray, StructuredMetricValue -] + + @abc.abstractmethod + def to_proto(self) -> metrics_for_slice_pb2.MetricValue: ... + + @abc.abstractmethod + def _apply_binary_op_elementwise( + self: ConcreteStructuredMetricValue, + other: ConcreteStructuredMetricValue, + op: Callable[[float, float], float], + ) -> ConcreteStructuredMetricValue: + """Applies the binary operator elementwise on self and `other`. + + Given two structures of the same type, this function's job is to find + corresponding pairs of elements within both structures, invoke `op` on each + pair, and store the result in a corresponding location within a new + structure. For example, to implement for a list, this function could be + implemented as: + + return [op(elem, other_elem) for elem, other_elem in zip(self, other)] + + Args: + ---- + other: A structure containing elements which should be the second operand + when applying `op`. `Other` must be a structured metric of the same type + as self. + op: A binary operator which should be applied elementwise to corresponding + primitive values in self and `other`. + + Returns: + ------- + A new structured metric that is the result of elementwise applying `op` + on corresponding elements within self and `other`. + """ + ... + + @abc.abstractmethod + def _apply_binary_op_broadcast( + self: ConcreteStructuredMetricValue, + other: float, + op: Callable[[float, float], float], + ) -> ConcreteStructuredMetricValue: + """Applies the binary operator on each element in self and a single float. + + This function supports broadcasting operations on the structured metric by + applying `op` on each element in self, paired with the primitive value + `other`. This makes it possible do things like add a fixed quantity to every + element in a structure. For example, to implement for a list, this function + could be implemented as: + + return [op(elem, other) for elem in self] + + Args: + ---- + other: The value to be used as the second operand when applying `op`. + op: A binary operator which should be applied elementwise to each element + in self and `other`. + + Returns: + ------- + A new structured metric that is the result of applying `op` on each + element within self and a single value, `other`. + """ + ... + + def _apply_binary_op( + self: ConcreteStructuredMetricValue, + other: Union[PrimitiveMetricValueType, ConcreteStructuredMetricValue], + op: Callable[[float, float], float], + ) -> ConcreteStructuredMetricValue: + if type(other) is type(self): # pylint: disable=unidiomatic-typecheck + return self._apply_binary_op_elementwise(other, op) + elif isinstance(other, (float, int, np.number)): + return self._apply_binary_op_broadcast(float(other), op) + else: + raise ValueError( + "Binary ops can only be applied elementwise on two instances of the " + "same StructuredMetricValue subclass or using broadcasting with one " + "StructuredMetricValue and a primitive numeric type (int, float, " + "np.number). Cannot apply binary op on objects of type " + f"{type(self)} and {type(other)}" + ) + + def __add__( + self: ConcreteStructuredMetricValue, + other: Union[ConcreteStructuredMetricValue, float], + ): + return self._apply_binary_op(other, operator.add) + + def __sub__( + self: ConcreteStructuredMetricValue, + other: Union[ConcreteStructuredMetricValue, float], + ): + return self._apply_binary_op(other, operator.sub) + + def __mul__( + self: ConcreteStructuredMetricValue, + other: Union[ConcreteStructuredMetricValue, float], + ): + return self._apply_binary_op(other, operator.mul) + + def __truediv__( + self: ConcreteStructuredMetricValue, + other: Union[ConcreteStructuredMetricValue, float], + ): + return self._apply_binary_op(other, operator.truediv) + + def __pow__( + self: ConcreteStructuredMetricValue, + other: Union[ConcreteStructuredMetricValue, float], + ): + return self._apply_binary_op(other, operator.pow) + + +MetricValueType = Union[PrimitiveMetricValueType, np.ndarray, StructuredMetricValue] class ValueWithTDistribution( NamedTuple( - 'ValueWithTDistribution', + "ValueWithTDistribution", [ - ('sample_mean', MetricValueType), - ('sample_standard_deviation', MetricValueType), - ('sample_degrees_of_freedom', int), - ('unsampled_value', MetricValueType), + ("sample_mean", MetricValueType), + ("sample_standard_deviation", MetricValueType), + ("sample_degrees_of_freedom", int), + ("unsampled_value", MetricValueType), ], ) ): - r"""Represents the t-distribution value. - - It includes sample_mean, sample_standard_deviation, - sample_degrees_of_freedom. And also unsampled_value is also stored here to - record the value calculated without bootstrapping. - The sample_standard_deviation is calculated as: - \sqrt{ \frac{1}{N-1} \sum_{i=1}^{N}{(x_i - \bar{x})^2} } - """ - - def __new__( - cls, - sample_mean: float, - sample_standard_deviation: Optional[float] = None, - sample_degrees_of_freedom: Optional[int] = None, - unsampled_value: Optional[float] = None, - ): - return super(ValueWithTDistribution, cls).__new__( - cls, - sample_mean, - sample_standard_deviation, - sample_degrees_of_freedom, - unsampled_value, - ) + r"""Represents the t-distribution value. - def __float__(self): - # unsampled_value can be numpy.float which is a subclass of float, but here - # need to return a strict float. - return float(self.unsampled_value) + It includes sample_mean, sample_standard_deviation, + sample_degrees_of_freedom. And also unsampled_value is also stored here to + record the value calculated without bootstrapping. + The sample_standard_deviation is calculated as: + \sqrt{ \frac{1}{N-1} \sum_{i=1}^{N}{(x_i - \bar{x})^2} } + """ + + def __new__( + cls, + sample_mean: float, + sample_standard_deviation: Optional[float] = None, + sample_degrees_of_freedom: Optional[int] = None, + unsampled_value: Optional[float] = None, + ): + return super(ValueWithTDistribution, cls).__new__( + cls, + sample_mean, + sample_standard_deviation, + sample_degrees_of_freedom, + unsampled_value, + ) + + def __float__(self): + # unsampled_value can be numpy.float which is a subclass of float, but here + # need to return a strict float. + return float(self.unsampled_value) # AddMetricsCallback should have the following prototype: @@ -416,9 +425,10 @@ def __float__(self): # necessarily dictionaries - they might also be Tensors, depending on what the # model's eval_input_receiver_fn returns. # pyformat: disable -AddMetricsCallbackType = Callable[[ - TensorTypeMaybeDict, TensorTypeMaybeDict, TensorTypeMaybeDict -], Dict[str, Tuple[TensorType, TensorType]]] +AddMetricsCallbackType = Callable[ + [TensorTypeMaybeDict, TensorTypeMaybeDict, TensorTypeMaybeDict], + Dict[str, Tuple[TensorType, TensorType]], +] # pyformat: enable # Type of keys we support for prediction, label and features dictionaries. @@ -429,30 +439,22 @@ def __float__(self): # FeaturesPredictionsLabels, new code should use DictOfTensorValue instead. DictOfFetchedTensorValues = Dict[FPLKeyType, Dict[str, TensorValue]] -FeaturesPredictionsLabels = NamedTuple( - 'FeaturesPredictionsLabels', - [ - ('input_ref', int), - ('features', DictOfFetchedTensorValues), - ('predictions', DictOfFetchedTensorValues), - ('labels', DictOfFetchedTensorValues), - ], -) + +class FeaturesPredictionsLabels(NamedTuple): + input_ref: int + features: DictOfFetchedTensorValues + predictions: DictOfFetchedTensorValues + labels: DictOfFetchedTensorValues + # Used in building the model diagnostics table, a MaterializedColumn is a value # inside of Extracts that will be emitted to file. Note that for strings, the # values are raw byte strings rather than unicode strings. This is by design, as # features can have arbitrary bytes values. -MaterializedColumn = NamedTuple( - 'MaterializedColumn', - [ - ('name', str), - ( - 'value', - Union[List[bytes], List[int], List[float], bytes, int, float], - ), - ], -) +class MaterializedColumn(NamedTuple): + name: str + value: Union[List[bytes], List[int], List[float], bytes, int, float] + # Extracts represent data extracted during pipeline processing. In order to # provide a flexible API, these types are just dicts where the keys are defined @@ -465,172 +467,174 @@ def __float__(self): class ModelLoader: - """Model loader is responsible for loading shared model types. - - Attributes: - construct_fn: A callable which creates the model instance. The callable - should take no args as input (typically a closure is used to capture - necessary parameters). - tags: Optional model tags (e.g. 'serve' for serving or 'eval' for - EvalSavedModel). - """ - - __slots__ = ['construct_fn', 'tags', '_shared_handle'] - - def __init__( - self, construct_fn: Callable[[], Any], tags: Optional[List[str]] = None - ): - self.construct_fn = construct_fn - self.tags = tags - self._shared_handle = shared.Shared() - - def load( - self, model_load_time_callback: Optional[Callable[[int], None]] = None - ) -> Any: - """Returns loaded model. - - Args: - model_load_time_callback: Optional callback to track load time. + """Model loader is responsible for loading shared model types. + + Attributes + ---------- + construct_fn: A callable which creates the model instance. The callable + should take no args as input (typically a closure is used to capture + necessary parameters). + tags: Optional model tags (e.g. 'serve' for serving or 'eval' for + EvalSavedModel). """ - if model_load_time_callback: - construct_fn = self._construct_fn_with_load_time(model_load_time_callback) - else: - construct_fn = self.construct_fn - return self._shared_handle.acquire(construct_fn) - def _construct_fn_with_load_time( - self, model_load_time_callback: Callable[[int], None] - ) -> Callable[[], Any]: - """Wraps actual construct fn to allow for load time metrics.""" + __slots__ = ["construct_fn", "tags", "_shared_handle"] + + def __init__( + self, construct_fn: Callable[[], Any], tags: Optional[List[str]] = None + ): + self.construct_fn = construct_fn + self.tags = tags + self._shared_handle = shared.Shared() + + def load( + self, model_load_time_callback: Optional[Callable[[int], None]] = None + ) -> Any: + """Returns loaded model. + + Args: + ---- + model_load_time_callback: Optional callback to track load time. + """ + if model_load_time_callback: + construct_fn = self._construct_fn_with_load_time(model_load_time_callback) + else: + construct_fn = self.construct_fn + return self._shared_handle.acquire(construct_fn) + + def _construct_fn_with_load_time( + self, model_load_time_callback: Callable[[int], None] + ) -> Callable[[], Any]: + """Wraps actual construct fn to allow for load time metrics.""" - def with_load_times(): - start_time = datetime.datetime.now() - model = self.construct_fn() - end_time = datetime.datetime.now() - model_load_time_callback(int((end_time - start_time).total_seconds())) - return model + def with_load_times(): + start_time = datetime.datetime.now() + model = self.construct_fn() + end_time = datetime.datetime.now() + model_load_time_callback(int((end_time - start_time).total_seconds())) + return model - return with_load_times + return with_load_times class EvalSharedModel( NamedTuple( - 'EvalSharedModel', + "EvalSharedModel", [ - ('model_path', str), + ("model_path", str), ( - 'add_metrics_callbacks', + "add_metrics_callbacks", List[Callable], ), # List[AnyMetricsCallbackType] - ('include_default_metrics', bool), - ('example_weight_key', Union[str, Dict[str, str]]), - ('additional_fetches', List[str]), - ('model_loader', ModelLoader), - ('model_name', str), - ('model_type', str), - ('rubber_stamp', bool), - ('is_baseline', bool), - ('resource_hints', Optional[Dict[str, Any]]), - ('backend_config', Optional[Any]), + ("include_default_metrics", bool), + ("example_weight_key", Union[str, Dict[str, str]]), + ("additional_fetches", List[str]), + ("model_loader", ModelLoader), + ("model_name", str), + ("model_type", str), + ("rubber_stamp", bool), + ("is_baseline", bool), + ("resource_hints", Optional[Dict[str, Any]]), + ("backend_config", Optional[Any]), ], ) ): - # pyformat: disable - """Shared model used during extraction and evaluation. - - Attributes: - model_path: Path to EvalSavedModel (containing the saved_model.pb file). - add_metrics_callbacks: Optional list of callbacks for adding additional - metrics to the graph. The names of the metrics added by the callbacks - should not conflict with existing metrics. See below for more details - about what each callback should do. The callbacks are only used during - evaluation. - include_default_metrics: True to include the default metrics that are part - of the saved model graph during evaluation. - example_weight_key: Example weight key (single-output model) or dict of - example weight keys (multi-output model) keyed by output_name. - additional_fetches: Prefixes of additional tensors stored in - signature_def.inputs that should be fetched at prediction time. The - "features" and "labels" tensors are handled automatically and should not - be included in this list. - model_loader: Model loader. - model_name: Model name (should align with ModelSpecs.name). - model_type: Model type (tfma.TF_KERAS, tfma.TF_LITE, tfma.TF_ESTIMATOR, ..). - rubber_stamp: True if this model is being rubber stamped. When a - model is rubber stamped diff thresholds will be ignored if an associated - baseline model is not passed. - is_baseline: The model is the baseline for comparison or not. - resource_hints: The beam resource hints to apply to the PTransform which - runs inference for this model. - backend_config: The backend config for running model inference. - - - More details on add_metrics_callbacks: - - Each add_metrics_callback should have the following prototype: - def add_metrics_callback(features_dict, predictions_dict, labels_dict): - - Note that features_dict, predictions_dict and labels_dict are not - necessarily dictionaries - they might also be Tensors, depending on what the - model's eval_input_receiver_fn returns. - - It should create and return a metric_ops dictionary, such that - metric_ops['metric_name'] = (value_op, update_op), just as in the Trainer. - - Short example: - - def add_metrics_callback(features_dict, predictions_dict, labels): - metrics_ops = {} - metric_ops['mean_label'] = tf.metrics.mean(labels) - metric_ops['mean_probability'] = tf.metrics.mean(tf.slice( - predictions_dict['probabilities'], [0, 1], [2, 1])) - return metric_ops - """ - # pyformat: enable - - def __new__( - cls, - model_path: Optional[str] = None, - add_metrics_callbacks: Optional[List[AddMetricsCallbackType]] = None, - include_default_metrics: Optional[bool] = True, - example_weight_key: Optional[Union[str, Dict[str, str]]] = None, - additional_fetches: Optional[List[str]] = None, - model_loader: Optional[ModelLoader] = None, - model_name: str = '', - model_type: str = '', - rubber_stamp: bool = False, - is_baseline: bool = False, - resource_hints: Optional[Dict[str, Any]] = None, - backend_config: Optional[Any] = None, - construct_fn: Optional[Callable[[], Any]] = None, - ): - if not add_metrics_callbacks: - add_metrics_callbacks = [] - if model_loader and construct_fn: - raise ValueError( - 'only one of model_loader or construct_fn should be used' - ) - if construct_fn: - model_loader = ModelLoader(tags=None, construct_fn=construct_fn) - if model_path is not None: - model_path = six.ensure_str(model_path) - if is_baseline and rubber_stamp: - raise ValueError('Baseline model cannot be rubber stamped.') - return super(EvalSharedModel, cls).__new__( + # pyformat: disable + """Shared model used during extraction and evaluation. + + Attributes + ---------- + model_path: Path to EvalSavedModel (containing the saved_model.pb file). + add_metrics_callbacks: Optional list of callbacks for adding additional + metrics to the graph. The names of the metrics added by the callbacks + should not conflict with existing metrics. See below for more details + about what each callback should do. The callbacks are only used during + evaluation. + include_default_metrics: True to include the default metrics that are part + of the saved model graph during evaluation. + example_weight_key: Example weight key (single-output model) or dict of + example weight keys (multi-output model) keyed by output_name. + additional_fetches: Prefixes of additional tensors stored in + signature_def.inputs that should be fetched at prediction time. The + "features" and "labels" tensors are handled automatically and should not + be included in this list. + model_loader: Model loader. + model_name: Model name (should align with ModelSpecs.name). + model_type: Model type (tfma.TF_KERAS, tfma.TF_LITE, tfma.TF_ESTIMATOR, ..). + rubber_stamp: True if this model is being rubber stamped. When a + model is rubber stamped diff thresholds will be ignored if an associated + baseline model is not passed. + is_baseline: The model is the baseline for comparison or not. + resource_hints: The beam resource hints to apply to the PTransform which + runs inference for this model. + backend_config: The backend config for running model inference. + + + More details on add_metrics_callbacks: + + Each add_metrics_callback should have the following prototype: + def add_metrics_callback(features_dict, predictions_dict, labels_dict): + + Note that features_dict, predictions_dict and labels_dict are not + necessarily dictionaries - they might also be Tensors, depending on what the + model's eval_input_receiver_fn returns. + + It should create and return a metric_ops dictionary, such that + metric_ops['metric_name'] = (value_op, update_op), just as in the Trainer. + + Short example: + + def add_metrics_callback(features_dict, predictions_dict, labels): + metrics_ops = {} + metric_ops['mean_label'] = tf.metrics.mean(labels) + metric_ops['mean_probability'] = tf.metrics.mean(tf.slice( + predictions_dict['probabilities'], [0, 1], [2, 1])) + return metric_ops + """ + + # pyformat: enable + + def __new__( cls, - model_path, - add_metrics_callbacks, - include_default_metrics, - example_weight_key, - additional_fetches, - model_loader, - model_name, - model_type, - rubber_stamp, - is_baseline, - resource_hints, - backend_config, - ) + model_path: Optional[str] = None, + add_metrics_callbacks: Optional[List[AddMetricsCallbackType]] = None, + include_default_metrics: Optional[bool] = True, + example_weight_key: Optional[Union[str, Dict[str, str]]] = None, + additional_fetches: Optional[List[str]] = None, + model_loader: Optional[ModelLoader] = None, + model_name: str = "", + model_type: str = "", + rubber_stamp: bool = False, + is_baseline: bool = False, + resource_hints: Optional[Dict[str, Any]] = None, + backend_config: Optional[Any] = None, + construct_fn: Optional[Callable[[], Any]] = None, + ): + if not add_metrics_callbacks: + add_metrics_callbacks = [] + if model_loader and construct_fn: + raise ValueError("only one of model_loader or construct_fn should be used") + if construct_fn: + model_loader = ModelLoader(tags=None, construct_fn=construct_fn) + if model_path is not None: + model_path = six.ensure_str(model_path) + if is_baseline and rubber_stamp: + raise ValueError("Baseline model cannot be rubber stamped.") + return super(EvalSharedModel, cls).__new__( + cls, + model_path, + add_metrics_callbacks, + include_default_metrics, + example_weight_key, + additional_fetches, + model_loader, + model_name, + model_type, + rubber_stamp, + is_baseline, + resource_hints, + backend_config, + ) # MaybeMultipleEvalSharedModels represents a parameter that can take on a single @@ -642,37 +646,37 @@ def __new__( ] __all__ = [ - 'AddMetricsCallbackType', - 'ConcreteStructuredMetricValue', - 'DictOfFetchedTensorValues', - 'DictOfTensorType', - 'DictOfTensorTypeMaybeDict', - 'DictOfTensorValue', - 'DictOfTensorValueMaybeDict', - 'DictOfTypeSpec', - 'DictOfTypeSpecMaybeDict', - 'EvalSharedModel', - 'Extracts', - 'FeaturesPredictionsLabels', - 'FPLKeyType', - 'MaterializedColumn', - 'MaybeMultipleEvalSharedModels', - 'MetricValueType', - 'MetricVariablesType', - 'ModelLoader', - 'PrimitiveMetricValueType', - 'RaggedTensorValue', - 'SparseTensorValue', - 'StructuredMetricValue', - 'TensorOrOperationType', - 'TensorType', - 'TensorTypeMaybeDict', - 'TensorTypeMaybeMultiLevelDict', - 'TensorValue', - 'TensorValueMaybeDict', - 'TensorValueMaybeMultiLevelDict', - 'TypeSpecMaybeDict', - 'TypeSpecMaybeMultiLevelDict', - 'ValueWithTDistribution', - 'VarLenTensorValue' + "AddMetricsCallbackType", + "ConcreteStructuredMetricValue", + "DictOfFetchedTensorValues", + "DictOfTensorType", + "DictOfTensorTypeMaybeDict", + "DictOfTensorValue", + "DictOfTensorValueMaybeDict", + "DictOfTypeSpec", + "DictOfTypeSpecMaybeDict", + "EvalSharedModel", + "Extracts", + "FeaturesPredictionsLabels", + "FPLKeyType", + "MaterializedColumn", + "MaybeMultipleEvalSharedModels", + "MetricValueType", + "MetricVariablesType", + "ModelLoader", + "PrimitiveMetricValueType", + "RaggedTensorValue", + "SparseTensorValue", + "StructuredMetricValue", + "TensorOrOperationType", + "TensorType", + "TensorTypeMaybeDict", + "TensorTypeMaybeMultiLevelDict", + "TensorValue", + "TensorValueMaybeDict", + "TensorValueMaybeMultiLevelDict", + "TypeSpecMaybeDict", + "TypeSpecMaybeMultiLevelDict", + "ValueWithTDistribution", + "VarLenTensorValue", ] diff --git a/tensorflow_model_analysis/api/types_test.py b/tensorflow_model_analysis/api/types_test.py index 2cc1cf12c9..c7c2db45f1 100644 --- a/tensorflow_model_analysis/api/types_test.py +++ b/tensorflow_model_analysis/api/types_test.py @@ -13,83 +13,85 @@ # limitations under the License. """Tests for types.""" -from absl.testing import absltest import numpy as np +from absl.testing import absltest + from tensorflow_model_analysis.api import types class TypesTest(absltest.TestCase): + def testVarLenTensorValueFromDenseRows(self): + tensor = types.VarLenTensorValue.from_dense_rows( + [np.array([]), np.array([1]), np.array([1, 2])] + ) + np.testing.assert_array_equal(np.array([1, 1, 2]), tensor.values) + np.testing.assert_array_equal( + np.array([[1, 0], [2, 0], [2, 1]]), tensor.indices + ) + np.testing.assert_array_equal(np.array([3, 2]), tensor.dense_shape) - def testVarLenTensorValueFromDenseRows(self): - tensor = types.VarLenTensorValue.from_dense_rows( - [np.array([]), np.array([1]), np.array([1, 2])] - ) - np.testing.assert_array_equal(np.array([1, 1, 2]), tensor.values) - np.testing.assert_array_equal( - np.array([[1, 0], [2, 0], [2, 1]]), tensor.indices - ) - np.testing.assert_array_equal(np.array([3, 2]), tensor.dense_shape) - - def testVarLenTensorValueToDenseRows(self): - tensor = types.VarLenTensorValue( - values=np.array([1, 2, 3, 4]), - indices=np.array([[0, 0], [0, 1], [2, 0], [2, 1]]), - dense_shape=np.array([3, 2]), - ) - dense_rows = list(tensor.dense_rows()) - self.assertLen(dense_rows, 3) - np.testing.assert_array_equal(np.array([1, 2]), dense_rows[0]) - np.testing.assert_array_equal(np.array([]), dense_rows[1]) - np.testing.assert_array_equal(np.array([3, 4]), dense_rows[2]) + def testVarLenTensorValueToDenseRows(self): + tensor = types.VarLenTensorValue( + values=np.array([1, 2, 3, 4]), + indices=np.array([[0, 0], [0, 1], [2, 0], [2, 1]]), + dense_shape=np.array([3, 2]), + ) + dense_rows = list(tensor.dense_rows()) + self.assertLen(dense_rows, 3) + np.testing.assert_array_equal(np.array([1, 2]), dense_rows[0]) + np.testing.assert_array_equal(np.array([]), dense_rows[1]) + np.testing.assert_array_equal(np.array([3, 4]), dense_rows[2]) - def testVarLenTensorValueInvalidShape(self): - with self.assertRaisesRegex( - ValueError, r'A VarLenTensorValue .* \(\[2 2 2\]\)' - ): - types.VarLenTensorValue( - values=np.array([1, 2, 3, 4, 5, 6, 7, 8]), - indices=np.array([ - [0, 0, 0], - [0, 0, 1], - [0, 1, 0], - [0, 1, 1], - [1, 0, 0], - [1, 0, 1], - [1, 1, 0], - [1, 1, 1], - ]), - dense_shape=np.array([2, 2, 2]), - ) + def testVarLenTensorValueInvalidShape(self): + with self.assertRaisesRegex( + ValueError, r"A VarLenTensorValue .* \(\[2 2 2\]\)" + ): + types.VarLenTensorValue( + values=np.array([1, 2, 3, 4, 5, 6, 7, 8]), + indices=np.array( + [ + [0, 0, 0], + [0, 0, 1], + [0, 1, 0], + [0, 1, 1], + [1, 0, 0], + [1, 0, 1], + [1, 1, 0], + [1, 1, 1], + ] + ), + dense_shape=np.array([2, 2, 2]), + ) - def testVarLenTensorValueInvalidRowIndices(self): - with self.assertRaisesRegex( - ValueError, r'The values and .* indices\[\[2\], :\]' - ): - types.VarLenTensorValue( - values=np.array([1, 2, 3, 4]), - # rows indices are reversed - indices=np.array([[1, 0], [1, 1], [0, 0], [0, 1]]), - dense_shape=np.array([2, 2]), - ) + def testVarLenTensorValueInvalidRowIndices(self): + with self.assertRaisesRegex( + ValueError, r"The values and .* indices\[\[2\], :\]" + ): + types.VarLenTensorValue( + values=np.array([1, 2, 3, 4]), + # rows indices are reversed + indices=np.array([[1, 0], [1, 1], [0, 0], [0, 1]]), + dense_shape=np.array([2, 2]), + ) - def testVarLenTensorValueInvalidColumnIndices(self): - with self.assertRaisesRegex( - ValueError, r'The values and .* indices\[\[1\], :\]' - ): - types.VarLenTensorValue( - values=np.array([1, 2, 3, 4]), - # columns indices in the first row are reversed - indices=np.array([[0, 1], [0, 0], [1, 0], [1, 1]]), - dense_shape=np.array([2, 2]), - ) + def testVarLenTensorValueInvalidColumnIndices(self): + with self.assertRaisesRegex( + ValueError, r"The values and .* indices\[\[1\], :\]" + ): + types.VarLenTensorValue( + values=np.array([1, 2, 3, 4]), + # columns indices in the first row are reversed + indices=np.array([[0, 1], [0, 0], [1, 0], [1, 1]]), + dense_shape=np.array([2, 2]), + ) - def testVarLenTensorValueEmpty(self): - types.VarLenTensorValue( - values=np.array([]), - indices=np.empty((0, 2)), - dense_shape=np.array([2, 2]), - ) + def testVarLenTensorValueEmpty(self): + types.VarLenTensorValue( + values=np.array([]), + indices=np.empty((0, 2)), + dense_shape=np.array([2, 2]), + ) -if __name__ == '__main__': - absltest.main() +if __name__ == "__main__": + absltest.main() diff --git a/tensorflow_model_analysis/api/verifier_lib.py b/tensorflow_model_analysis/api/verifier_lib.py index 08e1d812ef..ba1b3ef6b8 100644 --- a/tensorflow_model_analysis/api/verifier_lib.py +++ b/tensorflow_model_analysis/api/verifier_lib.py @@ -16,6 +16,7 @@ from typing import Any, Dict, List import apache_beam as beam + from tensorflow_model_analysis.api import types from tensorflow_model_analysis.validators import validator @@ -28,25 +29,27 @@ def Validate( # pylint: disable=invalid-name alternatives: Dict[str, beam.PTransform], validators: List[validator.Validator], ) -> validator.Validation: - """Performs validation of alternative evaluations. - - Args: - extracts: PCollection of extracts. - alternatives: Dict of PTransforms (Extracts -> Evaluation) whose output will - be compared for validation purposes (e.g. 'baseline' vs 'candidate'). - validators: List of validators for validating the output from running the - alternatives. The Validation outputs produced by the validators will be - merged into a single output. If there are overlapping output keys, later - outputs will replace earlier outputs sharing the same key. - - Returns: - Validation dict. - """ - evaluations = {} - for key in alternatives: - evaluations[key] = extracts | 'Evaluate(%s)' % key >> alternatives[key] - - validation = {} - for v in validators: - validation.update(evaluations | v.stage_name >> v.ptransform) - return validation + """Performs validation of alternative evaluations. + + Args: + ---- + extracts: PCollection of extracts. + alternatives: Dict of PTransforms (Extracts -> Evaluation) whose output will + be compared for validation purposes (e.g. 'baseline' vs 'candidate'). + validators: List of validators for validating the output from running the + alternatives. The Validation outputs produced by the validators will be + merged into a single output. If there are overlapping output keys, later + outputs will replace earlier outputs sharing the same key. + + Returns: + ------- + Validation dict. + """ + evaluations = {} + for key in alternatives: + evaluations[key] = extracts | "Evaluate(%s)" % key >> alternatives[key] + + validation = {} + for v in validators: + validation.update(evaluations | v.stage_name >> v.ptransform) + return validation diff --git a/tensorflow_model_analysis/constants.py b/tensorflow_model_analysis/constants.py index 7023b2ef27..5718dcf63f 100644 --- a/tensorflow_model_analysis/constants.py +++ b/tensorflow_model_analysis/constants.py @@ -15,24 +15,23 @@ from tfx_bsl.telemetry import util - # Mode for multiple model analysis runs -UNKNOWN_EVAL_MODE = 'unknown_eval_mode' -MODEL_CENTRIC_MODE = 'model_centric_mode' -DATA_CENTRIC_MODE = 'data_centric_mode' +UNKNOWN_EVAL_MODE = "unknown_eval_mode" +MODEL_CENTRIC_MODE = "model_centric_mode" +DATA_CENTRIC_MODE = "data_centric_mode" # Types of placeholders -PLACEHOLDER = 'placeholder' -SPARSE_PLACEHOLDER = 'sparse_placeholder' +PLACEHOLDER = "placeholder" +SPARSE_PLACEHOLDER = "sparse_placeholder" # Types of TF models -TFMA_EVAL = 'tfma_eval' -TF_ESTIMATOR = 'tf_estimator' -TF_KERAS = 'tf_keras' -TF_GENERIC = 'tf_generic' -TF_LITE = 'tf_lite' -TF_JS = 'tf_js' -MATERIALIZED_PREDICTION = 'materialized_prediction' +TFMA_EVAL = "tfma_eval" +TF_ESTIMATOR = "tf_estimator" +TF_KERAS = "tf_keras" +TF_GENERIC = "tf_generic" +TF_LITE = "tf_lite" +TF_JS = "tf_js" +MATERIALIZED_PREDICTION = "materialized_prediction" VALID_TF_MODEL_TYPES = ( TFMA_EVAL, TF_GENERIC, @@ -44,71 +43,71 @@ ) # This constant is only used for telemetry -MODEL_AGNOSTIC = 'model_agnostic' +MODEL_AGNOSTIC = "model_agnostic" # LINT.IfChange -METRICS_NAMESPACE = util.MakeTfxNamespace(['ModelAnalysis']) +METRICS_NAMESPACE = util.MakeTfxNamespace(["ModelAnalysis"]) # LINT.ThenChange(../../../learning/fairness/infra/plx/scripts/tfma_metrics_computed_tracker_macros.sql) # Keys for Extracts dictionary (keys starting with _ will not be materialized). # Input key. Could be a serialized tf.train.Example, a CSV row, JSON data, etc # depending on what the EvalInputReceiver was configured to accept as input. -INPUT_KEY = 'input' +INPUT_KEY = "input" # This holds an Arrow RecordBatch representing a batch of examples. -ARROW_RECORD_BATCH_KEY = 'arrow_record_batch' +ARROW_RECORD_BATCH_KEY = "arrow_record_batch" # This holds the column name containing the raw input (Could be a serialized # tf.train.Example, a CSV row, JSON data, etc) in an Arrow RecordBatch. -ARROW_INPUT_COLUMN = '__raw_record__' +ARROW_INPUT_COLUMN = "__raw_record__" # Features, predictions, and labels key. -FEATURES_PREDICTIONS_LABELS_KEY = '_fpl' +FEATURES_PREDICTIONS_LABELS_KEY = "_fpl" # Contains SliceKeyTypes that are used to fanout and aggregate. -SLICE_KEY_TYPES_KEY = '_slice_key_types' +SLICE_KEY_TYPES_KEY = "_slice_key_types" # Human-readable slice strings that are written to the diagnostic table for # analysis. -SLICE_KEYS_KEY = 'slice_keys' +SLICE_KEYS_KEY = "slice_keys" # Features key. -FEATURES_KEY = 'features' +FEATURES_KEY = "features" # Transformed features key. -TRANSFORMED_FEATURES_KEY = 'transformed_features' +TRANSFORMED_FEATURES_KEY = "transformed_features" # Labels key. -LABELS_KEY = 'labels' +LABELS_KEY = "labels" # Predictions key. -PREDICTIONS_KEY = 'predictions' +PREDICTIONS_KEY = "predictions" # Example weights key. -EXAMPLE_WEIGHTS_KEY = 'example_weights' +EXAMPLE_WEIGHTS_KEY = "example_weights" # Attributions key. -ATTRIBUTIONS_KEY = 'attributions' +ATTRIBUTIONS_KEY = "attributions" # Prediction log key. -SPLIT_KEY = 'split' +SPLIT_KEY = "split" # Keys used for standard attribution scores -BASELINE_SCORE_KEY = 'baseline_score' -EXAMPLE_SCORE_KEY = 'example_score' +BASELINE_SCORE_KEY = "baseline_score" +EXAMPLE_SCORE_KEY = "example_score" # Keys for Evaluation/Validation dictionaries # Metrics output key. -METRICS_KEY = 'metrics' +METRICS_KEY = "metrics" # Plots output key. -PLOTS_KEY = 'plots' +PLOTS_KEY = "plots" # Validations key. -VALIDATIONS_KEY = 'validations' +VALIDATIONS_KEY = "validations" # Analysis output key. -ANALYSIS_KEY = 'analysis' +ANALYSIS_KEY = "analysis" # Keys for validation alternatives -BASELINE_KEY = 'baseline' -CANDIDATE_KEY = 'candidate' +BASELINE_KEY = "baseline" +CANDIDATE_KEY = "candidate" -MATERIALIZE_COLUMNS = 'materialize' +MATERIALIZE_COLUMNS = "materialize" # Key used to save and store prediction logs. -PREDICTION_LOG_KEY = 'prediction_log' +PREDICTION_LOG_KEY = "prediction_log" # Not actually for any metric, just used for communicating errors. -ERROR_METRIC_NAME = '__ERROR__' +ERROR_METRIC_NAME = "__ERROR__" diff --git a/tensorflow_model_analysis/contrib/aggregates/binary_confusion_matrices.py b/tensorflow_model_analysis/contrib/aggregates/binary_confusion_matrices.py index 691929b0df..b5145a2293 100644 --- a/tensorflow_model_analysis/contrib/aggregates/binary_confusion_matrices.py +++ b/tensorflow_model_analysis/contrib/aggregates/binary_confusion_matrices.py @@ -17,250 +17,256 @@ DEFAULT_NUM_EXAMPLE_IDS = 100 -Matrix = NamedTuple( - 'Matrix', [('tp', float), ('tn', float), ('fp', float), ('fn', float)] -) - -_ThresholdEntry = NamedTuple( - '_ThresholdEntry', - [ - ('matrix', Matrix), - ('tp_examples', List[str]), - ('tn_examples', List[str]), - ('fp_examples', List[str]), - ('fn_examples', List[str]), - ], -) + +class Matrix(NamedTuple): + tp: float + tn: float + fp: float + fn: float + + +class _ThresholdEntry(NamedTuple): + matrix: Matrix + tp_examples: List[str] + tn_examples: List[str] + fp_examples: List[str] + fn_examples: List[str] + MatrixAccumulator = Dict[float, _ThresholdEntry] class BinaryConfusionMatrices: - """Computes binary confusion matrix.""" - - def __init__( - self, - thresholds: Sequence[float], - example_ids_count: int = DEFAULT_NUM_EXAMPLE_IDS, - enable_fractional_labels: bool = True, - ): - """Initializes the class. - - Args: - thresholds: A specific set of thresholds to use. The caller is responsible - for marking the boundaries with +/-epsilon if desired. - example_ids_count: Max number of example ids to be extracted for each - result in the binary confusion matrix (tp, tn, fp, and fn). - enable_fractional_labels: If false, labels will be compared to the - threshold in the same way predictions are. If true, each incoming tuple - of (label, prediction, and example weight) will be split into two tuples - as follows (where l, p, w represent the resulting label, prediction, and - example weight values): (1) l = 0.0, p = prediction, and w = - example_weight * (1.0 - label) (2) l = 1.0, p = prediction, and w = - example_weight * label. If enabled, an exception will be raised if - labels are not within [0, 1]. The implementation is such that tuples - associated with a weight of zero are not yielded. This means it is safe - to enable fractional labels even when the labels only take on the values - of 0.0 or 1.0. - """ - self._thresholds = thresholds - self._example_ids_count = example_ids_count - self._enable_fractional_labels = enable_fractional_labels - - def create_accumulator(self) -> MatrixAccumulator: - return {} - - def _merge_example_ids( - self, list_1: List[str], list_2: List[str] - ) -> List[str]: - result = list_1[: self._example_ids_count] - result.extend(list_2[: self._example_ids_count - len(result)]) - return result - - def _merge_entry( - self, - accumulator: MatrixAccumulator, - threshold: float, - entry: _ThresholdEntry, - ) -> _ThresholdEntry: - if threshold not in accumulator: - return entry - - return _ThresholdEntry( - matrix=Matrix( - tp=accumulator[threshold].matrix.tp + entry.matrix.tp, - tn=accumulator[threshold].matrix.tn + entry.matrix.tn, - fp=accumulator[threshold].matrix.fp + entry.matrix.fp, - fn=accumulator[threshold].matrix.fn + entry.matrix.fn, - ), - tp_examples=self._merge_example_ids( - accumulator[threshold].tp_examples, entry.tp_examples - ), - tn_examples=self._merge_example_ids( - accumulator[threshold].tn_examples, entry.tn_examples - ), - fp_examples=self._merge_example_ids( - accumulator[threshold].fp_examples, entry.fp_examples - ), - fn_examples=self._merge_example_ids( - accumulator[threshold].fn_examples, entry.fn_examples - ), - ) - - def add_input( - self, - accumulator: MatrixAccumulator, - labels: Sequence[float], - predictions: Sequence[float], - example_weights: Optional[Sequence[float]], - example_id: Optional[str], - ) -> MatrixAccumulator: - """Adds a single example input to the accumulator. - - Args: - accumulator: Accumulator to add input to. - labels: Expected values. - predictions: Predicted values. - example_weights: Weights for this example. - example_id: ID for this example. - - Returns: - Merged MatrixAccumulator of the original accumulator and the added inputs. - """ - if example_weights is None or all(w is None for w in example_weights): - example_weights = [1] * len(labels) - - for threshold in self._thresholds: - tp = 0.0 - tn = 0.0 - fp = 0.0 - fn = 0.0 - tp_example = None - tn_example = None - fp_example = None - fn_example = None - # We need to iterate here even though it is one example because one - # example can contain multiple labels/predictions/example_weights. - for label, prediction, example_weight in zip( - labels, predictions, example_weights - ): - if ( - label == 1.0 - if self._enable_fractional_labels - else label > threshold - ): - if prediction > threshold: - tp += example_weight - tp_example = example_id - else: - fn += example_weight - fn_example = example_id - else: - if prediction > threshold: - fp += example_weight - fp_example = example_id - else: - tn += example_weight - tn_example = example_id - - accumulator[threshold] = self._merge_entry( - accumulator=accumulator, - threshold=threshold, - entry=_ThresholdEntry( - Matrix(tp=tp, tn=tn, fp=fp, fn=fn), - tp_examples=[tp_example] if tp_example is not None else [], - tn_examples=[tn_example] if tn_example is not None else [], - fp_examples=[fp_example] if fp_example is not None else [], - fn_examples=[fn_example] if fn_example is not None else [], - ), - ) - - return accumulator - - def add_inputs( - self, - accumulator: MatrixAccumulator, - labels: Sequence[Sequence[float]], - predictions: Sequence[Sequence[float]], - example_weights: Optional[Sequence[Sequence[float]]], - example_ids: Optional[Sequence[str]], - ) -> MatrixAccumulator: - """Adds a batch of inputs to the accumulator. - - Args: - accumulator: Accumulator to add input to. - labels: Expected values. - predictions: Predicted values. - example_weights: Weights for each example. - example_ids: IDs For each example. - - Returns: - Merged MatrixAccumulator of the original accumulator and the added inputs. - """ - make_iter = lambda ex: ex if hasattr(ex, '__iter__') else [ex] - - if example_weights is None: - example_weights = [None] * len(labels) - - if example_ids is None: - example_ids = [None] * len(labels) - - for label, prediction, example_weight, example_id in zip( - labels, predictions, example_weights, example_ids + """Computes binary confusion matrix.""" + + def __init__( + self, + thresholds: Sequence[float], + example_ids_count: int = DEFAULT_NUM_EXAMPLE_IDS, + enable_fractional_labels: bool = True, ): - # Calls self.add_input() for each example within the batch. - accumulator = self.add_input( - accumulator=accumulator, - labels=make_iter(label), - predictions=make_iter(prediction), - example_weights=make_iter(example_weight), - example_id=example_id, - ) - - return accumulator - - def merge_accumulators( - self, - accumulators: Iterable[MatrixAccumulator], - ) -> MatrixAccumulator: - """Merges accumulators. - - Args: - accumulators: Accumulators to be merged - - Returns: - The merged accumulator. - """ - accumulators = iter(accumulators) - result = next(accumulators) - - for accumulator in accumulators: - for threshold in self._thresholds: - # We need to check if threshold is in the accumulator because the - # accumulator can be empty (i.e. no input was been added). - if threshold in accumulator: - result[threshold] = self._merge_entry( - accumulator=result, - threshold=threshold, - entry=accumulator[threshold], - ) - - return result - - def extract_output(self, accumulator: MatrixAccumulator) -> MatrixAccumulator: - for threshold in self._thresholds: - if threshold not in accumulator: - accumulator[threshold] = _ThresholdEntry( - Matrix(tp=0.0, tn=0.0, fp=0.0, fn=0.0), - tp_examples=[], - tn_examples=[], - fp_examples=[], - fn_examples=[], + """Initializes the class. + + Args: + ---- + thresholds: A specific set of thresholds to use. The caller is responsible + for marking the boundaries with +/-epsilon if desired. + example_ids_count: Max number of example ids to be extracted for each + result in the binary confusion matrix (tp, tn, fp, and fn). + enable_fractional_labels: If false, labels will be compared to the + threshold in the same way predictions are. If true, each incoming tuple + of (label, prediction, and example weight) will be split into two tuples + as follows (where l, p, w represent the resulting label, prediction, and + example weight values): (1) l = 0.0, p = prediction, and w = + example_weight * (1.0 - label) (2) l = 1.0, p = prediction, and w = + example_weight * label. If enabled, an exception will be raised if + labels are not within [0, 1]. The implementation is such that tuples + associated with a weight of zero are not yielded. This means it is safe + to enable fractional labels even when the labels only take on the values + of 0.0 or 1.0. + """ + self._thresholds = thresholds + self._example_ids_count = example_ids_count + self._enable_fractional_labels = enable_fractional_labels + + def create_accumulator(self) -> MatrixAccumulator: + return {} + + def _merge_example_ids(self, list_1: List[str], list_2: List[str]) -> List[str]: + result = list_1[: self._example_ids_count] + result.extend(list_2[: self._example_ids_count - len(result)]) + return result + + def _merge_entry( + self, + accumulator: MatrixAccumulator, + threshold: float, + entry: _ThresholdEntry, + ) -> _ThresholdEntry: + if threshold not in accumulator: + return entry + + return _ThresholdEntry( + matrix=Matrix( + tp=accumulator[threshold].matrix.tp + entry.matrix.tp, + tn=accumulator[threshold].matrix.tn + entry.matrix.tn, + fp=accumulator[threshold].matrix.fp + entry.matrix.fp, + fn=accumulator[threshold].matrix.fn + entry.matrix.fn, + ), + tp_examples=self._merge_example_ids( + accumulator[threshold].tp_examples, entry.tp_examples + ), + tn_examples=self._merge_example_ids( + accumulator[threshold].tn_examples, entry.tn_examples + ), + fp_examples=self._merge_example_ids( + accumulator[threshold].fp_examples, entry.fp_examples + ), + fn_examples=self._merge_example_ids( + accumulator[threshold].fn_examples, entry.fn_examples + ), ) - return accumulator - def __call__(self, *inputs, **named_inputs): - """Directly apply aggregate on inputs.""" - return self.extract_output( - self.add_inputs(self.create_accumulator(), *inputs, **named_inputs) - ) + def add_input( + self, + accumulator: MatrixAccumulator, + labels: Sequence[float], + predictions: Sequence[float], + example_weights: Optional[Sequence[float]], + example_id: Optional[str], + ) -> MatrixAccumulator: + """Adds a single example input to the accumulator. + + Args: + ---- + accumulator: Accumulator to add input to. + labels: Expected values. + predictions: Predicted values. + example_weights: Weights for this example. + example_id: ID for this example. + + Returns: + ------- + Merged MatrixAccumulator of the original accumulator and the added inputs. + """ + if example_weights is None or all(w is None for w in example_weights): + example_weights = [1] * len(labels) + + for threshold in self._thresholds: + tp = 0.0 + tn = 0.0 + fp = 0.0 + fn = 0.0 + tp_example = None + tn_example = None + fp_example = None + fn_example = None + # We need to iterate here even though it is one example because one + # example can contain multiple labels/predictions/example_weights. + for label, prediction, example_weight in zip( + labels, predictions, example_weights + ): + if ( + label == 1.0 + if self._enable_fractional_labels + else label > threshold + ): + if prediction > threshold: + tp += example_weight + tp_example = example_id + else: + fn += example_weight + fn_example = example_id + else: + if prediction > threshold: + fp += example_weight + fp_example = example_id + else: + tn += example_weight + tn_example = example_id + + accumulator[threshold] = self._merge_entry( + accumulator=accumulator, + threshold=threshold, + entry=_ThresholdEntry( + Matrix(tp=tp, tn=tn, fp=fp, fn=fn), + tp_examples=[tp_example] if tp_example is not None else [], + tn_examples=[tn_example] if tn_example is not None else [], + fp_examples=[fp_example] if fp_example is not None else [], + fn_examples=[fn_example] if fn_example is not None else [], + ), + ) + + return accumulator + + def add_inputs( + self, + accumulator: MatrixAccumulator, + labels: Sequence[Sequence[float]], + predictions: Sequence[Sequence[float]], + example_weights: Optional[Sequence[Sequence[float]]], + example_ids: Optional[Sequence[str]], + ) -> MatrixAccumulator: + """Adds a batch of inputs to the accumulator. + + Args: + ---- + accumulator: Accumulator to add input to. + labels: Expected values. + predictions: Predicted values. + example_weights: Weights for each example. + example_ids: IDs For each example. + + Returns: + ------- + Merged MatrixAccumulator of the original accumulator and the added inputs. + """ + make_iter = lambda ex: ex if hasattr(ex, "__iter__") else [ex] + + if example_weights is None: + example_weights = [None] * len(labels) + + if example_ids is None: + example_ids = [None] * len(labels) + + for label, prediction, example_weight, example_id in zip( + labels, predictions, example_weights, example_ids + ): + # Calls self.add_input() for each example within the batch. + accumulator = self.add_input( + accumulator=accumulator, + labels=make_iter(label), + predictions=make_iter(prediction), + example_weights=make_iter(example_weight), + example_id=example_id, + ) + + return accumulator + + def merge_accumulators( + self, + accumulators: Iterable[MatrixAccumulator], + ) -> MatrixAccumulator: + """Merges accumulators. + + Args: + ---- + accumulators: Accumulators to be merged + + Returns: + ------- + The merged accumulator. + """ + accumulators = iter(accumulators) + result = next(accumulators) + + for accumulator in accumulators: + for threshold in self._thresholds: + # We need to check if threshold is in the accumulator because the + # accumulator can be empty (i.e. no input was been added). + if threshold in accumulator: + result[threshold] = self._merge_entry( + accumulator=result, + threshold=threshold, + entry=accumulator[threshold], + ) + + return result + + def extract_output(self, accumulator: MatrixAccumulator) -> MatrixAccumulator: + for threshold in self._thresholds: + if threshold not in accumulator: + accumulator[threshold] = _ThresholdEntry( + Matrix(tp=0.0, tn=0.0, fp=0.0, fn=0.0), + tp_examples=[], + tn_examples=[], + fp_examples=[], + fn_examples=[], + ) + return accumulator + + def __call__(self, *inputs, **named_inputs): + """Directly apply aggregate on inputs.""" + return self.extract_output( + self.add_inputs(self.create_accumulator(), *inputs, **named_inputs) + ) diff --git a/tensorflow_model_analysis/contrib/aggregates/binary_confusion_matrices_test.py b/tensorflow_model_analysis/contrib/aggregates/binary_confusion_matrices_test.py index e32adc4d07..2ffb2fc229 100644 --- a/tensorflow_model_analysis/contrib/aggregates/binary_confusion_matrices_test.py +++ b/tensorflow_model_analysis/contrib/aggregates/binary_confusion_matrices_test.py @@ -13,8 +13,9 @@ # limitations under the License. """Tests for binary confusion matrix.""" -from absl.testing import parameterized import tensorflow as tf +from absl.testing import parameterized + from tensorflow_model_analysis.contrib.aggregates import binary_confusion_matrices from tensorflow_model_analysis.utils import test_util @@ -22,283 +23,282 @@ class BinaryConfusionMatricesTest( test_util.TensorflowModelAnalysisTest, parameterized.TestCase ): - - @parameterized.parameters( - dict( - thresholds=[0, 0.5, 1], - example_ids_count=100, - example_weights=(None, None, None, None), - example_ids=(None, None, None, None), - expected_result={ - 0: binary_confusion_matrices._ThresholdEntry( - matrix=binary_confusion_matrices.Matrix( - tp=2.0, tn=1.0, fp=1.0, fn=0.0 - ), - tp_examples=[], - tn_examples=[], - fp_examples=[], - fn_examples=[], - ), - 0.5: binary_confusion_matrices._ThresholdEntry( - matrix=binary_confusion_matrices.Matrix( - tp=1.0, tn=2.0, fp=0.0, fn=1.0 - ), - tp_examples=[], - tn_examples=[], - fp_examples=[], - fn_examples=[], - ), - 1: binary_confusion_matrices._ThresholdEntry( - matrix=binary_confusion_matrices.Matrix( - tp=0.0, tn=2.0, fp=0.0, fn=2.0 - ), - tp_examples=[], - tn_examples=[], - fp_examples=[], - fn_examples=[], - ), - }, - ), - dict( - thresholds=[0.1, 0.9], - example_ids_count=1, - example_weights=(1, 1, 1, 1), - example_ids=('id_1', 'id_2', 'id_3', 'id_4'), - expected_result={ - 0.1: binary_confusion_matrices._ThresholdEntry( - matrix=binary_confusion_matrices.Matrix( - tp=2.0, tn=1.0, fp=1.0, fn=0.0 - ), - tp_examples=['id_3'], - tn_examples=['id_1'], - fp_examples=['id_2'], - fn_examples=[], - ), - 0.9: binary_confusion_matrices._ThresholdEntry( - matrix=binary_confusion_matrices.Matrix( - tp=0.0, tn=2.0, fp=0.0, fn=2.0 - ), - tp_examples=[], - tn_examples=['id_1'], - fp_examples=[], - fn_examples=['id_3'], - ), - }, - ), - dict( - thresholds=[0.1, 0.9], - example_ids_count=2, - example_weights=(1, 1, 1, 1), - example_ids=('id_1', 'id_2', 'id_3', 'id_4'), - expected_result={ - 0.1: binary_confusion_matrices._ThresholdEntry( - matrix=binary_confusion_matrices.Matrix( - tp=2.0, tn=1.0, fp=1.0, fn=0.0 - ), - tp_examples=['id_3', 'id_4'], - tn_examples=['id_1'], - fp_examples=['id_2'], - fn_examples=[], - ), - 0.9: binary_confusion_matrices._ThresholdEntry( - matrix=binary_confusion_matrices.Matrix( - tp=0.0, tn=2.0, fp=0.0, fn=2.0 - ), - tp_examples=[], - tn_examples=['id_1', 'id_2'], - fp_examples=[], - fn_examples=['id_3', 'id_4'], - ), - }, - ), - dict( - thresholds=[0.25, 0.75], - example_ids_count=100, - example_weights=(0.2, 0.3, 0.5, 0.7), - example_ids=(None, None, None, None), - expected_result={ - 0.25: binary_confusion_matrices._ThresholdEntry( - matrix=binary_confusion_matrices.Matrix( - tp=1.2, tn=0.2, fp=0.3, fn=0.0 - ), - tp_examples=[], - tn_examples=[], - fp_examples=[], - fn_examples=[], - ), - 0.75: binary_confusion_matrices._ThresholdEntry( - matrix=binary_confusion_matrices.Matrix( - tp=0.7, tn=0.5, fp=0.0, fn=0.5 - ), - tp_examples=[], - tn_examples=[], - fp_examples=[], - fn_examples=[], - ), - }, - ), - ) - def testBinaryConfusionMatricesPerRow( - self, - thresholds, - example_ids_count, - example_weights, - example_ids, - expected_result, - ): - labels = (0, 0, 1, 1) - predictions = (0, 0.5, 0.3, 0.9) - - confusion_matrix = binary_confusion_matrices.BinaryConfusionMatrices( - thresholds=thresholds, - example_ids_count=example_ids_count, + @parameterized.parameters( + dict( + thresholds=[0, 0.5, 1], + example_ids_count=100, + example_weights=(None, None, None, None), + example_ids=(None, None, None, None), + expected_result={ + 0: binary_confusion_matrices._ThresholdEntry( + matrix=binary_confusion_matrices.Matrix( + tp=2.0, tn=1.0, fp=1.0, fn=0.0 + ), + tp_examples=[], + tn_examples=[], + fp_examples=[], + fn_examples=[], + ), + 0.5: binary_confusion_matrices._ThresholdEntry( + matrix=binary_confusion_matrices.Matrix( + tp=1.0, tn=2.0, fp=0.0, fn=1.0 + ), + tp_examples=[], + tn_examples=[], + fp_examples=[], + fn_examples=[], + ), + 1: binary_confusion_matrices._ThresholdEntry( + matrix=binary_confusion_matrices.Matrix( + tp=0.0, tn=2.0, fp=0.0, fn=2.0 + ), + tp_examples=[], + tn_examples=[], + fp_examples=[], + fn_examples=[], + ), + }, + ), + dict( + thresholds=[0.1, 0.9], + example_ids_count=1, + example_weights=(1, 1, 1, 1), + example_ids=("id_1", "id_2", "id_3", "id_4"), + expected_result={ + 0.1: binary_confusion_matrices._ThresholdEntry( + matrix=binary_confusion_matrices.Matrix( + tp=2.0, tn=1.0, fp=1.0, fn=0.0 + ), + tp_examples=["id_3"], + tn_examples=["id_1"], + fp_examples=["id_2"], + fn_examples=[], + ), + 0.9: binary_confusion_matrices._ThresholdEntry( + matrix=binary_confusion_matrices.Matrix( + tp=0.0, tn=2.0, fp=0.0, fn=2.0 + ), + tp_examples=[], + tn_examples=["id_1"], + fp_examples=[], + fn_examples=["id_3"], + ), + }, + ), + dict( + thresholds=[0.1, 0.9], + example_ids_count=2, + example_weights=(1, 1, 1, 1), + example_ids=("id_1", "id_2", "id_3", "id_4"), + expected_result={ + 0.1: binary_confusion_matrices._ThresholdEntry( + matrix=binary_confusion_matrices.Matrix( + tp=2.0, tn=1.0, fp=1.0, fn=0.0 + ), + tp_examples=["id_3", "id_4"], + tn_examples=["id_1"], + fp_examples=["id_2"], + fn_examples=[], + ), + 0.9: binary_confusion_matrices._ThresholdEntry( + matrix=binary_confusion_matrices.Matrix( + tp=0.0, tn=2.0, fp=0.0, fn=2.0 + ), + tp_examples=[], + tn_examples=["id_1", "id_2"], + fp_examples=[], + fn_examples=["id_3", "id_4"], + ), + }, + ), + dict( + thresholds=[0.25, 0.75], + example_ids_count=100, + example_weights=(0.2, 0.3, 0.5, 0.7), + example_ids=(None, None, None, None), + expected_result={ + 0.25: binary_confusion_matrices._ThresholdEntry( + matrix=binary_confusion_matrices.Matrix( + tp=1.2, tn=0.2, fp=0.3, fn=0.0 + ), + tp_examples=[], + tn_examples=[], + fp_examples=[], + fn_examples=[], + ), + 0.75: binary_confusion_matrices._ThresholdEntry( + matrix=binary_confusion_matrices.Matrix( + tp=0.7, tn=0.5, fp=0.0, fn=0.5 + ), + tp_examples=[], + tn_examples=[], + fp_examples=[], + fn_examples=[], + ), + }, + ), ) - accumulator = confusion_matrix.create_accumulator() - for label, prediction, example_weight, example_id in zip( - labels, predictions, example_weights, example_ids + def testBinaryConfusionMatricesPerRow( + self, + thresholds, + example_ids_count, + example_weights, + example_ids, + expected_result, ): - accumulator = confusion_matrix.add_input( - accumulator=accumulator, - labels=[label], - predictions=[prediction], - example_weights=[example_weight] if example_weight else None, - example_id=example_id, - ) - self.assertDictEqual(accumulator, expected_result) + labels = (0, 0, 1, 1) + predictions = (0, 0.5, 0.3, 0.9) - @parameterized.parameters( - dict( - thresholds=[0, 0.5, 1], - example_ids_count=100, - example_weights=(None, None, None, None), - example_ids=(None, None, None, None), - expected_result={ - 0: binary_confusion_matrices._ThresholdEntry( - matrix=binary_confusion_matrices.Matrix( - tp=2.0, tn=1.0, fp=1.0, fn=0.0 - ), - tp_examples=[], - tn_examples=[], - fp_examples=[], - fn_examples=[], - ), - 0.5: binary_confusion_matrices._ThresholdEntry( - matrix=binary_confusion_matrices.Matrix( - tp=1.0, tn=2.0, fp=0.0, fn=1.0 - ), - tp_examples=[], - tn_examples=[], - fp_examples=[], - fn_examples=[], - ), - 1: binary_confusion_matrices._ThresholdEntry( - matrix=binary_confusion_matrices.Matrix( - tp=0.0, tn=2.0, fp=0.0, fn=2.0 - ), - tp_examples=[], - tn_examples=[], - fp_examples=[], - fn_examples=[], - ), - }, - ), - dict( - thresholds=[0.1, 0.9], - example_ids_count=1, - example_weights=(1, 1, 1, 1), - example_ids=('id_1', 'id_2', 'id_3', 'id_4'), - expected_result={ - 0.1: binary_confusion_matrices._ThresholdEntry( - matrix=binary_confusion_matrices.Matrix( - tp=2.0, tn=1.0, fp=1.0, fn=0.0 - ), - tp_examples=['id_3'], - tn_examples=['id_1'], - fp_examples=['id_2'], - fn_examples=[], - ), - 0.9: binary_confusion_matrices._ThresholdEntry( - matrix=binary_confusion_matrices.Matrix( - tp=0.0, tn=2.0, fp=0.0, fn=2.0 - ), - tp_examples=[], - tn_examples=['id_1'], - fp_examples=[], - fn_examples=['id_3'], - ), - }, - ), - dict( - thresholds=[0.1, 0.9], - example_ids_count=2, - example_weights=(1, 1, 1, 1), - example_ids=('id_1', 'id_2', 'id_3', 'id_4'), - expected_result={ - 0.1: binary_confusion_matrices._ThresholdEntry( - matrix=binary_confusion_matrices.Matrix( - tp=2.0, tn=1.0, fp=1.0, fn=0.0 - ), - tp_examples=['id_3', 'id_4'], - tn_examples=['id_1'], - fp_examples=['id_2'], - fn_examples=[], - ), - 0.9: binary_confusion_matrices._ThresholdEntry( - matrix=binary_confusion_matrices.Matrix( - tp=0.0, tn=2.0, fp=0.0, fn=2.0 - ), - tp_examples=[], - tn_examples=['id_1', 'id_2'], - fp_examples=[], - fn_examples=['id_3', 'id_4'], - ), - }, - ), - dict( - thresholds=[0.25, 0.75], - example_ids_count=100, - example_weights=(0.2, 0.3, 0.5, 0.7), - example_ids=(None, None, None, None), - expected_result={ - 0.25: binary_confusion_matrices._ThresholdEntry( - matrix=binary_confusion_matrices.Matrix( - tp=1.2, tn=0.2, fp=0.3, fn=0.0 - ), - tp_examples=[], - tn_examples=[], - fp_examples=[], - fn_examples=[], - ), - 0.75: binary_confusion_matrices._ThresholdEntry( - matrix=binary_confusion_matrices.Matrix( - tp=0.7, tn=0.5, fp=0.0, fn=0.5 - ), - tp_examples=[], - tn_examples=[], - fp_examples=[], - fn_examples=[], - ), - }, - ), - ) - def testBinaryConfusionMatricesInProcess( - self, - thresholds, - example_ids_count, - example_weights, - example_ids, - expected_result, - ): - labels = (0, 0, 1, 1) - predictions = (0, 0.5, 0.3, 0.9) + confusion_matrix = binary_confusion_matrices.BinaryConfusionMatrices( + thresholds=thresholds, + example_ids_count=example_ids_count, + ) + accumulator = confusion_matrix.create_accumulator() + for label, prediction, example_weight, example_id in zip( + labels, predictions, example_weights, example_ids + ): + accumulator = confusion_matrix.add_input( + accumulator=accumulator, + labels=[label], + predictions=[prediction], + example_weights=[example_weight] if example_weight else None, + example_id=example_id, + ) + self.assertDictEqual(accumulator, expected_result) - confusion_matrix = binary_confusion_matrices.BinaryConfusionMatrices( - thresholds=thresholds, - example_ids_count=example_ids_count, + @parameterized.parameters( + dict( + thresholds=[0, 0.5, 1], + example_ids_count=100, + example_weights=(None, None, None, None), + example_ids=(None, None, None, None), + expected_result={ + 0: binary_confusion_matrices._ThresholdEntry( + matrix=binary_confusion_matrices.Matrix( + tp=2.0, tn=1.0, fp=1.0, fn=0.0 + ), + tp_examples=[], + tn_examples=[], + fp_examples=[], + fn_examples=[], + ), + 0.5: binary_confusion_matrices._ThresholdEntry( + matrix=binary_confusion_matrices.Matrix( + tp=1.0, tn=2.0, fp=0.0, fn=1.0 + ), + tp_examples=[], + tn_examples=[], + fp_examples=[], + fn_examples=[], + ), + 1: binary_confusion_matrices._ThresholdEntry( + matrix=binary_confusion_matrices.Matrix( + tp=0.0, tn=2.0, fp=0.0, fn=2.0 + ), + tp_examples=[], + tn_examples=[], + fp_examples=[], + fn_examples=[], + ), + }, + ), + dict( + thresholds=[0.1, 0.9], + example_ids_count=1, + example_weights=(1, 1, 1, 1), + example_ids=("id_1", "id_2", "id_3", "id_4"), + expected_result={ + 0.1: binary_confusion_matrices._ThresholdEntry( + matrix=binary_confusion_matrices.Matrix( + tp=2.0, tn=1.0, fp=1.0, fn=0.0 + ), + tp_examples=["id_3"], + tn_examples=["id_1"], + fp_examples=["id_2"], + fn_examples=[], + ), + 0.9: binary_confusion_matrices._ThresholdEntry( + matrix=binary_confusion_matrices.Matrix( + tp=0.0, tn=2.0, fp=0.0, fn=2.0 + ), + tp_examples=[], + tn_examples=["id_1"], + fp_examples=[], + fn_examples=["id_3"], + ), + }, + ), + dict( + thresholds=[0.1, 0.9], + example_ids_count=2, + example_weights=(1, 1, 1, 1), + example_ids=("id_1", "id_2", "id_3", "id_4"), + expected_result={ + 0.1: binary_confusion_matrices._ThresholdEntry( + matrix=binary_confusion_matrices.Matrix( + tp=2.0, tn=1.0, fp=1.0, fn=0.0 + ), + tp_examples=["id_3", "id_4"], + tn_examples=["id_1"], + fp_examples=["id_2"], + fn_examples=[], + ), + 0.9: binary_confusion_matrices._ThresholdEntry( + matrix=binary_confusion_matrices.Matrix( + tp=0.0, tn=2.0, fp=0.0, fn=2.0 + ), + tp_examples=[], + tn_examples=["id_1", "id_2"], + fp_examples=[], + fn_examples=["id_3", "id_4"], + ), + }, + ), + dict( + thresholds=[0.25, 0.75], + example_ids_count=100, + example_weights=(0.2, 0.3, 0.5, 0.7), + example_ids=(None, None, None, None), + expected_result={ + 0.25: binary_confusion_matrices._ThresholdEntry( + matrix=binary_confusion_matrices.Matrix( + tp=1.2, tn=0.2, fp=0.3, fn=0.0 + ), + tp_examples=[], + tn_examples=[], + fp_examples=[], + fn_examples=[], + ), + 0.75: binary_confusion_matrices._ThresholdEntry( + matrix=binary_confusion_matrices.Matrix( + tp=0.7, tn=0.5, fp=0.0, fn=0.5 + ), + tp_examples=[], + tn_examples=[], + fp_examples=[], + fn_examples=[], + ), + }, + ), ) - actual = confusion_matrix(labels, predictions, example_weights, example_ids) - self.assertDictEqual(actual, expected_result) + def testBinaryConfusionMatricesInProcess( + self, + thresholds, + example_ids_count, + example_weights, + example_ids, + expected_result, + ): + labels = (0, 0, 1, 1) + predictions = (0, 0.5, 0.3, 0.9) + + confusion_matrix = binary_confusion_matrices.BinaryConfusionMatrices( + thresholds=thresholds, + example_ids_count=example_ids_count, + ) + actual = confusion_matrix(labels, predictions, example_weights, example_ids) + self.assertDictEqual(actual, expected_result) -if __name__ == '__main__': - tf.test.main() +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow_model_analysis/eval_metrics_graph/eval_metrics_graph.py b/tensorflow_model_analysis/eval_metrics_graph/eval_metrics_graph.py index c680745db7..19cacb8ffe 100644 --- a/tensorflow_model_analysis/eval_metrics_graph/eval_metrics_graph.py +++ b/tensorflow_model_analysis/eval_metrics_graph/eval_metrics_graph.py @@ -33,437 +33,428 @@ import apache_beam as beam import tensorflow as tf + from tensorflow_model_analysis import constants from tensorflow_model_analysis.api import types from tensorflow_model_analysis.eval_saved_model import constants as eval_constants from tensorflow_model_analysis.eval_saved_model import util from tensorflow_model_analysis.utils import util as general_util + # Config for defining the input tensor feed into the EvalMetricsGraph. This # is needed for model agnostic use cases where a graph must be constructed. -FPLFeedConfig = NamedTuple( # pylint: disable=invalid-name - 'FPLFeedConfig', - [ - ('features', Dict[str, Any]), - ('predictions', Dict[str, Any]), - ('labels', Dict[str, Any]), - ], -) +class FPLFeedConfig(NamedTuple): + features: Dict[str, Any] + predictions: Dict[str, Any] + labels: Dict[str, Any] class EvalMetricsGraph: # pytype: disable=ignored-metaclass - """Abstraction for a graph that is used for computing and aggregating metrics. - - This abstract class contains methods and lays out the API to handle metrics - computation and aggregation as part of the TFMA flow. Inheritors of this class - are responsible for setting up the metrics graph and setting the class - variables which are required to do metric calculations. - """ - - __metaclass__ = abc.ABCMeta - - def __init__(self): - """Initializes this class and attempts to create the graph. + """Abstraction for a graph that is used for computing and aggregating metrics. - This method attempts to create the graph through _construct_graph and - also creates all class variables that need to be populated by the override - function _construct_graph. + This abstract class contains methods and lays out the API to handle metrics + computation and aggregation as part of the TFMA flow. Inheritors of this class + are responsible for setting up the metrics graph and setting the class + variables which are required to do metric calculations. """ - self._graph = tf.Graph() - self._session = tf.compat.v1.Session(graph=self._graph) - - # This lock is for multi-threaded contexts where multiple threads - # share the same EvalSavedModel. - # - # Locking is required in the case where there are multiple threads using - # the same EvalMetricsGraph. Because the metrics variables are part of the - # session, and all threads share the same session, without a lock, the - # "reset-update-get" steps may not be atomic and there can be races. - # - # Having each thread have its own session would also work, but would - # require a bigger refactor. - # TODO(b/131727905): Investigate whether it's possible / better to have - # each thread have its own session. - self._lock = threading.Lock() - - # Variables that need to be populated. - - # The names of the metric. - self._metric_names = [] - - # Ops associated with reading and writing the metric variables. - self._metric_value_ops = [] - self._metric_update_ops = [] - self._metric_variable_assign_ops = [] - - # Nodes associated with the metric variables. - self._metric_variable_nodes = [] - - # Placeholders and feed input for the metric variables. - self._metric_variable_placeholders = [] - self._perform_metrics_update_fn_feed_list = [] - self._perform_metrics_update_fn_feed_list_keys = [] - - # OrderedDicts that map features, predictions, and labels keys to their - # tensors. - self._features_map = {} - self._predictions_map = {} - self._labels_map = {} - - # Ops to set/update/reset all metric variables. - self._all_metric_variable_assign_ops = None - self._all_metric_update_ops = None - self._reset_variables_op = None - - # Callable to perform metric update. - self._perform_metrics_update_fn = None - - # OrderedDict produced by graph_ref's load_(legacy_)inputs, mapping input - # key to tensor value. - self._input_map = None - - self._batch_size = beam.metrics.Metrics.distribution( - constants.METRICS_NAMESPACE, 'batch_size' - ) - self._batch_size_failed = beam.metrics.Metrics.distribution( - constants.METRICS_NAMESPACE, 'batch_size_failed' - ) - - try: - self._construct_graph() - except ( - RuntimeError, - TypeError, - ValueError, - tf.errors.OpError, - ) as exception: - general_util.reraise_augmented(exception, 'Failed to create graph.') - - @abc.abstractmethod - def _construct_graph(self): - """Abstract function that is responsible for graph construction. - - This method is called as part of init. Subclasses are also responsible for - populating the variables in the __init__ method as part of graph - construction. - """ - raise NotImplementedError - - def register_add_metric_callbacks( - self, add_metrics_callbacks: List[types.AddMetricsCallbackType] - ) -> None: - """Register additional metric callbacks. - - Runs the given list of callbacks for adding additional metrics to the graph. - - For more details about add_metrics_callbacks, see the docstring for - EvalSharedModel.add_metrics_callbacks in types.py. - Args: - add_metrics_callbacks: A list of metric callbacks to add to the metrics - graph. - - Raises: - ValueError: There was a metric name conflict: a callback tried to add a - metric whose name conflicted with a metric that was added by an earlier - callback. - """ - with self._graph.as_default(): - features_dict, predictions_dict, labels_dict = ( - self.get_features_predictions_labels_dicts() - ) - features_dict = util.wrap_tensor_or_dict_of_tensors_in_identity( - features_dict - ) - predictions_dict = util.wrap_tensor_or_dict_of_tensors_in_identity( - predictions_dict - ) - labels_dict = util.wrap_tensor_or_dict_of_tensors_in_identity(labels_dict) - - metric_ops = {} - for add_metrics_callback in add_metrics_callbacks: - new_metric_ops = add_metrics_callback( - features_dict, predictions_dict, labels_dict + __metaclass__ = abc.ABCMeta + + def __init__(self): + """Initializes this class and attempts to create the graph. + + This method attempts to create the graph through _construct_graph and + also creates all class variables that need to be populated by the override + function _construct_graph. + """ + self._graph = tf.Graph() + self._session = tf.compat.v1.Session(graph=self._graph) + + # This lock is for multi-threaded contexts where multiple threads + # share the same EvalSavedModel. + # + # Locking is required in the case where there are multiple threads using + # the same EvalMetricsGraph. Because the metrics variables are part of the + # session, and all threads share the same session, without a lock, the + # "reset-update-get" steps may not be atomic and there can be races. + # + # Having each thread have its own session would also work, but would + # require a bigger refactor. + # TODO(b/131727905): Investigate whether it's possible / better to have + # each thread have its own session. + self._lock = threading.Lock() + + # Variables that need to be populated. + + # The names of the metric. + self._metric_names = [] + + # Ops associated with reading and writing the metric variables. + self._metric_value_ops = [] + self._metric_update_ops = [] + self._metric_variable_assign_ops = [] + + # Nodes associated with the metric variables. + self._metric_variable_nodes = [] + + # Placeholders and feed input for the metric variables. + self._metric_variable_placeholders = [] + self._perform_metrics_update_fn_feed_list = [] + self._perform_metrics_update_fn_feed_list_keys = [] + + # OrderedDicts that map features, predictions, and labels keys to their + # tensors. + self._features_map = {} + self._predictions_map = {} + self._labels_map = {} + + # Ops to set/update/reset all metric variables. + self._all_metric_variable_assign_ops = None + self._all_metric_update_ops = None + self._reset_variables_op = None + + # Callable to perform metric update. + self._perform_metrics_update_fn = None + + # OrderedDict produced by graph_ref's load_(legacy_)inputs, mapping input + # key to tensor value. + self._input_map = None + + self._batch_size = beam.metrics.Metrics.distribution( + constants.METRICS_NAMESPACE, "batch_size" ) - overlap = set(new_metric_ops) & set(metric_ops) - if overlap: - raise ValueError( - 'metric keys should not conflict, but an ' - 'earlier callback already added the metrics ' - 'named %s' % overlap - ) - metric_ops.update(new_metric_ops) - self.register_additional_metric_ops(metric_ops) - - def graph_as_default(self): - return self._graph.as_default() - - def graph_finalize(self): - self._graph.finalize() - - def register_additional_metric_ops( - self, metric_ops: Dict[str, Tuple[tf.Tensor, tf.Tensor]] - ) -> None: - """Register additional metric ops that were added. - - Args: - metric_ops: Dictionary of metric ops, just like in the Trainer. - - Raises: - ValueError: One or more of the metric ops already exist in the graph. - """ - for metric_name, (value_op, update_op) in metric_ops.items(): - if metric_name in self._metric_names: - raise ValueError( - 'tried to register new metric with name %s, but a ' - 'metric with that name already exists.' % metric_name + self._batch_size_failed = beam.metrics.Metrics.distribution( + constants.METRICS_NAMESPACE, "batch_size_failed" ) - self._metric_names.append(metric_name) - self._metric_value_ops.append(value_op) - self._metric_update_ops.append(update_op) - - # Update metric variables incrementally with only the new elements in the - # metric_variables collection. - collection = self._graph.get_collection( - tf.compat.v1.GraphKeys.METRIC_VARIABLES - ) - collection = collection[len(self._metric_variable_nodes) :] - - # Note that this is a node_list - it's not something that TFMA - # configures, but something that TF.Learn configures. - # - # As such, we also use graph.get_tensor_by_name directly, instead of - # TFMA's version which expects names encoded by TFMA. - for node in collection: - self._metric_variable_nodes.append(node) - with self._graph.as_default(): - placeholder = tf.compat.v1.placeholder( - dtype=node.dtype, shape=node.get_shape() + + try: + self._construct_graph() + except ( + RuntimeError, + TypeError, + ValueError, + tf.errors.OpError, + ) as exception: + general_util.reraise_augmented(exception, "Failed to create graph.") + + @abc.abstractmethod + def _construct_graph(self): + """Abstract function that is responsible for graph construction. + + This method is called as part of init. Subclasses are also responsible for + populating the variables in the __init__ method as part of graph + construction. + """ + raise NotImplementedError + + def register_add_metric_callbacks( + self, add_metrics_callbacks: List[types.AddMetricsCallbackType] + ) -> None: + """Register additional metric callbacks. + + Runs the given list of callbacks for adding additional metrics to the graph. + + For more details about add_metrics_callbacks, see the docstring for + EvalSharedModel.add_metrics_callbacks in types.py. + + Args: + ---- + add_metrics_callbacks: A list of metric callbacks to add to the metrics + graph. + + Raises: + ------ + ValueError: There was a metric name conflict: a callback tried to add a + metric whose name conflicted with a metric that was added by an earlier + callback. + """ + with self._graph.as_default(): + features_dict, predictions_dict, labels_dict = ( + self.get_features_predictions_labels_dicts() + ) + features_dict = util.wrap_tensor_or_dict_of_tensors_in_identity( + features_dict + ) + predictions_dict = util.wrap_tensor_or_dict_of_tensors_in_identity( + predictions_dict + ) + labels_dict = util.wrap_tensor_or_dict_of_tensors_in_identity(labels_dict) + + metric_ops = {} + for add_metrics_callback in add_metrics_callbacks: + new_metric_ops = add_metrics_callback( + features_dict, predictions_dict, labels_dict + ) + overlap = set(new_metric_ops) & set(metric_ops) + if overlap: + raise ValueError( + "metric keys should not conflict, but an " + "earlier callback already added the metrics " + "named %s" % overlap + ) + metric_ops.update(new_metric_ops) + self.register_additional_metric_ops(metric_ops) + + def graph_as_default(self): + return self._graph.as_default() + + def graph_finalize(self): + self._graph.finalize() + + def register_additional_metric_ops( + self, metric_ops: Dict[str, Tuple[tf.Tensor, tf.Tensor]] + ) -> None: + """Register additional metric ops that were added. + + Args: + ---- + metric_ops: Dictionary of metric ops, just like in the Trainer. + + Raises: + ------ + ValueError: One or more of the metric ops already exist in the graph. + """ + for metric_name, (value_op, update_op) in metric_ops.items(): + if metric_name in self._metric_names: + raise ValueError( + "tried to register new metric with name %s, but a " + "metric with that name already exists." % metric_name + ) + self._metric_names.append(metric_name) + self._metric_value_ops.append(value_op) + self._metric_update_ops.append(update_op) + + # Update metric variables incrementally with only the new elements in the + # metric_variables collection. + collection = self._graph.get_collection(tf.compat.v1.GraphKeys.METRIC_VARIABLES) + collection = collection[len(self._metric_variable_nodes) :] + + # Note that this is a node_list - it's not something that TFMA + # configures, but something that TF.Learn configures. + # + # As such, we also use graph.get_tensor_by_name directly, instead of + # TFMA's version which expects names encoded by TFMA. + for node in collection: + self._metric_variable_nodes.append(node) + with self._graph.as_default(): + placeholder = tf.compat.v1.placeholder( + dtype=node.dtype, shape=node.get_shape() + ) + self._metric_variable_placeholders.append(placeholder) + self._metric_variable_assign_ops.append( + tf.compat.v1.assign(node, placeholder) + ) + + with self._graph.as_default(): + self._all_metric_variable_assign_ops = tf.group( + *self._metric_variable_assign_ops + ) + self._all_metric_update_ops = tf.group(*self._metric_update_ops) + self._reset_variables_op = tf.compat.v1.local_variables_initializer() + self._session.run(self._reset_variables_op) + + self._perform_metrics_update_fn = self._session.make_callable( + fetches=self._all_metric_update_ops, + feed_list=self._perform_metrics_update_fn_feed_list, ) - self._metric_variable_placeholders.append(placeholder) - self._metric_variable_assign_ops.append( - tf.compat.v1.assign(node, placeholder) + + def _log_debug_message_for_tracing_feed_errors( + self, + fetches: List[types.TensorOrOperationType], + feed_list: List[types.TensorOrOperationType], + ) -> None: + """Logs debug message for tracing feed errors.""" + + def create_tuple_list(tensor: types.TensorOrOperationType): + """Create a list of tuples describing a Tensor.""" + result = None + if isinstance(tensor, tf.Operation): + result = [("Op", tensor.name)] + elif isinstance(tensor, tf.SparseTensor): + result = [ + ("SparseTensor.indices", tensor.indices.name), + ("SparseTensor.values", tensor.values.name), + ("SparseTensor.dense_shape", tensor.dense_shape.name), + ] + elif isinstance(tensor, tf.Tensor): + result = [("Tensor", tensor.name)] + else: + result = [("Unknown", str(tensor))] + return result + + def flatten(target: List[List[Any]]) -> List[Any]: + return list(itertools.chain.from_iterable(target)) + + def log_list(name: str, target: List[Any]) -> None: + tf.compat.v1.logging.info("%s = [", name) + for elem_type, elem_name in flatten([create_tuple_list(x) for x in target]): + tf.compat.v1.logging.info("('%s', '%s'),", elem_type, elem_name) + tf.compat.v1.logging.info("]") + + tf.compat.v1.logging.info("-------------------- fetches and feeds information") + log_list("fetches", fetches) + tf.compat.v1.logging.info("") + log_list("feed_list", feed_list) + tf.compat.v1.logging.info( + "-------------------- end fetches and feeds information" ) - with self._graph.as_default(): - self._all_metric_variable_assign_ops = tf.group( - *self._metric_variable_assign_ops - ) - self._all_metric_update_ops = tf.group(*self._metric_update_ops) - self._reset_variables_op = tf.compat.v1.local_variables_initializer() - self._session.run(self._reset_variables_op) - - self._perform_metrics_update_fn = self._session.make_callable( - fetches=self._all_metric_update_ops, - feed_list=self._perform_metrics_update_fn_feed_list, - ) - - def _log_debug_message_for_tracing_feed_errors( - self, - fetches: List[types.TensorOrOperationType], - feed_list: List[types.TensorOrOperationType], - ) -> None: - """Logs debug message for tracing feed errors.""" - - def create_tuple_list(tensor: types.TensorOrOperationType): - """Create a list of tuples describing a Tensor.""" - result = None - if isinstance(tensor, tf.Operation): - result = [('Op', tensor.name)] - elif isinstance(tensor, tf.SparseTensor): - result = [ - ('SparseTensor.indices', tensor.indices.name), - ('SparseTensor.values', tensor.values.name), - ('SparseTensor.dense_shape', tensor.dense_shape.name), - ] - elif isinstance(tensor, tf.Tensor): - result = [('Tensor', tensor.name)] - else: - result = [('Unknown', str(tensor))] - return result - - def flatten(target: List[List[Any]]) -> List[Any]: - return list(itertools.chain.from_iterable(target)) - - def log_list(name: str, target: List[Any]) -> None: - tf.compat.v1.logging.info('%s = [', name) - for elem_type, elem_name in flatten( - [create_tuple_list(x) for x in target] - ): - tf.compat.v1.logging.info("('%s', '%s'),", elem_type, elem_name) - tf.compat.v1.logging.info(']') - - tf.compat.v1.logging.info( - '-------------------- fetches and feeds information' - ) - log_list('fetches', fetches) - tf.compat.v1.logging.info('') - log_list('feed_list', feed_list) - tf.compat.v1.logging.info( - '-------------------- end fetches and feeds information' - ) - - def get_features_predictions_labels_dicts( - self, - ) -> Tuple[ - types.TensorTypeMaybeDict, - types.TensorTypeMaybeDict, - types.TensorTypeMaybeDict, - ]: - """Returns features, predictions, labels dictionaries (or values). - - The dictionaries contain references to the nodes, so they can be used - to construct new metrics similarly to how metrics can be constructed in - the Trainer. - - Returns: - Tuple of features, predictions, labels dictionaries (or values). - """ - features = {} - for key, value in self._features_map.items(): - features[key] = value - - predictions = {} - for key, value in self._predictions_map.items(): - predictions[key] = value - # Unnest if it wasn't a dictionary to begin with. - default_predictions_key = util.default_dict_key( - eval_constants.PREDICTIONS_NAME - ) - if list(predictions) == [default_predictions_key]: - predictions = predictions[default_predictions_key] - - labels = {} - for key, value in self._labels_map.items(): - labels[key] = value - # Unnest if it wasn't a dictionary to begin with. - default_labels_key = util.default_dict_key(eval_constants.LABELS_NAME) - if list(labels) == [default_labels_key]: - labels = labels[default_labels_key] - - return (features, predictions, labels) - - def _perform_metrics_update_list(self, examples_list: List[Any]) -> None: - """Run a metrics update on a list of examples.""" - try: - if self._perform_metrics_update_fn is None: - raise ValueError('_perform_metrics_update_fn is None.') - self._perform_metrics_update_fn(*[examples_list]) - - except ( - RuntimeError, - TypeError, - ValueError, - tf.errors.OpError, - ) as exception: - general_util.reraise_augmented( - exception, 'raw_input = %s' % examples_list - ) - - def metrics_reset_update_get( - self, features_predictions_labels: types.FeaturesPredictionsLabels - ) -> List[Any]: - """Run the metrics reset, update, get operations on a single FPL.""" - return self.metrics_reset_update_get_list([features_predictions_labels]) - - def metrics_reset_update_get_list( - self, examples_list: List[Any] - ) -> List[Any]: - """Run the metrics reset, update, get operations on a list of FPLs.""" - with self._lock: - # Note that due to tf op reordering issues on some hardware, DO NOT merge - # these operations into a single atomic reset_update_get operation. - # - # Try to run the entire batch size through. If we hit a functional issue, - # attempt to run the examples through serially - batch_size = len(examples_list) - try: - self._reset_metric_variables() - self._perform_metrics_update_list(examples_list) - self._batch_size.update(batch_size) - except ( - ValueError, - tf.errors.InvalidArgumentError, - tf.errors.ResourceExhaustedError, - ) as e: - self._reset_metric_variables() - self._batch_size_failed.update(batch_size) - tf.compat.v1.logging.warning( - 'Large batch_size %s failed with error %s. ' - 'Attempting to run batch through serially.', - batch_size, - e, + def get_features_predictions_labels_dicts( + self, + ) -> Tuple[ + types.TensorTypeMaybeDict, + types.TensorTypeMaybeDict, + types.TensorTypeMaybeDict, + ]: + """Returns features, predictions, labels dictionaries (or values). + + The dictionaries contain references to the nodes, so they can be used + to construct new metrics similarly to how metrics can be constructed in + the Trainer. + + Returns + ------- + Tuple of features, predictions, labels dictionaries (or values). + """ + features = {} + for key, value in self._features_map.items(): + features[key] = value + + predictions = {} + for key, value in self._predictions_map.items(): + predictions[key] = value + # Unnest if it wasn't a dictionary to begin with. + default_predictions_key = util.default_dict_key(eval_constants.PREDICTIONS_NAME) + if list(predictions) == [default_predictions_key]: + predictions = predictions[default_predictions_key] + + labels = {} + for key, value in self._labels_map.items(): + labels[key] = value + # Unnest if it wasn't a dictionary to begin with. + default_labels_key = util.default_dict_key(eval_constants.LABELS_NAME) + if list(labels) == [default_labels_key]: + labels = labels[default_labels_key] + + return (features, predictions, labels) + + def _perform_metrics_update_list(self, examples_list: List[Any]) -> None: + """Run a metrics update on a list of examples.""" + try: + if self._perform_metrics_update_fn is None: + raise ValueError("_perform_metrics_update_fn is None.") + self._perform_metrics_update_fn(*[examples_list]) + + except ( + RuntimeError, + TypeError, + ValueError, + tf.errors.OpError, + ) as exception: + general_util.reraise_augmented(exception, "raw_input = %s" % examples_list) + + def metrics_reset_update_get( + self, features_predictions_labels: types.FeaturesPredictionsLabels + ) -> List[Any]: + """Run the metrics reset, update, get operations on a single FPL.""" + return self.metrics_reset_update_get_list([features_predictions_labels]) + + def metrics_reset_update_get_list(self, examples_list: List[Any]) -> List[Any]: + """Run the metrics reset, update, get operations on a list of FPLs.""" + with self._lock: + # Note that due to tf op reordering issues on some hardware, DO NOT merge + # these operations into a single atomic reset_update_get operation. + # + # Try to run the entire batch size through. If we hit a functional issue, + # attempt to run the examples through serially + batch_size = len(examples_list) + try: + self._reset_metric_variables() + self._perform_metrics_update_list(examples_list) + self._batch_size.update(batch_size) + except ( + ValueError, + tf.errors.InvalidArgumentError, + tf.errors.ResourceExhaustedError, + ) as e: + self._reset_metric_variables() + self._batch_size_failed.update(batch_size) + tf.compat.v1.logging.warning( + "Large batch_size %s failed with error %s. " + "Attempting to run batch through serially.", + batch_size, + e, + ) + for example in examples_list: + self._perform_metrics_update_list([example]) + self._batch_size.update(1) + return self._get_metric_variables() + + def _get_metric_variables(self) -> List[Any]: + # Lock should be acquired before calling this function. + return self._session.run(fetches=self._metric_variable_nodes) + + def get_metric_variables(self) -> List[Any]: + """Returns a list containing the metric variable values.""" + with self._lock: + return self._get_metric_variables() + + def _create_feed_for_metric_variables( + self, metric_variable_values: List[Any] + ) -> Dict[types.TensorType, Any]: + """Returns a feed dict for feeding metric variables values to set them. + + Args: + ---- + metric_variable_values: Metric variable values retrieved using + get_metric_variables, for instance. + + Returns: + ------- + A feed dict for feeding metric variables values to the placeholders + constructed for setting the metric variable values to the fed values. + """ + result = {} + for node, value in zip( + self._metric_variable_placeholders, metric_variable_values + ): + result[node] = value + return result + + def _set_metric_variables(self, metric_variable_values: List[Any]) -> None: + # Lock should be acquired before calling this function. + return self._session.run( + fetches=self._all_metric_variable_assign_ops, + feed_dict=self._create_feed_for_metric_variables(metric_variable_values), ) - for example in examples_list: - self._perform_metrics_update_list([example]) - self._batch_size.update(1) - return self._get_metric_variables() - - def _get_metric_variables(self) -> List[Any]: - # Lock should be acquired before calling this function. - return self._session.run(fetches=self._metric_variable_nodes) - - def get_metric_variables(self) -> List[Any]: - """Returns a list containing the metric variable values.""" - with self._lock: - return self._get_metric_variables() - - def _create_feed_for_metric_variables( - self, metric_variable_values: List[Any] - ) -> Dict[types.TensorType, Any]: - """Returns a feed dict for feeding metric variables values to set them. - - Args: - metric_variable_values: Metric variable values retrieved using - get_metric_variables, for instance. - - Returns: - A feed dict for feeding metric variables values to the placeholders - constructed for setting the metric variable values to the fed values. - """ - result = {} - for node, value in zip( - self._metric_variable_placeholders, metric_variable_values - ): - result[node] = value - return result - - def _set_metric_variables(self, metric_variable_values: List[Any]) -> None: - # Lock should be acquired before calling this function. - return self._session.run( - fetches=self._all_metric_variable_assign_ops, - feed_dict=self._create_feed_for_metric_variables( - metric_variable_values - ), - ) - - def set_metric_variables(self, metric_variable_values: List[Any]) -> None: - """Set metric variable values to the given values.""" - with self._lock: - self._set_metric_variables(metric_variable_values) - - def _reset_metric_variables(self) -> None: - # Lock should be acquired before calling this function. - self._session.run(self._reset_variables_op) - - def reset_metric_variables(self) -> None: - """Reset metric variable values to their initial values.""" - with self._lock: - self._reset_metric_variables() - - def _get_metric_values(self) -> Dict[str, Any]: - # Lock should be acquired before calling this function. - metric_values = self._session.run(fetches=self._metric_value_ops) - return dict(zip(self._metric_names, metric_values)) - - def get_metric_values(self) -> Dict[str, Any]: - """Retrieve metric values.""" - with self._lock: - return self._get_metric_values() - - def metrics_set_variables_and_get_values( - self, metric_variable_values: List[Any] - ) -> Dict[str, Any]: - with self._lock: - self._set_metric_variables(metric_variable_values) - return self._get_metric_values() + + def set_metric_variables(self, metric_variable_values: List[Any]) -> None: + """Set metric variable values to the given values.""" + with self._lock: + self._set_metric_variables(metric_variable_values) + + def _reset_metric_variables(self) -> None: + # Lock should be acquired before calling this function. + self._session.run(self._reset_variables_op) + + def reset_metric_variables(self) -> None: + """Reset metric variable values to their initial values.""" + with self._lock: + self._reset_metric_variables() + + def _get_metric_values(self) -> Dict[str, Any]: + # Lock should be acquired before calling this function. + metric_values = self._session.run(fetches=self._metric_value_ops) + return dict(zip(self._metric_names, metric_values)) + + def get_metric_values(self) -> Dict[str, Any]: + """Retrieve metric values.""" + with self._lock: + return self._get_metric_values() + + def metrics_set_variables_and_get_values( + self, metric_variable_values: List[Any] + ) -> Dict[str, Any]: + with self._lock: + self._set_metric_variables(metric_variable_values) + return self._get_metric_values() diff --git a/tensorflow_model_analysis/eval_saved_model/constants.py b/tensorflow_model_analysis/eval_saved_model/constants.py index 09f8379c72..6fddf6545a 100644 --- a/tensorflow_model_analysis/eval_saved_model/constants.py +++ b/tensorflow_model_analysis/eval_saved_model/constants.py @@ -13,26 +13,26 @@ # limitations under the License. """Constants for the EvalSavedModel.""" -EVAL_SAVED_MODEL_EXPORT_NAME = 'TFMA' -EVAL_SAVED_MODEL_TAG = 'eval_saved_model' +EVAL_SAVED_MODEL_EXPORT_NAME = "TFMA" +EVAL_SAVED_MODEL_TAG = "eval_saved_model" -SIGNATURE_DEF_INPUTS_PREFIX = 'inputs' -SIGNATURE_DEF_INPUT_REFS_KEY = 'input_refs' -SIGNATURE_DEF_ITERATOR_INITIALIZER_KEY = 'iterator_initializer' -SIGNATURE_DEF_TFMA_VERSION_KEY = 'tfma/version' +SIGNATURE_DEF_INPUTS_PREFIX = "inputs" +SIGNATURE_DEF_INPUT_REFS_KEY = "input_refs" +SIGNATURE_DEF_ITERATOR_INITIALIZER_KEY = "iterator_initializer" +SIGNATURE_DEF_TFMA_VERSION_KEY = "tfma/version" -FEATURES_NAME = 'features' -LABELS_NAME = 'labels' +FEATURES_NAME = "features" +LABELS_NAME = "labels" # TODO(b/79777718): Really tf.saved_model.tag_constants.EVAL -EVAL_TAG = 'eval' +EVAL_TAG = "eval" # TODO(b/79777718): Really model_fn.EXPORT_TAG_MAP[ModeKeys.EVAL] -DEFAULT_EVAL_SIGNATURE_DEF_KEY = 'eval' +DEFAULT_EVAL_SIGNATURE_DEF_KEY = "eval" # TODO(b/79777718): Really tf.estimator.export.EvalOutput.PREDICTIONS_NAME -PREDICTIONS_NAME = 'predictions' +PREDICTIONS_NAME = "predictions" # TODO(b/79777718): Really tf.estimator.export.EvalOutput.METRICS_NAME -METRICS_NAME = 'metrics' +METRICS_NAME = "metrics" # TODO(b/79777718): Really tf.estimator.export.EvalOutput.METRIC_VALUE_SUFFIX -METRIC_VALUE_SUFFIX = 'value' +METRIC_VALUE_SUFFIX = "value" # TODO(b/79777718): Really tf.estimator.export.EvalOutput.METRIC_UPDATE_SUFFIX -METRIC_UPDATE_SUFFIX = 'update_op' +METRIC_UPDATE_SUFFIX = "update_op" diff --git a/tensorflow_model_analysis/evaluators/__init__.py b/tensorflow_model_analysis/evaluators/__init__.py index fa774c312e..67f3ae3b0e 100644 --- a/tensorflow_model_analysis/evaluators/__init__.py +++ b/tensorflow_model_analysis/evaluators/__init__.py @@ -14,17 +14,22 @@ """Init module for TensorFlow Model Analysis evaluators.""" # pylint: disable=g-importing-member -from tensorflow_model_analysis.evaluators.analysis_table_evaluator import AnalysisTableEvaluator -from tensorflow_model_analysis.evaluators.evaluator import Evaluation -from tensorflow_model_analysis.evaluators.evaluator import Evaluator -from tensorflow_model_analysis.evaluators.evaluator import verify_evaluator -from tensorflow_model_analysis.evaluators.metrics_plots_and_validations_evaluator import MetricsPlotsAndValidationsEvaluator +from tensorflow_model_analysis.evaluators.analysis_table_evaluator import ( + AnalysisTableEvaluator, +) +from tensorflow_model_analysis.evaluators.evaluator import ( + Evaluation, + Evaluator, + verify_evaluator, +) +from tensorflow_model_analysis.evaluators.metrics_plots_and_validations_evaluator import ( + MetricsPlotsAndValidationsEvaluator, +) __all__ = [ - "AnalysisTableEvaluator", - "Evaluation", - "Evaluator" - "MetricsAndPlotsEvaluator", - "MetricsPlotsAndValidationsEvaluator", - "verify_evaluator", + "AnalysisTableEvaluator", + "Evaluation", + "Evaluator" "MetricsAndPlotsEvaluator", + "MetricsPlotsAndValidationsEvaluator", + "verify_evaluator", ] diff --git a/tensorflow_model_analysis/evaluators/analysis_table_evaluator.py b/tensorflow_model_analysis/evaluators/analysis_table_evaluator.py index 55fddcebc2..b889862f27 100644 --- a/tensorflow_model_analysis/evaluators/analysis_table_evaluator.py +++ b/tensorflow_model_analysis/evaluators/analysis_table_evaluator.py @@ -16,6 +16,7 @@ from typing import Any, Dict, Iterable, Optional, Union import apache_beam as beam + from tensorflow_model_analysis import constants from tensorflow_model_analysis.api import types from tensorflow_model_analysis.evaluators import evaluator @@ -28,42 +29,45 @@ def AnalysisTableEvaluator( # pylint: disable=invalid-name include: Optional[Union[Iterable[str], Dict[str, Any]]] = None, exclude: Optional[Union[Iterable[str], Dict[str, Any]]] = None, ) -> evaluator.Evaluator: - """Creates an Evaluator for returning Extracts data for analysis. + """Creates an Evaluator for returning Extracts data for analysis. - If both include and exclude are None then tfma.INPUT_KEY extracts will be - excluded by default. + If both include and exclude are None then tfma.INPUT_KEY extracts will be + excluded by default. - Args: - key: Name to use for key in Evaluation output. - run_after: Extractor to run after (None means before any extractors). - include: List or map of keys to include in output. Keys starting with '_' - are automatically filtered out at write time. If a map of keys is passed - then the keys and sub-keys that exist in the map will be included in the - output. An empty dict behaves as a wildcard matching all keys or the value - itself. Since matching on feature values is not currently supported, an - empty dict must be used to represent the leaf nodes. For example: {'key1': - {'key1-subkey': {}}, 'key2': {}}. - exclude: List or map of keys to exclude from output. If a map of keys is - passed then the keys and sub-keys that exist in the map will be excluded - from the output. An empty dict behaves as a wildcard matching all keys or - the value itself. Since matching on feature values is not currently - supported, an empty dict must be used to represent the leaf nodes. For - example, {'key1': {'key1-subkey': {}}, 'key2': {}}. + Args: + ---- + key: Name to use for key in Evaluation output. + run_after: Extractor to run after (None means before any extractors). + include: List or map of keys to include in output. Keys starting with '_' + are automatically filtered out at write time. If a map of keys is passed + then the keys and sub-keys that exist in the map will be included in the + output. An empty dict behaves as a wildcard matching all keys or the value + itself. Since matching on feature values is not currently supported, an + empty dict must be used to represent the leaf nodes. For example: {'key1': + {'key1-subkey': {}}, 'key2': {}}. + exclude: List or map of keys to exclude from output. If a map of keys is + passed then the keys and sub-keys that exist in the map will be excluded + from the output. An empty dict behaves as a wildcard matching all keys or + the value itself. Since matching on feature values is not currently + supported, an empty dict must be used to represent the leaf nodes. For + example, {'key1': {'key1-subkey': {}}, 'key2': {}}. - Returns: - Evaluator for collecting analysis data. The output is stored under the key - 'analysis'. + Returns: + ------- + Evaluator for collecting analysis data. The output is stored under the key + 'analysis'. - Raises: - ValueError: If both include and exclude are used. - """ - # pylint: disable=no-value-for-parameter - return evaluator.Evaluator( - stage_name='EvaluateExtracts', - run_after=run_after, - ptransform=EvaluateExtracts(key=key, include=include, exclude=exclude), - ) - # pylint: enable=no-value-for-parameter + Raises: + ------ + ValueError: If both include and exclude are used. + """ + # pylint: disable=no-value-for-parameter + return evaluator.Evaluator( + stage_name="EvaluateExtracts", + run_after=run_after, + ptransform=EvaluateExtracts(key=key, include=include, exclude=exclude), + ) + # pylint: enable=no-value-for-parameter @beam.ptransform_fn @@ -75,34 +79,36 @@ def EvaluateExtracts( # pylint: disable=invalid-name include: Optional[Union[Iterable[str], Dict[str, Any]]] = None, exclude: Optional[Union[Iterable[str], Dict[str, Any]]] = None, ) -> evaluator.Evaluation: - """Creates Evaluation output for extracts. + """Creates Evaluation output for extracts. - If both include and exclude are None then tfma.INPUT_KEY extracts will be - excluded by default. + If both include and exclude are None then tfma.INPUT_KEY extracts will be + excluded by default. - Args: - extracts: PCollection of Extracts. - key: Name to use for key in Evaluation output. - include: List or map of keys to include in output. Keys starting with '_' - are automatically filtered out at write time. If a map of keys is passed - then the keys and sub-keys that exist in the map will be included in the - output. An empty dict behaves as a wildcard matching all keys or the value - itself. Since matching on feature values is not currently supported, an - empty dict must be used to represent the leaf nodes. For example: {'key1': - {'key1-subkey': {}}, 'key2': {}}. - exclude: List or map of keys to exclude from output. If a map of keys is - passed then the keys and sub-keys that exist in the map will be excluded - from the output. An empty dict behaves as a wildcard matching all keys or - the value itself. Since matching on feature values is not currently - supported, an empty dict must be used to represent the leaf nodes. For - example, {'key1': {'key1-subkey': {}}, 'key2': {}}. + Args: + ---- + extracts: PCollection of Extracts. + key: Name to use for key in Evaluation output. + include: List or map of keys to include in output. Keys starting with '_' + are automatically filtered out at write time. If a map of keys is passed + then the keys and sub-keys that exist in the map will be included in the + output. An empty dict behaves as a wildcard matching all keys or the value + itself. Since matching on feature values is not currently supported, an + empty dict must be used to represent the leaf nodes. For example: {'key1': + {'key1-subkey': {}}, 'key2': {}}. + exclude: List or map of keys to exclude from output. If a map of keys is + passed then the keys and sub-keys that exist in the map will be excluded + from the output. An empty dict behaves as a wildcard matching all keys or + the value itself. Since matching on feature values is not currently + supported, an empty dict must be used to represent the leaf nodes. For + example, {'key1': {'key1-subkey': {}}, 'key2': {}}. - Returns: - Evaluation containing PCollection of Extracts. - """ - if include is None and exclude is None: - exclude = [constants.INPUT_KEY] - filtered = extracts - if include or exclude: - filtered = extracts | extractor.Filter(include=include, exclude=exclude) - return {key: filtered} + Returns: + ------- + Evaluation containing PCollection of Extracts. + """ + if include is None and exclude is None: + exclude = [constants.INPUT_KEY] + filtered = extracts + if include or exclude: + filtered = extracts | extractor.Filter(include=include, exclude=exclude) + return {key: filtered} diff --git a/tensorflow_model_analysis/evaluators/analysis_table_evaluator_test.py b/tensorflow_model_analysis/evaluators/analysis_table_evaluator_test.py index 55dba4d2b2..dbbefa6c0f 100644 --- a/tensorflow_model_analysis/evaluators/analysis_table_evaluator_test.py +++ b/tensorflow_model_analysis/evaluators/analysis_table_evaluator_test.py @@ -14,84 +14,82 @@ """Tests for analysis_table_evaluator.""" import apache_beam as beam -from apache_beam.testing import util import tensorflow as tf +from apache_beam.testing import util + from tensorflow_model_analysis import constants from tensorflow_model_analysis.evaluators import analysis_table_evaluator from tensorflow_model_analysis.utils import test_util class AnalysisTableEvaulatorTest(test_util.TensorflowModelAnalysisTest): - - def testIncludeFilter(self): - with beam.Pipeline() as pipeline: - got = ( - pipeline - | 'Create' >> beam.Create([{'a': 1, 'b': 2}]) - | 'EvaluateExtracts' - >> analysis_table_evaluator.EvaluateExtracts(include=['a']) - ) - - def check_result(got): - try: - self.assertEqual(got, [{'a': 1}]) - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(got[constants.ANALYSIS_KEY], check_result) - - def testExcludeFilter(self): - with beam.Pipeline() as pipeline: - got = ( - pipeline - | 'Create' >> beam.Create([{'a': 1, 'b': 2}]) - | 'EvaluateExtracts' - >> analysis_table_evaluator.EvaluateExtracts(exclude=['a']) - ) - - def check_result(got): - try: - self.assertEqual(got, [{'b': 2}]) - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(got[constants.ANALYSIS_KEY], check_result) - - def testNoIncludeOrExcludeFilters(self): - with beam.Pipeline() as pipeline: - got = ( - pipeline - | 'Create' - >> beam.Create([{constants.INPUT_KEY: 'input', 'other': 2}]) - | 'EvaluateExtracts' >> analysis_table_evaluator.EvaluateExtracts() - ) - - def check_result(got): - try: - self.assertEqual(got, [{'other': 2}]) - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(got[constants.ANALYSIS_KEY], check_result) - - def testEmptyExcludeFilters(self): - with beam.Pipeline() as pipeline: - got = ( - pipeline - | 'Create' - >> beam.Create([{constants.INPUT_KEY: 'input', 'other': 2}]) - | 'EvaluateExtracts' - >> analysis_table_evaluator.EvaluateExtracts(exclude=[]) - ) - - def check_result(got): - try: - self.assertEqual(got, [{constants.INPUT_KEY: 'input', 'other': 2}]) - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(got[constants.ANALYSIS_KEY], check_result) - - -if __name__ == '__main__': - tf.test.main() + def testIncludeFilter(self): + with beam.Pipeline() as pipeline: + got = ( + pipeline + | "Create" >> beam.Create([{"a": 1, "b": 2}]) + | "EvaluateExtracts" + >> analysis_table_evaluator.EvaluateExtracts(include=["a"]) + ) + + def check_result(got): + try: + self.assertEqual(got, [{"a": 1}]) + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(got[constants.ANALYSIS_KEY], check_result) + + def testExcludeFilter(self): + with beam.Pipeline() as pipeline: + got = ( + pipeline + | "Create" >> beam.Create([{"a": 1, "b": 2}]) + | "EvaluateExtracts" + >> analysis_table_evaluator.EvaluateExtracts(exclude=["a"]) + ) + + def check_result(got): + try: + self.assertEqual(got, [{"b": 2}]) + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(got[constants.ANALYSIS_KEY], check_result) + + def testNoIncludeOrExcludeFilters(self): + with beam.Pipeline() as pipeline: + got = ( + pipeline + | "Create" >> beam.Create([{constants.INPUT_KEY: "input", "other": 2}]) + | "EvaluateExtracts" >> analysis_table_evaluator.EvaluateExtracts() + ) + + def check_result(got): + try: + self.assertEqual(got, [{"other": 2}]) + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(got[constants.ANALYSIS_KEY], check_result) + + def testEmptyExcludeFilters(self): + with beam.Pipeline() as pipeline: + got = ( + pipeline + | "Create" >> beam.Create([{constants.INPUT_KEY: "input", "other": 2}]) + | "EvaluateExtracts" + >> analysis_table_evaluator.EvaluateExtracts(exclude=[]) + ) + + def check_result(got): + try: + self.assertEqual(got, [{constants.INPUT_KEY: "input", "other": 2}]) + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(got[constants.ANALYSIS_KEY], check_result) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow_model_analysis/evaluators/confidence_intervals_util.py b/tensorflow_model_analysis/evaluators/confidence_intervals_util.py index e16c92fd6b..729abc8c0a 100644 --- a/tensorflow_model_analysis/evaluators/confidence_intervals_util.py +++ b/tensorflow_model_analysis/evaluators/confidence_intervals_util.py @@ -16,177 +16,182 @@ import collections import numbers from typing import Iterable, NamedTuple, Optional, Sequence, Set, Tuple + import apache_beam as beam import numpy as np + from tensorflow_model_analysis import constants from tensorflow_model_analysis.api import types from tensorflow_model_analysis.metrics import metric_types -SampleMetrics = NamedTuple( - 'SampleMetrics', [('metrics', metric_types.MetricsDict), ('sample_id', int)] -) + +class SampleMetrics(NamedTuple): + metrics: metric_types.MetricsDict + sample_id: int def mean_and_std( values: Sequence[types.MetricValueType], ddof: int ) -> Tuple[types.MetricValueType, types.MetricValueType]: - """Computes mean and standard deviation for (structued) metric values. - - Args: - values: An iterable of values for which to compute the mean and standard - deviation - ddof: The difference in degrees of freedom to use for the standard deviation - computation, relative to the length of values. For example, if len(values) - == 10, and ddof is 1, the standard deviation will be computed with 9 - degreees of freedom - - Returns: - A 2-tuple in which the first element is the mean and the second element is - the standard deviation. The types of the mean and standard deviation will - be the same as the type of each element in values. - """ - total = None - for value in values: - if total is None: - total = value - else: - total = total + value - mean = total / len(values) - squared_residual_total = None - for value in values: - squared_residual = (value - mean) ** 2 - if squared_residual_total is None: - squared_residual_total = squared_residual - else: - squared_residual_total = squared_residual_total + squared_residual - std = (squared_residual_total / (len(values) - ddof)) ** 0.5 - return mean, std - - -class SampleCombineFn(beam.CombineFn): - """Computes the standard deviation for each metric from samples.""" - - class SampleAccumulator: - - __slots__ = ['point_estimates', 'num_samples', 'metric_samples'] - - def __init__(self): - self.point_estimates = None - self.num_samples = 0 - self.metric_samples = collections.defaultdict(list) - - def __init__( - self, - num_samples: int, - full_sample_id: int, - skip_ci_metric_keys: Optional[Set[metric_types.MetricKey]] = None, - ): - """Initializes a SampleCombineFn. + """Computes mean and standard deviation for (structued) metric values. Args: - num_samples: The number of samples computed per slice. - full_sample_id: The sample_id corresponding to the unsampled metrics. - skip_ci_metric_keys: Set of metric keys for which to skip confidence - interval computation. For metric keys in this set, just the unsampled - value will be returned. + ---- + values: An iterable of values for which to compute the mean and standard + deviation + ddof: The difference in degrees of freedom to use for the standard deviation + computation, relative to the length of values. For example, if len(values) + == 10, and ddof is 1, the standard deviation will be computed with 9 + degreees of freedom + + Returns: + ------- + A 2-tuple in which the first element is the mean and the second element is + the standard deviation. The types of the mean and standard deviation will + be the same as the type of each element in values. """ - self._num_samples = num_samples - self._full_sample_id = full_sample_id - self._skip_ci_metric_keys = skip_ci_metric_keys - self._num_slices_counter = beam.metrics.Metrics.counter( - constants.METRICS_NAMESPACE, 'num_slices' - ) - self._missing_samples_counter = beam.metrics.Metrics.counter( - constants.METRICS_NAMESPACE, 'num_slices_missing_samples' - ) - self._missing_metric_samples_counter = beam.metrics.Metrics.counter( - constants.METRICS_NAMESPACE, 'num_slices_missing_metric_samples' - ) - - def create_accumulator(self) -> 'SampleCombineFn.SampleAccumulator': - return SampleCombineFn.SampleAccumulator() - - def add_input( - self, - accumulator: 'SampleCombineFn.SampleAccumulator', - sample: SampleMetrics, - ) -> 'SampleCombineFn.SampleAccumulator': - sample_id = sample.sample_id - sample = sample.metrics - if sample_id == self._full_sample_id: - accumulator.point_estimates = sample - else: - accumulator.num_samples += 1 - for metric_key, value in sample.items(): - if not ( - isinstance(value, (numbers.Number, types.StructuredMetricValue)) - or ( - isinstance(value, np.ndarray) - and np.issubdtype(value.dtype, np.number) + total = None + for value in values: + if total is None: + total = value + else: + total = total + value + mean = total / len(values) + squared_residual_total = None + for value in values: + squared_residual = (value - mean) ** 2 + if squared_residual_total is None: + squared_residual_total = squared_residual + else: + squared_residual_total = squared_residual_total + squared_residual + std = (squared_residual_total / (len(values) - ddof)) ** 0.5 + return mean, std + + +class SampleCombineFn(beam.CombineFn): + """Computes the standard deviation for each metric from samples.""" + + class SampleAccumulator: + __slots__ = ["point_estimates", "num_samples", "metric_samples"] + + def __init__(self): + self.point_estimates = None + self.num_samples = 0 + self.metric_samples = collections.defaultdict(list) + + def __init__( + self, + num_samples: int, + full_sample_id: int, + skip_ci_metric_keys: Optional[Set[metric_types.MetricKey]] = None, + ): + """Initializes a SampleCombineFn. + + Args: + ---- + num_samples: The number of samples computed per slice. + full_sample_id: The sample_id corresponding to the unsampled metrics. + skip_ci_metric_keys: Set of metric keys for which to skip confidence + interval computation. For metric keys in this set, just the unsampled + value will be returned. + """ + self._num_samples = num_samples + self._full_sample_id = full_sample_id + self._skip_ci_metric_keys = skip_ci_metric_keys + self._num_slices_counter = beam.metrics.Metrics.counter( + constants.METRICS_NAMESPACE, "num_slices" + ) + self._missing_samples_counter = beam.metrics.Metrics.counter( + constants.METRICS_NAMESPACE, "num_slices_missing_samples" + ) + self._missing_metric_samples_counter = beam.metrics.Metrics.counter( + constants.METRICS_NAMESPACE, "num_slices_missing_metric_samples" + ) + + def create_accumulator(self) -> "SampleCombineFn.SampleAccumulator": + return SampleCombineFn.SampleAccumulator() + + def add_input( + self, + accumulator: "SampleCombineFn.SampleAccumulator", + sample: SampleMetrics, + ) -> "SampleCombineFn.SampleAccumulator": + sample_id = sample.sample_id + sample = sample.metrics + if sample_id == self._full_sample_id: + accumulator.point_estimates = sample + else: + accumulator.num_samples += 1 + for metric_key, value in sample.items(): + if not ( + isinstance(value, (numbers.Number, types.StructuredMetricValue)) + or ( + isinstance(value, np.ndarray) + and np.issubdtype(value.dtype, np.number) + ) + ): + # A value must be a number, a StructuredMetricValue, or a numeric + # NumPy array. If none of those matches, skip. The absence of any + # sample for a specific metric_key will cause _validate_accumulator to + # remove all samples, which will result in no CI computation for that + # metric key. + continue + if ( + self._skip_ci_metric_keys + and metric_key in self._skip_ci_metric_keys + ): + continue + accumulator.metric_samples[metric_key].append(value) + return accumulator + + def merge_accumulators( + self, accumulators: Iterable["SampleCombineFn.SampleAccumulator"] + ) -> "SampleCombineFn.SampleAccumulator": + # treat as iterator to enforce streaming processing + accumulators = iter(accumulators) + result = next(accumulators) + for accumulator in accumulators: + if accumulator.point_estimates is not None: + result.point_estimates = accumulator.point_estimates + result.num_samples += accumulator.num_samples + for metric_key, sample_values in accumulator.metric_samples.items(): + result.metric_samples[metric_key].extend(sample_values) + return result + + def _validate_accumulator( + self, accumulator: "SampleCombineFn.SampleAccumulator" + ) -> "SampleCombineFn.SampleAccumulator": + self._num_slices_counter.inc(1) + error_metric_key = metric_types.MetricKey(constants.ERROR_METRIC_NAME) + if accumulator.num_samples < self._num_samples: + self._missing_samples_counter.inc(1) + accumulator.point_estimates[error_metric_key] = ( + f"CI not computed because only {accumulator.num_samples} samples " + f"were non-empty. Expected {self._num_samples}." + ) + # If we are missing samples, clear samples for all metrics as they are all + # unusable. + accumulator.metric_samples = {} + # Check that all metrics were present in all samples + metric_incorrect_sample_counts = {} + for metric_key in accumulator.point_estimates: + if metric_key in accumulator.metric_samples: + actual_num_samples = len(accumulator.metric_samples[metric_key]) + if actual_num_samples != self._num_samples: + # If we are missing a per-metric sample, clear samples for tha metric + # as it is unusable. + del accumulator.metric_samples[metric_key] + metric_incorrect_sample_counts[metric_key] = actual_num_samples + if metric_incorrect_sample_counts: + accumulator.point_estimates[error_metric_key] = ( + "CI not computed for the following metrics due to incorrect number " + f'of samples: "{metric_incorrect_sample_counts}".' + f"Expected {self._num_samples}." ) - ): - # A value must be a number, a StructuredMetricValue, or a numeric - # NumPy array. If none of those matches, skip. The absence of any - # sample for a specific metric_key will cause _validate_accumulator to - # remove all samples, which will result in no CI computation for that - # metric key. - continue - if ( - self._skip_ci_metric_keys - and metric_key in self._skip_ci_metric_keys - ): - continue - accumulator.metric_samples[metric_key].append(value) - return accumulator - - def merge_accumulators( - self, accumulators: Iterable['SampleCombineFn.SampleAccumulator'] - ) -> 'SampleCombineFn.SampleAccumulator': - # treat as iterator to enforce streaming processing - accumulators = iter(accumulators) - result = next(accumulators) - for accumulator in accumulators: - if accumulator.point_estimates is not None: - result.point_estimates = accumulator.point_estimates - result.num_samples += accumulator.num_samples - for metric_key, sample_values in accumulator.metric_samples.items(): - result.metric_samples[metric_key].extend(sample_values) - return result - - def _validate_accumulator( - self, accumulator: 'SampleCombineFn.SampleAccumulator' - ) -> 'SampleCombineFn.SampleAccumulator': - self._num_slices_counter.inc(1) - error_metric_key = metric_types.MetricKey(constants.ERROR_METRIC_NAME) - if accumulator.num_samples < self._num_samples: - self._missing_samples_counter.inc(1) - accumulator.point_estimates[error_metric_key] = ( - f'CI not computed because only {accumulator.num_samples} samples ' - f'were non-empty. Expected {self._num_samples}.' - ) - # If we are missing samples, clear samples for all metrics as they are all - # unusable. - accumulator.metric_samples = {} - # Check that all metrics were present in all samples - metric_incorrect_sample_counts = {} - for metric_key in accumulator.point_estimates: - if metric_key in accumulator.metric_samples: - actual_num_samples = len(accumulator.metric_samples[metric_key]) - if actual_num_samples != self._num_samples: - # If we are missing a per-metric sample, clear samples for tha metric - # as it is unusable. - del accumulator.metric_samples[metric_key] - metric_incorrect_sample_counts[metric_key] = actual_num_samples - if metric_incorrect_sample_counts: - accumulator.point_estimates[error_metric_key] = ( - 'CI not computed for the following metrics due to incorrect number ' - f'of samples: "{metric_incorrect_sample_counts}".' - f'Expected {self._num_samples}.' - ) - return accumulator - - # TODO(b/195132951): replace with @abc.abstractmethod - def extract_output( - self, accumulator: 'SampleCombineFn.SampleAccumulator' - ) -> metric_types.MetricsDict: - raise NotImplementedError('Must be implemented in subclasses.') + return accumulator + + # TODO(b/195132951): replace with @abc.abstractmethod + def extract_output( + self, accumulator: "SampleCombineFn.SampleAccumulator" + ) -> metric_types.MetricsDict: + raise NotImplementedError("Must be implemented in subclasses.") diff --git a/tensorflow_model_analysis/evaluators/confidence_intervals_util_test.py b/tensorflow_model_analysis/evaluators/confidence_intervals_util_test.py index 9517fe8cce..e46c8478ab 100644 --- a/tensorflow_model_analysis/evaluators/confidence_intervals_util_test.py +++ b/tensorflow_model_analysis/evaluators/confidence_intervals_util_test.py @@ -13,317 +13,310 @@ # limitations under the License. """Tests for confidence_intervals_util.""" -from absl.testing import absltest -from absl.testing import parameterized import apache_beam as beam -from apache_beam.testing import util import numpy as np +from absl.testing import absltest, parameterized +from apache_beam.testing import util from numpy import testing + from tensorflow_model_analysis.evaluators import confidence_intervals_util -from tensorflow_model_analysis.metrics import binary_confusion_matrices -from tensorflow_model_analysis.metrics import metric_types +from tensorflow_model_analysis.metrics import binary_confusion_matrices, metric_types _FULL_SAMPLE_ID = -1 class _ValidateSampleCombineFn(confidence_intervals_util.SampleCombineFn): - - def extract_output( - self, - accumulator: confidence_intervals_util.SampleCombineFn.SampleAccumulator, - ) -> confidence_intervals_util.SampleCombineFn.SampleAccumulator: - return self._validate_accumulator(accumulator) + def extract_output( + self, + accumulator: confidence_intervals_util.SampleCombineFn.SampleAccumulator, + ) -> confidence_intervals_util.SampleCombineFn.SampleAccumulator: + return self._validate_accumulator(accumulator) class ConfidenceIntervalsUtilTest(parameterized.TestCase): - - @parameterized.named_parameters( - { - 'testcase_name': '_ints', - 'values': [0, 1, 2], - 'ddof': 1, - 'expected_mean': 1, - 'expected_std': np.std([0, 1, 2], ddof=1), - }, - { - 'testcase_name': '_ndarrays', - 'values': [np.array([0]), np.array([1]), np.array([2])], - 'ddof': 1, - 'expected_mean': np.array([1]), - 'expected_std': np.array([np.std([0, 1, 2], ddof=1)]), - }, - { - 'testcase_name': '_confusion_matrices', - 'values': [ - binary_confusion_matrices.Matrices( - thresholds=[0.5], tp=[0], fp=[1], tn=[2], fn=[3] - ), - binary_confusion_matrices.Matrices( - thresholds=[0.5], tp=[4], fp=[5], tn=[6], fn=[7] - ), - binary_confusion_matrices.Matrices( - thresholds=[0.5], tp=[8], fp=[9], tn=[10], fn=[11] - ), - ], - 'ddof': 1, - 'expected_mean': binary_confusion_matrices.Matrices( - thresholds=[0.5], - tp=np.mean([0, 4, 8]), - fp=np.mean([1, 5, 9]), - tn=np.mean([2, 6, 10]), - fn=np.mean([3, 7, 11]), - ), - 'expected_std': binary_confusion_matrices.Matrices( - thresholds=[0.5], - tp=np.std([0, 4, 8], ddof=1), - fp=np.std([1, 5, 9], ddof=1), - tn=np.std([2, 6, 10], ddof=1), - fn=np.std([3, 7, 11], ddof=1), - ), - }, - ) - def test_mean_and_std(self, values, ddof, expected_mean, expected_std): - actual_mean, actual_std = confidence_intervals_util.mean_and_std( - values, ddof + @parameterized.named_parameters( + { + "testcase_name": "_ints", + "values": [0, 1, 2], + "ddof": 1, + "expected_mean": 1, + "expected_std": np.std([0, 1, 2], ddof=1), + }, + { + "testcase_name": "_ndarrays", + "values": [np.array([0]), np.array([1]), np.array([2])], + "ddof": 1, + "expected_mean": np.array([1]), + "expected_std": np.array([np.std([0, 1, 2], ddof=1)]), + }, + { + "testcase_name": "_confusion_matrices", + "values": [ + binary_confusion_matrices.Matrices( + thresholds=[0.5], tp=[0], fp=[1], tn=[2], fn=[3] + ), + binary_confusion_matrices.Matrices( + thresholds=[0.5], tp=[4], fp=[5], tn=[6], fn=[7] + ), + binary_confusion_matrices.Matrices( + thresholds=[0.5], tp=[8], fp=[9], tn=[10], fn=[11] + ), + ], + "ddof": 1, + "expected_mean": binary_confusion_matrices.Matrices( + thresholds=[0.5], + tp=np.mean([0, 4, 8]), + fp=np.mean([1, 5, 9]), + tn=np.mean([2, 6, 10]), + fn=np.mean([3, 7, 11]), + ), + "expected_std": binary_confusion_matrices.Matrices( + thresholds=[0.5], + tp=np.std([0, 4, 8], ddof=1), + fp=np.std([1, 5, 9], ddof=1), + tn=np.std([2, 6, 10], ddof=1), + fn=np.std([3, 7, 11], ddof=1), + ), + }, ) - self.assertEqual(expected_mean, actual_mean) - self.assertEqual(expected_std, actual_std) + def test_mean_and_std(self, values, ddof, expected_mean, expected_std): + actual_mean, actual_std = confidence_intervals_util.mean_and_std(values, ddof) + self.assertEqual(expected_mean, actual_mean) + self.assertEqual(expected_std, actual_std) - def test_sample_combine_fn(self): - metric_key = metric_types.MetricKey('metric') - array_metric_key = metric_types.MetricKey('array_metric') - missing_sample_metric_key = metric_types.MetricKey('missing_metric') - non_numeric_metric_key = metric_types.MetricKey('non_numeric_metric') - non_numeric_array_metric_key = metric_types.MetricKey('non_numeric_array') - mixed_type_array_metric_key = metric_types.MetricKey('mixed_type_array') - skipped_metric_key = metric_types.MetricKey('skipped_metric') - slice_key1 = (('slice_feature', 1),) - slice_key2 = (('slice_feature', 2),) - # the sample value is irrelevant for this test as we only verify counters. - samples = [ - # unsampled value for slice 1 - ( - slice_key1, - confidence_intervals_util.SampleMetrics( - sample_id=_FULL_SAMPLE_ID, - metrics={ - metric_key: 2.1, - array_metric_key: np.array([1, 2]), - missing_sample_metric_key: 3, - non_numeric_metric_key: 'a', - non_numeric_array_metric_key: np.array(['a', 'aaa']), - mixed_type_array_metric_key: np.array(['a']), - skipped_metric_key: 16, - }, + def test_sample_combine_fn(self): + metric_key = metric_types.MetricKey("metric") + array_metric_key = metric_types.MetricKey("array_metric") + missing_sample_metric_key = metric_types.MetricKey("missing_metric") + non_numeric_metric_key = metric_types.MetricKey("non_numeric_metric") + non_numeric_array_metric_key = metric_types.MetricKey("non_numeric_array") + mixed_type_array_metric_key = metric_types.MetricKey("mixed_type_array") + skipped_metric_key = metric_types.MetricKey("skipped_metric") + slice_key1 = (("slice_feature", 1),) + slice_key2 = (("slice_feature", 2),) + # the sample value is irrelevant for this test as we only verify counters. + samples = [ + # unsampled value for slice 1 + ( + slice_key1, + confidence_intervals_util.SampleMetrics( + sample_id=_FULL_SAMPLE_ID, + metrics={ + metric_key: 2.1, + array_metric_key: np.array([1, 2]), + missing_sample_metric_key: 3, + non_numeric_metric_key: "a", + non_numeric_array_metric_key: np.array(["a", "aaa"]), + mixed_type_array_metric_key: np.array(["a"]), + skipped_metric_key: 16, + }, + ), ), - ), - # sample values for slice 1 - ( - slice_key1, - confidence_intervals_util.SampleMetrics( - sample_id=0, - metrics={ - metric_key: 1, - array_metric_key: np.array([2, 3]), - missing_sample_metric_key: 2, - non_numeric_metric_key: 'b', - non_numeric_array_metric_key: np.array(['a', 'aaa']), - # one sample is an empty float array - mixed_type_array_metric_key: np.array([], dtype=float), - skipped_metric_key: 7, - }, + # sample values for slice 1 + ( + slice_key1, + confidence_intervals_util.SampleMetrics( + sample_id=0, + metrics={ + metric_key: 1, + array_metric_key: np.array([2, 3]), + missing_sample_metric_key: 2, + non_numeric_metric_key: "b", + non_numeric_array_metric_key: np.array(["a", "aaa"]), + # one sample is an empty float array + mixed_type_array_metric_key: np.array([], dtype=float), + skipped_metric_key: 7, + }, + ), ), - ), - # sample values for slice 1 missing missing_sample_metric_key - ( - slice_key1, - confidence_intervals_util.SampleMetrics( - sample_id=1, - metrics={ - metric_key: 2, - array_metric_key: np.array([0, 1]), - non_numeric_metric_key: 'c', - non_numeric_array_metric_key: np.array(['a', 'aaa']), - # one sample is a unicode array - mixed_type_array_metric_key: np.array(['a']), - skipped_metric_key: 8, - }, + # sample values for slice 1 missing missing_sample_metric_key + ( + slice_key1, + confidence_intervals_util.SampleMetrics( + sample_id=1, + metrics={ + metric_key: 2, + array_metric_key: np.array([0, 1]), + non_numeric_metric_key: "c", + non_numeric_array_metric_key: np.array(["a", "aaa"]), + # one sample is a unicode array + mixed_type_array_metric_key: np.array(["a"]), + skipped_metric_key: 8, + }, + ), ), - ), - # unsampled value for slice 2 - ( - slice_key2, - confidence_intervals_util.SampleMetrics( - sample_id=_FULL_SAMPLE_ID, - metrics={ - metric_key: 6.3, - array_metric_key: np.array([10, 20]), - missing_sample_metric_key: 6, - non_numeric_metric_key: 'd', - non_numeric_array_metric_key: np.array(['a', 'aaa']), - mixed_type_array_metric_key: np.array(['a']), - skipped_metric_key: 10000, - }, + # unsampled value for slice 2 + ( + slice_key2, + confidence_intervals_util.SampleMetrics( + sample_id=_FULL_SAMPLE_ID, + metrics={ + metric_key: 6.3, + array_metric_key: np.array([10, 20]), + missing_sample_metric_key: 6, + non_numeric_metric_key: "d", + non_numeric_array_metric_key: np.array(["a", "aaa"]), + mixed_type_array_metric_key: np.array(["a"]), + skipped_metric_key: 10000, + }, + ), ), - ), - # Only 1 sample value (missing sample ID 1) for slice 2 - ( - slice_key2, - confidence_intervals_util.SampleMetrics( - sample_id=0, - metrics={ - metric_key: 3, - array_metric_key: np.array([20, 30]), - missing_sample_metric_key: 12, - non_numeric_metric_key: 'd', - non_numeric_array_metric_key: np.array(['a', 'aaa']), - mixed_type_array_metric_key: np.array(['a']), - skipped_metric_key: 5000, - }, + # Only 1 sample value (missing sample ID 1) for slice 2 + ( + slice_key2, + confidence_intervals_util.SampleMetrics( + sample_id=0, + metrics={ + metric_key: 3, + array_metric_key: np.array([20, 30]), + missing_sample_metric_key: 12, + non_numeric_metric_key: "d", + non_numeric_array_metric_key: np.array(["a", "aaa"]), + mixed_type_array_metric_key: np.array(["a"]), + skipped_metric_key: 5000, + }, + ), ), - ), - ] + ] - with beam.Pipeline() as pipeline: - result = ( - pipeline - | 'Create' >> beam.Create(samples, reshuffle=False) - | 'CombineSamplesPerKey' - >> beam.CombinePerKey( - _ValidateSampleCombineFn( - num_samples=2, - full_sample_id=_FULL_SAMPLE_ID, - skip_ci_metric_keys=[skipped_metric_key], - ) - ) - ) + with beam.Pipeline() as pipeline: + result = ( + pipeline + | "Create" >> beam.Create(samples, reshuffle=False) + | "CombineSamplesPerKey" + >> beam.CombinePerKey( + _ValidateSampleCombineFn( + num_samples=2, + full_sample_id=_FULL_SAMPLE_ID, + skip_ci_metric_keys=[skipped_metric_key], + ) + ) + ) - def check_result(got_pcoll): - self.assertLen(got_pcoll, 2) - accumulators_by_slice = dict(got_pcoll) + def check_result(got_pcoll): + self.assertLen(got_pcoll, 2) + accumulators_by_slice = dict(got_pcoll) - self.assertIn(slice_key1, accumulators_by_slice) - slice1_accumulator = accumulators_by_slice[slice_key1] - # check unsampled value - self.assertIn(metric_key, slice1_accumulator.point_estimates) - self.assertEqual(2.1, slice1_accumulator.point_estimates[metric_key]) - # check numeric case sample_values - self.assertIn(metric_key, slice1_accumulator.metric_samples) - self.assertEqual([1, 2], slice1_accumulator.metric_samples[metric_key]) - # check numeric array in sample_values - self.assertIn(array_metric_key, slice1_accumulator.metric_samples) - array_metric_samples = slice1_accumulator.metric_samples[ - array_metric_key - ] - self.assertLen(array_metric_samples, 2) - testing.assert_array_equal(np.array([2, 3]), array_metric_samples[0]) - testing.assert_array_equal(np.array([0, 1]), array_metric_samples[1]) - # check that non-numeric metric sample_values are not present - self.assertIn( - non_numeric_metric_key, slice1_accumulator.point_estimates - ) - self.assertNotIn( - non_numeric_metric_key, slice1_accumulator.metric_samples - ) - self.assertIn( - non_numeric_array_metric_key, slice1_accumulator.point_estimates - ) - self.assertNotIn( - non_numeric_array_metric_key, slice1_accumulator.metric_samples - ) - self.assertIn( - mixed_type_array_metric_key, slice1_accumulator.point_estimates - ) - self.assertNotIn( - mixed_type_array_metric_key, slice1_accumulator.metric_samples - ) - # check that single metric missing samples generates error - error_key = metric_types.MetricKey('__ERROR__') - self.assertIn(error_key, slice1_accumulator.point_estimates) - self.assertRegex( - slice1_accumulator.point_estimates[error_key], - 'CI not computed for.*missing_metric.*', - ) - # check that skipped metrics have no samples - self.assertNotIn(skipped_metric_key, slice1_accumulator.metric_samples) + self.assertIn(slice_key1, accumulators_by_slice) + slice1_accumulator = accumulators_by_slice[slice_key1] + # check unsampled value + self.assertIn(metric_key, slice1_accumulator.point_estimates) + self.assertEqual(2.1, slice1_accumulator.point_estimates[metric_key]) + # check numeric case sample_values + self.assertIn(metric_key, slice1_accumulator.metric_samples) + self.assertEqual([1, 2], slice1_accumulator.metric_samples[metric_key]) + # check numeric array in sample_values + self.assertIn(array_metric_key, slice1_accumulator.metric_samples) + array_metric_samples = slice1_accumulator.metric_samples[ + array_metric_key + ] + self.assertLen(array_metric_samples, 2) + testing.assert_array_equal(np.array([2, 3]), array_metric_samples[0]) + testing.assert_array_equal(np.array([0, 1]), array_metric_samples[1]) + # check that non-numeric metric sample_values are not present + self.assertIn( + non_numeric_metric_key, slice1_accumulator.point_estimates + ) + self.assertNotIn( + non_numeric_metric_key, slice1_accumulator.metric_samples + ) + self.assertIn( + non_numeric_array_metric_key, slice1_accumulator.point_estimates + ) + self.assertNotIn( + non_numeric_array_metric_key, slice1_accumulator.metric_samples + ) + self.assertIn( + mixed_type_array_metric_key, slice1_accumulator.point_estimates + ) + self.assertNotIn( + mixed_type_array_metric_key, slice1_accumulator.metric_samples + ) + # check that single metric missing samples generates error + error_key = metric_types.MetricKey("__ERROR__") + self.assertIn(error_key, slice1_accumulator.point_estimates) + self.assertRegex( + slice1_accumulator.point_estimates[error_key], + "CI not computed for.*missing_metric.*", + ) + # check that skipped metrics have no samples + self.assertNotIn(skipped_metric_key, slice1_accumulator.metric_samples) - self.assertIn(slice_key2, accumulators_by_slice) - slice2_accumulator = accumulators_by_slice[slice_key2] - # check unsampled value - self.assertIn(metric_key, slice2_accumulator.point_estimates) - self.assertEqual(6.3, slice2_accumulator.point_estimates[metric_key]) - # check that entirely missing sample generates error - self.assertIn( - metric_types.MetricKey('__ERROR__'), - slice2_accumulator.point_estimates, - ) - self.assertRegex( - slice2_accumulator.point_estimates[error_key], - 'CI not computed because only 1.*Expected 2.*', - ) + self.assertIn(slice_key2, accumulators_by_slice) + slice2_accumulator = accumulators_by_slice[slice_key2] + # check unsampled value + self.assertIn(metric_key, slice2_accumulator.point_estimates) + self.assertEqual(6.3, slice2_accumulator.point_estimates[metric_key]) + # check that entirely missing sample generates error + self.assertIn( + metric_types.MetricKey("__ERROR__"), + slice2_accumulator.point_estimates, + ) + self.assertRegex( + slice2_accumulator.point_estimates[error_key], + "CI not computed because only 1.*Expected 2.*", + ) - util.assert_that(result, check_result) + util.assert_that(result, check_result) - runner_result = pipeline.run() - # we expect one missing samples counter increment for slice2, since we - # expected 2 samples, but only saw 1. - metric_filter = beam.metrics.metric.MetricsFilter().with_name( - 'num_slices_missing_samples' - ) - counters = runner_result.metrics().query(filter=metric_filter)['counters'] - self.assertLen(counters, 1) - self.assertEqual(1, counters[0].committed) + runner_result = pipeline.run() + # we expect one missing samples counter increment for slice2, since we + # expected 2 samples, but only saw 1. + metric_filter = beam.metrics.metric.MetricsFilter().with_name( + "num_slices_missing_samples" + ) + counters = runner_result.metrics().query(filter=metric_filter)["counters"] + self.assertLen(counters, 1) + self.assertEqual(1, counters[0].committed) - # verify total slice counter - metric_filter = beam.metrics.metric.MetricsFilter().with_name( - 'num_slices' - ) - counters = runner_result.metrics().query(filter=metric_filter)['counters'] - self.assertLen(counters, 1) - self.assertEqual(2, counters[0].committed) + # verify total slice counter + metric_filter = beam.metrics.metric.MetricsFilter().with_name("num_slices") + counters = runner_result.metrics().query(filter=metric_filter)["counters"] + self.assertLen(counters, 1) + self.assertEqual(2, counters[0].committed) - def test_sample_combine_fn_no_input(self): - slice_key = (('slice_feature', 1),) - samples = [ - ( - slice_key, - confidence_intervals_util.SampleMetrics( - sample_id=_FULL_SAMPLE_ID, metrics={} + def test_sample_combine_fn_no_input(self): + slice_key = (("slice_feature", 1),) + samples = [ + ( + slice_key, + confidence_intervals_util.SampleMetrics( + sample_id=_FULL_SAMPLE_ID, metrics={} + ), ), - ), - ( - slice_key, - confidence_intervals_util.SampleMetrics(sample_id=0, metrics={}), - ), - ( - slice_key, - confidence_intervals_util.SampleMetrics(sample_id=1, metrics={}), - ), - ] + ( + slice_key, + confidence_intervals_util.SampleMetrics(sample_id=0, metrics={}), + ), + ( + slice_key, + confidence_intervals_util.SampleMetrics(sample_id=1, metrics={}), + ), + ] - with beam.Pipeline() as pipeline: - result = ( - pipeline - | 'Create' >> beam.Create(samples) - | 'CombineSamplesPerKey' - >> beam.CombinePerKey( - _ValidateSampleCombineFn( - num_samples=2, full_sample_id=_FULL_SAMPLE_ID - ) - ) - ) + with beam.Pipeline() as pipeline: + result = ( + pipeline + | "Create" >> beam.Create(samples) + | "CombineSamplesPerKey" + >> beam.CombinePerKey( + _ValidateSampleCombineFn( + num_samples=2, full_sample_id=_FULL_SAMPLE_ID + ) + ) + ) - def check_result(got_pcoll): - self.assertLen(got_pcoll, 1) - accumulators_by_slice = dict(got_pcoll) - self.assertIn(slice_key, accumulators_by_slice) - accumulator = accumulators_by_slice[slice_key] - self.assertEqual(2, accumulator.num_samples) - self.assertIsInstance(accumulator.point_estimates, dict) - self.assertIsInstance(accumulator.metric_samples, dict) + def check_result(got_pcoll): + self.assertLen(got_pcoll, 1) + accumulators_by_slice = dict(got_pcoll) + self.assertIn(slice_key, accumulators_by_slice) + accumulator = accumulators_by_slice[slice_key] + self.assertEqual(2, accumulator.num_samples) + self.assertIsInstance(accumulator.point_estimates, dict) + self.assertIsInstance(accumulator.metric_samples, dict) - util.assert_that(result, check_result) + util.assert_that(result, check_result) -if __name__ == '__main__': - absltest.main() +if __name__ == "__main__": + absltest.main() diff --git a/tensorflow_model_analysis/evaluators/counter_util.py b/tensorflow_model_analysis/evaluators/counter_util.py index afd3c794fb..91f0ea0c4f 100644 --- a/tensorflow_model_analysis/evaluators/counter_util.py +++ b/tensorflow_model_analysis/evaluators/counter_util.py @@ -16,19 +16,20 @@ from typing import List, Set import apache_beam as beam + from tensorflow_model_analysis import constants from tensorflow_model_analysis.api import types from tensorflow_model_analysis.proto import config_pb2 def _IncrementMetricsCounters(metric_name: str, version: str, model_type: str): - # LINT.IfChange - metric_name = 'metric_computed_%s_%s_%s' % (metric_name, version, model_type) - # LINT.ThenChange(../../../../learning/fairness/infra/plx/scripts/tfma_metrics_computed_tracker_macros.sql) - metrics_counter = beam.metrics.Metrics.counter( - constants.METRICS_NAMESPACE, metric_name - ) - metrics_counter.inc(1) + # LINT.IfChange + metric_name = "metric_computed_%s_%s_%s" % (metric_name, version, model_type) + # LINT.ThenChange(../../../../learning/fairness/infra/plx/scripts/tfma_metrics_computed_tracker_macros.sql) + metrics_counter = beam.metrics.Metrics.counter( + constants.METRICS_NAMESPACE, metric_name + ) + metrics_counter.inc(1) @beam.ptransform_fn @@ -39,18 +40,18 @@ def IncrementMetricsCallbacksCounters( metrics_callbacks: List[types.AddMetricsCallbackType], model_type: str, ) -> beam.PCollection[None]: - """To track count of all the metrics being computed using TFMA.""" + """To track count of all the metrics being computed using TFMA.""" - def _MakeAndIncrementCounters(_): - for callback in metrics_callbacks: - if hasattr(callback, 'name'): - _IncrementMetricsCounters(callback.name, 'v1', model_type) + def _MakeAndIncrementCounters(_): + for callback in metrics_callbacks: + if hasattr(callback, "name"): + _IncrementMetricsCounters(callback.name, "v1", model_type) - return ( - pipeline - | 'CreateSole' >> beam.Create([None]) - | 'Count' >> beam.Map(_MakeAndIncrementCounters) - ) + return ( + pipeline + | "CreateSole" >> beam.Create([None]) + | "Count" >> beam.Map(_MakeAndIncrementCounters) + ) @beam.ptransform_fn @@ -59,23 +60,23 @@ def _MakeAndIncrementCounters(_): def IncrementSliceSpecCounters( pipeline: beam.Pipeline, ) -> beam.PCollection[None]: - """To track count of all slicing spec computed using TFMA.""" + """To track count of all slicing spec computed using TFMA.""" - def _MakeAndIncrementCounters(slice_list): - for slice_key, slice_value in slice_list: - # LINT.IfChange - slice_name = 'slice_computed_%s_%s' % (slice_key, slice_value) - # LINT.ThenChange(../../../../learning/fairness/infra/plx/scripts/tfma_metrics_computed_tracker_macros.sql) - slice_counter = beam.metrics.Metrics.counter( - constants.METRICS_NAMESPACE, slice_name - ) - slice_counter.inc(1) + def _MakeAndIncrementCounters(slice_list): + for slice_key, slice_value in slice_list: + # LINT.IfChange + slice_name = "slice_computed_%s_%s" % (slice_key, slice_value) + # LINT.ThenChange(../../../../learning/fairness/infra/plx/scripts/tfma_metrics_computed_tracker_macros.sql) + slice_counter = beam.metrics.Metrics.counter( + constants.METRICS_NAMESPACE, slice_name + ) + slice_counter.inc(1) - return ( - pipeline - | 'GetSliceCountKeys' >> beam.Keys() - | 'Count' >> beam.Map(_MakeAndIncrementCounters) - ) + return ( + pipeline + | "GetSliceCountKeys" >> beam.Keys() + | "Count" >> beam.Map(_MakeAndIncrementCounters) + ) @beam.ptransform_fn @@ -86,16 +87,16 @@ def IncrementMetricsSpecsCounters( metrics_specs: List[config_pb2.MetricsSpec], model_types: Set[str], ) -> beam.PCollection[None]: - """To track count of all metrics specs in TFMA.""" - - def _MakeAndIncrementCounters(_): - for model_type in model_types: - for metrics_spec in metrics_specs: - for metric in metrics_spec.metrics: - _IncrementMetricsCounters(metric.class_name, 'v2', model_type) - - return ( - pipeline - | 'CreateSole' >> beam.Create([None]) - | 'Count' >> beam.Map(_MakeAndIncrementCounters) - ) + """To track count of all metrics specs in TFMA.""" + + def _MakeAndIncrementCounters(_): + for model_type in model_types: + for metrics_spec in metrics_specs: + for metric in metrics_spec.metrics: + _IncrementMetricsCounters(metric.class_name, "v2", model_type) + + return ( + pipeline + | "CreateSole" >> beam.Create([None]) + | "Count" >> beam.Map(_MakeAndIncrementCounters) + ) diff --git a/tensorflow_model_analysis/evaluators/counter_util_test.py b/tensorflow_model_analysis/evaluators/counter_util_test.py index 36dfe5bd34..d1788144c6 100644 --- a/tensorflow_model_analysis/evaluators/counter_util_test.py +++ b/tensorflow_model_analysis/evaluators/counter_util_test.py @@ -15,59 +15,57 @@ import apache_beam as beam import tensorflow as tf + from tensorflow_model_analysis import constants from tensorflow_model_analysis.evaluators import counter_util from tensorflow_model_analysis.proto import config_pb2 class CounterUtilTest(tf.test.TestCase): + def testSliceSpecBeamCounter(self): + with beam.Pipeline() as pipeline: + _ = ( + pipeline + | beam.Create([((("slice_key", "first_slice"),), 2)]) + | counter_util.IncrementSliceSpecCounters() + ) - def testSliceSpecBeamCounter(self): - with beam.Pipeline() as pipeline: - _ = ( - pipeline - | beam.Create([((('slice_key', 'first_slice'),), 2)]) - | counter_util.IncrementSliceSpecCounters() - ) - - result = pipeline.run() + result = pipeline.run() - slice_spec_filter = ( - beam.metrics.metric.MetricsFilter() - .with_namespace(constants.METRICS_NAMESPACE) - .with_name('slice_computed_slice_key_first_slice') - ) - slice_count = ( - result.metrics() - .query(filter=slice_spec_filter)['counters'][0] - .committed - ) - self.assertEqual(slice_count, 1) + slice_spec_filter = ( + beam.metrics.metric.MetricsFilter() + .with_namespace(constants.METRICS_NAMESPACE) + .with_name("slice_computed_slice_key_first_slice") + ) + slice_count = ( + result.metrics().query(filter=slice_spec_filter)["counters"][0].committed + ) + self.assertEqual(slice_count, 1) - def testMetricsSpecBeamCounter(self): - with beam.Pipeline() as pipeline: - metrics_spec = config_pb2.MetricsSpec( - metrics=[config_pb2.MetricConfig(class_name='FairnessIndicators')] - ) - model_types = set(['tf_js', 'tf_keras']) - _ = pipeline | counter_util.IncrementMetricsSpecsCounters( - [metrics_spec], model_types - ) + def testMetricsSpecBeamCounter(self): + with beam.Pipeline() as pipeline: + metrics_spec = config_pb2.MetricsSpec( + metrics=[config_pb2.MetricConfig(class_name="FairnessIndicators")] + ) + model_types = set(["tf_js", "tf_keras"]) + _ = pipeline | counter_util.IncrementMetricsSpecsCounters( + [metrics_spec], model_types + ) - result = pipeline.run() + result = pipeline.run() - for model_type in model_types: - metric_filter = ( - beam.metrics.metric.MetricsFilter() - .with_namespace(constants.METRICS_NAMESPACE) - .with_name('metric_computed_FairnessIndicators_v2_' + model_type) - ) - actual_metrics_count = ( - result.metrics().query(filter=metric_filter)['counters'][0].committed - ) + for model_type in model_types: + metric_filter = ( + beam.metrics.metric.MetricsFilter() + .with_namespace(constants.METRICS_NAMESPACE) + .with_name("metric_computed_FairnessIndicators_v2_" + model_type) + ) + actual_metrics_count = ( + result.metrics().query(filter=metric_filter)["counters"][0].committed + ) - self.assertEqual(actual_metrics_count, 1) + self.assertEqual(actual_metrics_count, 1) -if __name__ == '__main__': - tf.test.main() +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow_model_analysis/evaluators/evaluator.py b/tensorflow_model_analysis/evaluators/evaluator.py index 6d4c58d400..07ab22103b 100644 --- a/tensorflow_model_analysis/evaluators/evaluator.py +++ b/tensorflow_model_analysis/evaluators/evaluator.py @@ -16,25 +16,20 @@ from typing import Any, Dict, Iterable, List, NamedTuple, Optional import apache_beam as beam + from tensorflow_model_analysis.extractors import extractor + # An evaluator is a PTransform that takes Extracts as input and produces an # Evaluation as output. A typical example of an evaluator is the # MetricsAndPlotsEvaluator that takes the 'features', 'labels', and # 'predictions' extracts from the PredictExtractor and evaluates them using post # export metrics to produce serialized metrics and plots. -Evaluator = NamedTuple( # pylint: disable=invalid-name - 'Evaluator', - [ - ('stage_name', str), - # Extractor.stage_name. If None then evaluation is run before any - # extractors are run. If LAST_EXTRACTOR_STAGE_NAME then evaluation is - # run after the last extractor has run. - ('run_after', Optional[str]), - # PTransform Extracts -> Evaluation - ('ptransform', beam.PTransform), - ], -) +class Evaluator(NamedTuple): + stage_name: str + run_after: Optional[str] + ptransform: beam.PTransform + # An Evaluation represents the output from evaluating the Extracts at a # particular point in the pipeline. The evaluation outputs are keyed by their @@ -43,103 +38,103 @@ Evaluation = Dict[str, beam.pvalue.PCollection] -def verify_evaluator( - evaluator: Evaluator, extractors: List[extractor.Extractor] -): - """Verifies evaluator is matched with an extractor. +def verify_evaluator(evaluator: Evaluator, extractors: List[extractor.Extractor]): + """Verifies evaluator is matched with an extractor. - Args: - evaluator: Evaluator to verify. - extractors: Extractors to use in verification. + Args: + ---- + evaluator: Evaluator to verify. + extractors: Extractors to use in verification. - Raises: - ValueError: If an Extractor cannot be found for the Evaluator. - """ - if ( - evaluator.run_after - and evaluator.run_after != extractor.LAST_EXTRACTOR_STAGE_NAME - and not any(evaluator.run_after == x.stage_name for x in extractors) - ): - raise ValueError( - 'Extractor matching run_after=%s for Evaluator %s not found' - % (evaluator.run_after, evaluator.stage_name) - ) + Raises: + ------ + ValueError: If an Extractor cannot be found for the Evaluator. + """ + if ( + evaluator.run_after + and evaluator.run_after != extractor.LAST_EXTRACTOR_STAGE_NAME + and not any(evaluator.run_after == x.stage_name for x in extractors) + ): + raise ValueError( + "Extractor matching run_after=%s for Evaluator %s not found" + % (evaluator.run_after, evaluator.stage_name) + ) class _CombineEvaluationDictionariesFn(beam.CombineFn): - """CombineFn to combine dictionaries generated by different evaluators.""" - - def create_accumulator(self) -> Dict[str, Any]: - return {} - - def _merge( - self, accumulator: Dict[str, Any], output_dict: Dict[str, Any] - ) -> None: - intersection = set(accumulator) & set(output_dict) - if intersection: - raise ValueError( - 'Dictionaries generated by different evaluators should have ' - 'different keys, but keys %s appeared in the output of multiple ' - 'evaluators' % intersection - ) - accumulator.update(output_dict) - - def add_input( - self, accumulator: Dict[str, Any], output_dict: Dict[str, Any] - ) -> Dict[str, Any]: - if not isinstance(output_dict, dict): - raise TypeError( - 'for outputs written to by multiple evaluators, the outputs must all ' - 'be dictionaries, but got output of type %s, value %s' - % (type(output_dict), str(output_dict)) - ) - self._merge(accumulator, output_dict) - return accumulator - - def merge_accumulators( - self, accumulators: Iterable[Dict[str, Any]] - ) -> Dict[str, Any]: - accumulators = iter(accumulators) - result = next(accumulators) - for acc in accumulators: - self._merge(result, acc) - return result - - def extract_output(self, accumulator: Dict[str, Any]) -> Dict[str, Any]: - return accumulator + """CombineFn to combine dictionaries generated by different evaluators.""" + + def create_accumulator(self) -> Dict[str, Any]: + return {} + + def _merge(self, accumulator: Dict[str, Any], output_dict: Dict[str, Any]) -> None: + intersection = set(accumulator) & set(output_dict) + if intersection: + raise ValueError( + "Dictionaries generated by different evaluators should have " + "different keys, but keys %s appeared in the output of multiple " + "evaluators" % intersection + ) + accumulator.update(output_dict) + + def add_input( + self, accumulator: Dict[str, Any], output_dict: Dict[str, Any] + ) -> Dict[str, Any]: + if not isinstance(output_dict, dict): + raise TypeError( + "for outputs written to by multiple evaluators, the outputs must all " + "be dictionaries, but got output of type %s, value %s" + % (type(output_dict), str(output_dict)) + ) + self._merge(accumulator, output_dict) + return accumulator + + def merge_accumulators( + self, accumulators: Iterable[Dict[str, Any]] + ) -> Dict[str, Any]: + accumulators = iter(accumulators) + result = next(accumulators) + for acc in accumulators: + self._merge(result, acc) + return result + + def extract_output(self, accumulator: Dict[str, Any]) -> Dict[str, Any]: + return accumulator def combine_dict_based_evaluations( evaluations: Dict[str, List[beam.pvalue.PCollection]], ) -> Evaluation: - """Combines multiple evaluation outputs together when the outputs are dicts. - - Note that the dict here refers to the output in the PCollection. The - evaluations themselves are dicts of PCollections keyed by category ('metrics', - 'plots', 'analysis', etc). This util is used to group the outputs of one or - more of these evaluations where the PCollections themselves must be dicts. For - example, a 'metrics' evaluation might store its output in PCollection of dicts - containing metric keys and metric values. This util would be used to group the - outputs from running two or more independent metrics evaluations together into - a single PCollection. - - Args: - evaluations: Dict of lists of PCollections of outputs from different - evaluators keyed by type of output ('metrics', 'plots', 'analysis', etc). - - Returns: - Dict of consolidated PCollections of outputs keyed by type of output. - """ - result = {} - for k, v in evaluations.items(): - if len(v) == 1: - result[k] = v[0] - continue - - result[k] = ( - v - | 'FlattenEvaluationOutput(%s)' % k >> beam.Flatten() - | 'CombineEvaluationOutput(%s)' % k - >> beam.CombinePerKey(_CombineEvaluationDictionariesFn()) - ) - return result + """Combines multiple evaluation outputs together when the outputs are dicts. + + Note that the dict here refers to the output in the PCollection. The + evaluations themselves are dicts of PCollections keyed by category ('metrics', + 'plots', 'analysis', etc). This util is used to group the outputs of one or + more of these evaluations where the PCollections themselves must be dicts. For + example, a 'metrics' evaluation might store its output in PCollection of dicts + containing metric keys and metric values. This util would be used to group the + outputs from running two or more independent metrics evaluations together into + a single PCollection. + + Args: + ---- + evaluations: Dict of lists of PCollections of outputs from different + evaluators keyed by type of output ('metrics', 'plots', 'analysis', etc). + + Returns: + ------- + Dict of consolidated PCollections of outputs keyed by type of output. + """ + result = {} + for k, v in evaluations.items(): + if len(v) == 1: + result[k] = v[0] + continue + + result[k] = ( + v + | "FlattenEvaluationOutput(%s)" % k >> beam.Flatten() + | "CombineEvaluationOutput(%s)" % k + >> beam.CombinePerKey(_CombineEvaluationDictionariesFn()) + ) + return result diff --git a/tensorflow_model_analysis/evaluators/evaluator_test.py b/tensorflow_model_analysis/evaluators/evaluator_test.py index a5d95dd559..816de8c763 100644 --- a/tensorflow_model_analysis/evaluators/evaluator_test.py +++ b/tensorflow_model_analysis/evaluators/evaluator_test.py @@ -14,36 +14,36 @@ """Test for evaluator.""" import tensorflow as tf + from tensorflow_model_analysis.evaluators import evaluator from tensorflow_model_analysis.extractors import extractor from tensorflow_model_analysis.utils import test_util class EvaluatorTest(test_util.TensorflowModelAnalysisTest): + def testVerifyEvaluatorRaisesValueError(self): + extractors = [ + extractor.Extractor(stage_name="ExtractorThatExists", ptransform=None) + ] + evaluator.verify_evaluator( + evaluator.Evaluator( + stage_name="EvaluatorWithoutError", + run_after="ExtractorThatExists", + ptransform=None, + ), + extractors, + ) - def testVerifyEvaluatorRaisesValueError(self): - extractors = [ - extractor.Extractor(stage_name='ExtractorThatExists', ptransform=None) - ] - evaluator.verify_evaluator( - evaluator.Evaluator( - stage_name='EvaluatorWithoutError', - run_after='ExtractorThatExists', - ptransform=None, - ), - extractors, - ) - - with self.assertRaises(ValueError): - evaluator.verify_evaluator( - evaluator.Evaluator( - stage_name='EvaluatorWithError', - run_after='ExtractorThatDoesNotExist', - ptransform=None, - ), - extractors, - ) + with self.assertRaises(ValueError): + evaluator.verify_evaluator( + evaluator.Evaluator( + stage_name="EvaluatorWithError", + run_after="ExtractorThatDoesNotExist", + ptransform=None, + ), + extractors, + ) -if __name__ == '__main__': - tf.test.main() +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow_model_analysis/evaluators/jackknife.py b/tensorflow_model_analysis/evaluators/jackknife.py index b0a5fda921..726ecd0104 100644 --- a/tensorflow_model_analysis/evaluators/jackknife.py +++ b/tensorflow_model_analysis/evaluators/jackknife.py @@ -17,6 +17,7 @@ import apache_beam as beam import numpy as np + from tensorflow_model_analysis.api import types from tensorflow_model_analysis.evaluators import confidence_intervals_util from tensorflow_model_analysis.metrics import metric_types @@ -25,121 +26,122 @@ _FULL_SAMPLE_ID = -1 -_AccumulatorType = TypeVar('_AccumulatorType') +_AccumulatorType = TypeVar("_AccumulatorType") class _AccumulateOnlyCombineFn(beam_util.DelegatingCombineFn): - """A combine_fn wrapper which returns the accumulator as the output value. + """A combine_fn wrapper which returns the accumulator as the output value. - This is intended to allow invoking CombineFns in settings where you might need - to subsequently merge the results before calling extract_output. A typical use - of a _AccumulateOnlyCombineFn might look like: + This is intended to allow invoking CombineFns in settings where you might need + to subsequently merge the results before calling extract_output. A typical use + of a _AccumulateOnlyCombineFn might look like: - c = OtherCombineFn() + c = OtherCombineFn() - # combine per key, but don't call extract_output() - accumulators = [p | beam.CombinePerKey(_AccumulateOnlyCombineFn(c) - for i in range(3)] + # combine per key, but don't call extract_output() + accumulators = [p | beam.CombinePerKey(_AccumulateOnlyCombineFn(c) + for i in range(3)] - # create two different lists of accumulator PCollections - even_accumulators = [a for i, a in enumerate(accumulators) if i % 2 == 0] + # create two different lists of accumulator PCollections + even_accumulators = [a for i, a in enumerate(accumulators) if i % 2 == 0] - # extract output on the two different sets of accumulators without - # recomputing the even accumulators - even_output = (even_accumulators | beam.Flatten() - | beam.CombinePerKey(_AccumulatorCombineFn(c))) - all_output = (accumulators | beam.Flatten() - | beam.CombinePerKey(_AccumulatorCombineFn(c))) - """ + # extract output on the two different sets of accumulators without + # recomputing the even accumulators + even_output = (even_accumulators | beam.Flatten() + | beam.CombinePerKey(_AccumulatorCombineFn(c))) + all_output = (accumulators | beam.Flatten() + | beam.CombinePerKey(_AccumulatorCombineFn(c))) + """ - def extract_output(self, accumulator: _AccumulatorType) -> _AccumulatorType: - return accumulator + def extract_output(self, accumulator: _AccumulatorType) -> _AccumulatorType: + return accumulator class _AccumulatorCombineFn(beam_util.DelegatingCombineFn): - """A CombineFn wrapper that takes accumulators as add_input elements. + """A CombineFn wrapper that takes accumulators as add_input elements. - In combination with _AccumulateOnlyCombineFn, this makes it possible to - operate on a CombineFn's accumulators prior to calling extract input. See the - _AccumulateOnlyCombineFn docstring for more details. - """ + In combination with _AccumulateOnlyCombineFn, this makes it possible to + operate on a CombineFn's accumulators prior to calling extract input. See the + _AccumulateOnlyCombineFn docstring for more details. + """ - def add_input( - self, accumulator: _AccumulatorType, element: _AccumulatorType - ) -> _AccumulatorType: - return self._combine_fn.merge_accumulators([accumulator, element]) + def add_input( + self, accumulator: _AccumulatorType, element: _AccumulatorType + ) -> _AccumulatorType: + return self._combine_fn.merge_accumulators([accumulator, element]) class _JackknifeSampleCombineFn(confidence_intervals_util.SampleCombineFn): - """Computes the jackknife standard error for each metric from samples.""" - - def __init__( - self, - num_jackknife_samples: int, - skip_ci_metric_keys: Optional[Set[metric_types.MetricKey]] = None, - ): - """Initializes a _JackknifeSampleCombineFn. - - Args: - num_jackknife_samples: The expected number of samples computed per slice. - skip_ci_metric_keys: Set of metric keys for which to skip confidence - interval computation. For metric keys in this set, just the unsampled - value will be returned. - """ - super().__init__( - num_samples=num_jackknife_samples, - full_sample_id=_FULL_SAMPLE_ID, - skip_ci_metric_keys=skip_ci_metric_keys, - ) - - def extract_output( - self, - accumulator: confidence_intervals_util.SampleCombineFn.SampleAccumulator, - ) -> metric_types.MetricsDict: - accumulator = self._validate_accumulator(accumulator) - result = {} - num_buckets = self._num_samples - for key, point_estimate in accumulator.point_estimates.items(): - if key not in accumulator.metric_samples: - result[key] = point_estimate - else: - # See jackknife cookie bucket method described in: - # go/rasta-confidence-intervals - pseudo_values = [] - total = None - for sample_value in accumulator.metric_samples[key]: - if total is None: - total = sample_value - else: - total = total + sample_value - pseudo_values.append( - point_estimate * num_buckets - sample_value * (num_buckets - 1) - ) - _, std_dev = confidence_intervals_util.mean_and_std( - pseudo_values, ddof=1 + """Computes the jackknife standard error for each metric from samples.""" + + def __init__( + self, + num_jackknife_samples: int, + skip_ci_metric_keys: Optional[Set[metric_types.MetricKey]] = None, + ): + """Initializes a _JackknifeSampleCombineFn. + + Args: + ---- + num_jackknife_samples: The expected number of samples computed per slice. + skip_ci_metric_keys: Set of metric keys for which to skip confidence + interval computation. For metric keys in this set, just the unsampled + value will be returned. + """ + super().__init__( + num_samples=num_jackknife_samples, + full_sample_id=_FULL_SAMPLE_ID, + skip_ci_metric_keys=skip_ci_metric_keys, ) - # Here we use Student's t-distribution to estimate the standard - # error with n - 1 degrees of freedom as S.E. = S.D. / sqrt(n)a - # In the case of the delete-d jackknife, the standard error is inversely - # proprotional to the square root of the number of data partitions. - std_error = std_dev / (num_buckets**0.5) - mean = total / num_buckets - result[key] = types.ValueWithTDistribution( - sample_mean=mean, - sample_standard_deviation=std_error, - unsampled_value=point_estimate, - sample_degrees_of_freedom=num_buckets - 1, - ) - return result # pytype: disable=bad-return-type # numpy-scalars + + def extract_output( + self, + accumulator: confidence_intervals_util.SampleCombineFn.SampleAccumulator, + ) -> metric_types.MetricsDict: + accumulator = self._validate_accumulator(accumulator) + result = {} + num_buckets = self._num_samples + for key, point_estimate in accumulator.point_estimates.items(): + if key not in accumulator.metric_samples: + result[key] = point_estimate + else: + # See jackknife cookie bucket method described in: + # go/rasta-confidence-intervals + pseudo_values = [] + total = None + for sample_value in accumulator.metric_samples[key]: + if total is None: + total = sample_value + else: + total = total + sample_value + pseudo_values.append( + point_estimate * num_buckets - sample_value * (num_buckets - 1) + ) + _, std_dev = confidence_intervals_util.mean_and_std( + pseudo_values, ddof=1 + ) + # Here we use Student's t-distribution to estimate the standard + # error with n - 1 degrees of freedom as S.E. = S.D. / sqrt(n)a + # In the case of the delete-d jackknife, the standard error is inversely + # proprotional to the square root of the number of data partitions. + std_error = std_dev / (num_buckets**0.5) + mean = total / num_buckets + result[key] = types.ValueWithTDistribution( + sample_mean=mean, + sample_standard_deviation=std_error, + unsampled_value=point_estimate, + sample_degrees_of_freedom=num_buckets - 1, + ) + return result # pytype: disable=bad-return-type # numpy-scalars def _add_sample_id( slice_key, metrics_dict: metric_types.MetricsDict, sample_id: int = 0 ): - # sample_id has a default value in order to satisfy requirement of MapTuple - return slice_key, confidence_intervals_util.SampleMetrics( - metrics=metrics_dict, sample_id=sample_id - ) + # sample_id has a default value in order to satisfy requirement of MapTuple + return slice_key, confidence_intervals_util.SampleMetrics( + metrics=metrics_dict, sample_id=sample_id + ) @beam.ptransform_fn @@ -151,34 +153,34 @@ def _ComputeJackknifeSample( # pylint: disable=invalid-name computations_combine_fn: beam.CombineFn, derived_metrics_ptransform: beam.PTransform, ) -> beam.PCollection[confidence_intervals_util.SampleMetrics]: - """Computes a single jackknife delete-d sample from partition accumulators. - - Args: - sample_accumulators: A PCollections of combiner accumulators to be used for - a given sample. - sample_id: The sample_id to generate. This is used to determine which - partition accumulators to skip. - computations_combine_fn: a beam.CombineFn instance that takes input elements - of type Extracts and returns a MetricsDict. This will be invoked as part - of a CombinePerKey in which the key is a slice key. - derived_metrics_ptransform: A PTransform which adds derived metrics to the - results of the computations_combine_fn. This PTransform should both input - and output a single PCollection with elements of type MetricsDict where - the output MetricsDict includes additional derived metrics. - - Returns: - A single sample tuple containing the sample_id and the metric dictionary for - that sample - """ - - return ( - sample_accumulators - | 'MergePartitionsPerSlice' - >> beam.CombinePerKey(_AccumulatorCombineFn(computations_combine_fn)) - | 'AddDerivedMetrics' >> derived_metrics_ptransform - | 'AddSampleIdToValue' - >> beam.MapTuple(_add_sample_id, sample_id=sample_id) - ) + """Computes a single jackknife delete-d sample from partition accumulators. + + Args: + ---- + sample_accumulators: A PCollections of combiner accumulators to be used for + a given sample. + sample_id: The sample_id to generate. This is used to determine which + partition accumulators to skip. + computations_combine_fn: a beam.CombineFn instance that takes input elements + of type Extracts and returns a MetricsDict. This will be invoked as part + of a CombinePerKey in which the key is a slice key. + derived_metrics_ptransform: A PTransform which adds derived metrics to the + results of the computations_combine_fn. This PTransform should both input + and output a single PCollection with elements of type MetricsDict where + the output MetricsDict includes additional derived metrics. + + Returns: + ------- + A single sample tuple containing the sample_id and the metric dictionary for + that sample + """ + return ( + sample_accumulators + | "MergePartitionsPerSlice" + >> beam.CombinePerKey(_AccumulatorCombineFn(computations_combine_fn)) + | "AddDerivedMetrics" >> derived_metrics_ptransform + | "AddSampleIdToValue" >> beam.MapTuple(_add_sample_id, sample_id=sample_id) + ) @beam.ptransform_fn @@ -194,95 +196,96 @@ def ComputeWithConfidenceIntervals( # pylint: disable=invalid-name ) -> beam.pvalue.PCollection[ Tuple[slicer.SliceKeyOrCrossSliceKeyType, metric_types.MetricsDict] ]: - """Computes base metrics and derived metrics and adds std error estimates. - - Args: - sliced_extracts: Incoming PCollection consisting of slice key and extracts. - computations_combine_fn: a beam.CombineFn instance that takes input elements - of type Extracts and returns a MetricsDict. This will be invoked as part - of a CombinePerKey in which the key is a slice key. - derived_metrics_ptransform: A PTransform which adds derived metrics to the - results of the computations_combine_fn. This PTransform should both input - and output a single PCollection with elements of type MetricsDict where - the output MetricsDict includes additional derived metrics. - num_jackknife_samples: The number of jackknife replicates to use in - computing the jackknife standard error. - skip_ci_metric_keys: Set of metric keys for which to skip confidence - interval computation. For metric keys in this set, just the unsampled - value will be returned. - random_seed_for_testing: Seed to use for unit testing, because - nondeterministic tests stink. Each partition will use this value + i. - - Returns: - A PCollection of sliced metrics containing standard error estimates for - each numeric metric. - """ - - random_state = np.random.RandomState(random_seed_for_testing) - - def partition_fn(_, num_partitions): - return random_state.randint(num_partitions) - - # Partition the data - # List[PCollection[Tuple[slicer.SliceKeyType, types.Extracts]]] - partitions = ( - sliced_extracts - | f'Partition({num_jackknife_samples})' - >> beam.Partition(partition_fn, num_jackknife_samples) - ) - - # Within each partition, partially combine per slice key to get accumulators - # and partition sizes; add partition_id for determinism. - # List[PCollection[Tuple[slicer.SliceKeyType, AccumulatorType]]] - partition_accumulators = [] - for i, partition in enumerate(partitions): - partition_accumulators.append( - partition - | f'CombinePartitionPerSlice[{i}]' - >> beam.CombinePerKey(_AccumulateOnlyCombineFn(computations_combine_fn)) + """Computes base metrics and derived metrics and adds std error estimates. + + Args: + ---- + sliced_extracts: Incoming PCollection consisting of slice key and extracts. + computations_combine_fn: a beam.CombineFn instance that takes input elements + of type Extracts and returns a MetricsDict. This will be invoked as part + of a CombinePerKey in which the key is a slice key. + derived_metrics_ptransform: A PTransform which adds derived metrics to the + results of the computations_combine_fn. This PTransform should both input + and output a single PCollection with elements of type MetricsDict where + the output MetricsDict includes additional derived metrics. + num_jackknife_samples: The number of jackknife replicates to use in + computing the jackknife standard error. + skip_ci_metric_keys: Set of metric keys for which to skip confidence + interval computation. For metric keys in this set, just the unsampled + value will be returned. + random_seed_for_testing: Seed to use for unit testing, because + nondeterministic tests stink. Each partition will use this value + i. + + Returns: + ------- + A PCollection of sliced metrics containing standard error estimates for + each numeric metric. + """ + random_state = np.random.RandomState(random_seed_for_testing) + + def partition_fn(_, num_partitions): + return random_state.randint(num_partitions) + + # Partition the data + # List[PCollection[Tuple[slicer.SliceKeyType, types.Extracts]]] + partitions = ( + sliced_extracts + | f"Partition({num_jackknife_samples})" + >> beam.Partition(partition_fn, num_jackknife_samples) ) - unsampled_metrics = ( - partition_accumulators - | 'FlattenPartitions' >> beam.Flatten() - | 'MergePartitionsPerSlice' - >> beam.CombinePerKey(_AccumulatorCombineFn(computations_combine_fn)) - | 'AddDerivedMetrics' >> derived_metrics_ptransform - | 'AddSampleIdToValue' - >> beam.MapTuple(_add_sample_id, sample_id=_FULL_SAMPLE_ID) - ) - - # Compute the combine_fn output for the delete-d samples by merging all but - # one partitions. - # List[PCollection[Tuple[slicer.SliceKeyType, SampleMetrics]]] - delete_d_samples = [] - for sample_id in range(num_jackknife_samples): - # TODO(b/194732335): Push filter and Flatten into _ComputeJackknifeSample. - # TODO(b/130032676): Replace the 'ExcludePartition' step with for-loop - # exclusion after cl/435922775 (or equivalent) is submitted. - sample_accumulators = [ - acc | f'ExcludePartition[{sample_id}]' >> beam.Filter(lambda _: False) - if i == sample_id - else acc - for i, acc in enumerate(partition_accumulators) - ] - delete_d_samples.append( - sample_accumulators - | f'FlattenPartitions[{sample_id}]' >> beam.Flatten() - | f'ComputeJackknifeSample[{sample_id}]' - >> _ComputeJackknifeSample( # pylint: disable=no-value-for-parameter - sample_id=sample_id, - computations_combine_fn=computations_combine_fn, - derived_metrics_ptransform=derived_metrics_ptransform, + # Within each partition, partially combine per slice key to get accumulators + # and partition sizes; add partition_id for determinism. + # List[PCollection[Tuple[slicer.SliceKeyType, AccumulatorType]]] + partition_accumulators = [] + for i, partition in enumerate(partitions): + partition_accumulators.append( + partition + | f"CombinePartitionPerSlice[{i}]" + >> beam.CombinePerKey(_AccumulateOnlyCombineFn(computations_combine_fn)) ) + + unsampled_metrics = ( + partition_accumulators + | "FlattenPartitions" >> beam.Flatten() + | "MergePartitionsPerSlice" + >> beam.CombinePerKey(_AccumulatorCombineFn(computations_combine_fn)) + | "AddDerivedMetrics" >> derived_metrics_ptransform + | "AddSampleIdToValue" + >> beam.MapTuple(_add_sample_id, sample_id=_FULL_SAMPLE_ID) ) - # PCollection[Tuple[slicer.SliceKeyType, metric_types.MetricsDict]] - return ( - delete_d_samples + [unsampled_metrics] - | 'FlattenSamples' >> beam.Flatten() - | 'CombineJackknifeSamplesPerSlice' - >> beam.CombinePerKey( - _JackknifeSampleCombineFn(num_jackknife_samples, skip_ci_metric_keys) - ) - ) + # Compute the combine_fn output for the delete-d samples by merging all but + # one partitions. + # List[PCollection[Tuple[slicer.SliceKeyType, SampleMetrics]]] + delete_d_samples = [] + for sample_id in range(num_jackknife_samples): + # TODO(b/194732335): Push filter and Flatten into _ComputeJackknifeSample. + # TODO(b/130032676): Replace the 'ExcludePartition' step with for-loop + # exclusion after cl/435922775 (or equivalent) is submitted. + sample_accumulators = [ + acc | f"ExcludePartition[{sample_id}]" >> beam.Filter(lambda _: False) + if i == sample_id + else acc + for i, acc in enumerate(partition_accumulators) + ] + delete_d_samples.append( + sample_accumulators + | f"FlattenPartitions[{sample_id}]" >> beam.Flatten() + | f"ComputeJackknifeSample[{sample_id}]" + >> _ComputeJackknifeSample( # pylint: disable=no-value-for-parameter + sample_id=sample_id, + computations_combine_fn=computations_combine_fn, + derived_metrics_ptransform=derived_metrics_ptransform, + ) + ) + + # PCollection[Tuple[slicer.SliceKeyType, metric_types.MetricsDict]] + return ( + delete_d_samples + [unsampled_metrics] + | "FlattenSamples" >> beam.Flatten() + | "CombineJackknifeSamplesPerSlice" + >> beam.CombinePerKey( + _JackknifeSampleCombineFn(num_jackknife_samples, skip_ci_metric_keys) + ) + ) diff --git a/tensorflow_model_analysis/evaluators/jackknife_test.py b/tensorflow_model_analysis/evaluators/jackknife_test.py index 2427566c68..2b2f3956d0 100644 --- a/tensorflow_model_analysis/evaluators/jackknife_test.py +++ b/tensorflow_model_analysis/evaluators/jackknife_test.py @@ -15,262 +15,257 @@ import functools -from absl.testing import absltest import apache_beam as beam +from absl.testing import absltest from apache_beam.testing import util + from tensorflow_model_analysis.api import types -from tensorflow_model_analysis.evaluators import confidence_intervals_util -from tensorflow_model_analysis.evaluators import jackknife -from tensorflow_model_analysis.metrics import binary_confusion_matrices -from tensorflow_model_analysis.metrics import metric_types +from tensorflow_model_analysis.evaluators import confidence_intervals_util, jackknife +from tensorflow_model_analysis.metrics import binary_confusion_matrices, metric_types class ListCombineFn(beam.CombineFn): + def __init__(self, extract_output_append=None): + self._extract_output_append = extract_output_append - def __init__(self, extract_output_append=None): - self._extract_output_append = extract_output_append - - def create_accumulator(self): - return [] + def create_accumulator(self): + return [] - def add_input(self, accumulator, element): - return accumulator + [element] + def add_input(self, accumulator, element): + return accumulator + [element] - def merge_accumulators(self, accumulators): - return functools.reduce(list.__add__, accumulators) + def merge_accumulators(self, accumulators): + return functools.reduce(list.__add__, accumulators) - def extract_output(self, accumulator): - if self._extract_output_append: - return accumulator + [self._extract_output_append] - else: - return accumulator + def extract_output(self, accumulator): + if self._extract_output_append: + return accumulator + [self._extract_output_append] + else: + return accumulator class ListCombineFnExtractOutputNotImplemented(ListCombineFn): - - def extract_output(self, accumulator): - raise NotImplementedError( - 'extract_output intentionally not implement to verify behavior. We ' - 'would like to be able to mock a combine_fn and then call ' - 'combine_fn.extract_output.assert_not_called().' - ) + def extract_output(self, accumulator): + raise NotImplementedError( + "extract_output intentionally not implement to verify behavior. We " + "would like to be able to mock a combine_fn and then call " + "combine_fn.extract_output.assert_not_called()." + ) class ListCombineFnAddInputNotImplemented(ListCombineFn): - - def add_input(self, accumulator, element): - raise NotImplementedError( - 'add_input intentionally not implement to verify behavior. We would ' - 'like to be able to mock a combine_fn and then call ' - 'combine_fn.add_input.assert_not_called().' - ) + def add_input(self, accumulator, element): + raise NotImplementedError( + "add_input intentionally not implement to verify behavior. We would " + "like to be able to mock a combine_fn and then call " + "combine_fn.add_input.assert_not_called()." + ) class JackknifeTest(absltest.TestCase): + def test_accumulate_only_combiner(self): + with beam.Pipeline() as pipeline: + result = ( + pipeline + | "Create" >> beam.Create([1, 2]) + | "AccumulateOnlyCombine" + >> beam.CombineGlobally( + jackknife._AccumulateOnlyCombineFn( + ListCombineFnExtractOutputNotImplemented( + extract_output_append=3 + ) + ) + ) + ) - def test_accumulate_only_combiner(self): - with beam.Pipeline() as pipeline: - result = ( - pipeline - | 'Create' >> beam.Create([1, 2]) - | 'AccumulateOnlyCombine' - >> beam.CombineGlobally( - jackknife._AccumulateOnlyCombineFn( - ListCombineFnExtractOutputNotImplemented( - extract_output_append=3 - ) - ) - ) - ) + def check_result(got_pcoll): + self.assertLen(got_pcoll, 1) + self.assertEqual(got_pcoll[0], [1, 2]) - def check_result(got_pcoll): - self.assertLen(got_pcoll, 1) - self.assertEqual(got_pcoll[0], [1, 2]) + util.assert_that(result, check_result) - util.assert_that(result, check_result) + def test_accumulator_combiner(self): + with beam.Pipeline() as pipeline: + result = ( + pipeline + | "Create" >> beam.Create([[1], [2]]) + | "AccumulatorCombine" + >> beam.CombineGlobally( + jackknife._AccumulatorCombineFn( + ListCombineFnAddInputNotImplemented(extract_output_append=3) + ) + ) + ) - def test_accumulator_combiner(self): - with beam.Pipeline() as pipeline: - result = ( - pipeline - | 'Create' >> beam.Create([[1], [2]]) - | 'AccumulatorCombine' - >> beam.CombineGlobally( - jackknife._AccumulatorCombineFn( - ListCombineFnAddInputNotImplemented(extract_output_append=3) - ) - ) - ) + def check_result(got_pcoll): + self.assertLen(got_pcoll, 1) + self.assertEqual(got_pcoll[0], [1, 2, 3]) - def check_result(got_pcoll): - self.assertLen(got_pcoll, 1) - self.assertEqual(got_pcoll[0], [1, 2, 3]) + util.assert_that(result, check_result) - util.assert_that(result, check_result) - - def test_jackknife_sample_combine_fn(self): - x_key = metric_types.MetricKey('x') - y_key = metric_types.MetricKey('y') - cm_key = metric_types.MetricKey('confusion_matrix') - cm_metric = binary_confusion_matrices.Matrices( - thresholds=[0.5], tp=[0], fp=[1], tn=[2], fn=[3] - ) - slice_key1 = (('slice_feature', 1),) - slice_key2 = (('slice_feature', 2),) - samples = [ - # point estimate for slice 1 - ( - slice_key1, - confidence_intervals_util.SampleMetrics( - sample_id=jackknife._FULL_SAMPLE_ID, - metrics={ - x_key: 1.6, - y_key: 16, - cm_key: cm_metric, - }, - ), - ), - # sample values 1 of 2 for slice 1 - ( - slice_key1, - confidence_intervals_util.SampleMetrics( - sample_id=0, - metrics={ - x_key: 1, - y_key: 10, - cm_key: cm_metric - 1, - }, - ), - ), - # sample values 2 of 2 for slice 1 - ( - slice_key1, - confidence_intervals_util.SampleMetrics( - sample_id=1, - metrics={ - x_key: 2, - y_key: 20, - cm_key: cm_metric + 1, - }, + def test_jackknife_sample_combine_fn(self): + x_key = metric_types.MetricKey("x") + y_key = metric_types.MetricKey("y") + cm_key = metric_types.MetricKey("confusion_matrix") + cm_metric = binary_confusion_matrices.Matrices( + thresholds=[0.5], tp=[0], fp=[1], tn=[2], fn=[3] + ) + slice_key1 = (("slice_feature", 1),) + slice_key2 = (("slice_feature", 2),) + samples = [ + # point estimate for slice 1 + ( + slice_key1, + confidence_intervals_util.SampleMetrics( + sample_id=jackknife._FULL_SAMPLE_ID, + metrics={ + x_key: 1.6, + y_key: 16, + cm_key: cm_metric, + }, + ), ), - ), - # point estimate for slice 2 - ( - slice_key2, - confidence_intervals_util.SampleMetrics( - sample_id=jackknife._FULL_SAMPLE_ID, - metrics={ - x_key: 3.3, - y_key: 33, - cm_key: cm_metric, - }, + # sample values 1 of 2 for slice 1 + ( + slice_key1, + confidence_intervals_util.SampleMetrics( + sample_id=0, + metrics={ + x_key: 1, + y_key: 10, + cm_key: cm_metric - 1, + }, + ), ), - ), - # sample values 1 of 2 for slice 2 - ( - slice_key2, - confidence_intervals_util.SampleMetrics( - sample_id=0, - metrics={ - x_key: 2, - y_key: 20, - cm_key: cm_metric - 10, - }, + # sample values 2 of 2 for slice 1 + ( + slice_key1, + confidence_intervals_util.SampleMetrics( + sample_id=1, + metrics={ + x_key: 2, + y_key: 20, + cm_key: cm_metric + 1, + }, + ), ), - ), - # sample values 2 of 2 for slice 2 - ( - slice_key2, - confidence_intervals_util.SampleMetrics( - sample_id=1, - metrics={ - x_key: 4, - y_key: 40, - cm_key: cm_metric + 10, - }, + # point estimate for slice 2 + ( + slice_key2, + confidence_intervals_util.SampleMetrics( + sample_id=jackknife._FULL_SAMPLE_ID, + metrics={ + x_key: 3.3, + y_key: 33, + cm_key: cm_metric, + }, + ), ), - ), - ] - - with beam.Pipeline() as pipeline: - result = ( - pipeline - | 'Create' >> beam.Create(samples, reshuffle=False) - | 'CombineJackknifeSamplesPerKey' - >> beam.CombinePerKey( - jackknife._JackknifeSampleCombineFn(num_jackknife_samples=2) - ) - ) - - # WARNING: Do not change this test without carefully considering the - # impact on clients due to changed CI bounds. The current implementation - # follows jackknife cookie bucket method described in: - # go/rasta-confidence-intervals - def check_result(got_pcoll): - expected_pcoll = [ + # sample values 1 of 2 for slice 2 ( - slice_key1, - { - x_key: types.ValueWithTDistribution( - sample_mean=1.5, - sample_standard_deviation=0.5, - sample_degrees_of_freedom=1, - unsampled_value=1.6, - ), - y_key: types.ValueWithTDistribution( - sample_mean=15.0, - sample_standard_deviation=5, - sample_degrees_of_freedom=1, - unsampled_value=16, - ), - cm_key: types.ValueWithTDistribution( - sample_mean=cm_metric, - sample_standard_deviation=( - binary_confusion_matrices.Matrices( - thresholds=[0.5], tp=[1], fp=[1], tn=[1], fn=[1] - ) - ), - sample_degrees_of_freedom=1, - unsampled_value=cm_metric, - ), - }, + slice_key2, + confidence_intervals_util.SampleMetrics( + sample_id=0, + metrics={ + x_key: 2, + y_key: 20, + cm_key: cm_metric - 10, + }, + ), ), + # sample values 2 of 2 for slice 2 ( slice_key2, - { - x_key: types.ValueWithTDistribution( - sample_mean=3.0, - sample_standard_deviation=1, - sample_degrees_of_freedom=1, - unsampled_value=3.3, - ), - y_key: types.ValueWithTDistribution( - sample_mean=30.0, - sample_standard_deviation=10, - sample_degrees_of_freedom=1, - unsampled_value=33, - ), - cm_key: types.ValueWithTDistribution( - sample_mean=cm_metric, - sample_standard_deviation=( - binary_confusion_matrices.Matrices( - thresholds=[0.5], - tp=[10], - fp=[10], - tn=[10], - fn=[10], - ) - ), - sample_degrees_of_freedom=1, - unsampled_value=cm_metric, - ), - }, + confidence_intervals_util.SampleMetrics( + sample_id=1, + metrics={ + x_key: 4, + y_key: 40, + cm_key: cm_metric + 10, + }, + ), ), ] - self.assertCountEqual(expected_pcoll, got_pcoll) - util.assert_that(result, check_result) + with beam.Pipeline() as pipeline: + result = ( + pipeline + | "Create" >> beam.Create(samples, reshuffle=False) + | "CombineJackknifeSamplesPerKey" + >> beam.CombinePerKey( + jackknife._JackknifeSampleCombineFn(num_jackknife_samples=2) + ) + ) + + # WARNING: Do not change this test without carefully considering the + # impact on clients due to changed CI bounds. The current implementation + # follows jackknife cookie bucket method described in: + # go/rasta-confidence-intervals + def check_result(got_pcoll): + expected_pcoll = [ + ( + slice_key1, + { + x_key: types.ValueWithTDistribution( + sample_mean=1.5, + sample_standard_deviation=0.5, + sample_degrees_of_freedom=1, + unsampled_value=1.6, + ), + y_key: types.ValueWithTDistribution( + sample_mean=15.0, + sample_standard_deviation=5, + sample_degrees_of_freedom=1, + unsampled_value=16, + ), + cm_key: types.ValueWithTDistribution( + sample_mean=cm_metric, + sample_standard_deviation=( + binary_confusion_matrices.Matrices( + thresholds=[0.5], tp=[1], fp=[1], tn=[1], fn=[1] + ) + ), + sample_degrees_of_freedom=1, + unsampled_value=cm_metric, + ), + }, + ), + ( + slice_key2, + { + x_key: types.ValueWithTDistribution( + sample_mean=3.0, + sample_standard_deviation=1, + sample_degrees_of_freedom=1, + unsampled_value=3.3, + ), + y_key: types.ValueWithTDistribution( + sample_mean=30.0, + sample_standard_deviation=10, + sample_degrees_of_freedom=1, + unsampled_value=33, + ), + cm_key: types.ValueWithTDistribution( + sample_mean=cm_metric, + sample_standard_deviation=( + binary_confusion_matrices.Matrices( + thresholds=[0.5], + tp=[10], + fp=[10], + tn=[10], + fn=[10], + ) + ), + sample_degrees_of_freedom=1, + unsampled_value=cm_metric, + ), + }, + ), + ] + self.assertCountEqual(expected_pcoll, got_pcoll) + + util.assert_that(result, check_result) -if __name__ == '__main__': - absltest.main() +if __name__ == "__main__": + absltest.main() diff --git a/tensorflow_model_analysis/evaluators/keras_util.py b/tensorflow_model_analysis/evaluators/keras_util.py index c802a2d995..68bd4854a9 100644 --- a/tensorflow_model_analysis/evaluators/keras_util.py +++ b/tensorflow_model_analysis/evaluators/keras_util.py @@ -20,14 +20,16 @@ import apache_beam as beam import numpy as np import tensorflow as tf + from tensorflow_model_analysis import constants from tensorflow_model_analysis.api import types -from tensorflow_model_analysis.metrics import metric_types -from tensorflow_model_analysis.metrics import metric_util -from tensorflow_model_analysis.metrics import tf_metric_accumulators +from tensorflow_model_analysis.metrics import ( + metric_types, + metric_util, + tf_metric_accumulators, +) from tensorflow_model_analysis.proto import config_pb2 -from tensorflow_model_analysis.utils import model_util -from tensorflow_model_analysis.utils import util +from tensorflow_model_analysis.utils import model_util, util from tensorflow_model_analysis.utils.keras_lib import tf_keras @@ -38,74 +40,77 @@ def metric_computations_using_keras_saved_model( eval_config: Optional[config_pb2.EvalConfig], batch_size: Optional[int] = None, ) -> metric_types.MetricComputations: - """Returns computations for computing metrics natively using keras. - - Args: - model_name: Name of model. - model_loader: Loader for shared model containing keras saved model to use - for metric computations. - eval_config: Eval config. - batch_size: Batch size to use during evaluation (testing only). - """ - model = model_loader.load() - # If metrics were only added using model.compile then use - # model.compiled_metrics and model.compiled_loss to compute the metrics, - # otherwise custom metrics added via model.add_metric were also used and we - # need to call model.evaluate. - if not model.metrics: - return [] - elif ( - hasattr(model, 'compiled_metrics') - and model.compiled_metrics - and hasattr(model, 'compiled_loss') - and model.compiled_loss - and len(model.compiled_metrics.metrics) + len(model.compiled_loss.metrics) - == len(model.metrics) - ): - if hasattr(model, 'output_names') and model.output_names: - output_names = model.output_names - else: - output_names = [] - keys = _metric_keys( - chain(model.compiled_metrics.metrics, model.compiled_loss.metrics), - model_name, - output_names, - ) - return [ - metric_types.MetricComputation( - keys=keys, - preprocessors=None, - combiner=_KerasCompiledMetricsCombiner( - keys, model_name, model_loader, eval_config, batch_size - ), + """Returns computations for computing metrics natively using keras. + + Args: + ---- + model_name: Name of model. + model_loader: Loader for shared model containing keras saved model to use + for metric computations. + eval_config: Eval config. + batch_size: Batch size to use during evaluation (testing only). + """ + model = model_loader.load() + # If metrics were only added using model.compile then use + # model.compiled_metrics and model.compiled_loss to compute the metrics, + # otherwise custom metrics added via model.add_metric were also used and we + # need to call model.evaluate. + if not model.metrics: + return [] + elif ( + hasattr(model, "compiled_metrics") + and model.compiled_metrics + and hasattr(model, "compiled_loss") + and model.compiled_loss + and len(model.compiled_metrics.metrics) + len(model.compiled_loss.metrics) + == len(model.metrics) + ): + if hasattr(model, "output_names") and model.output_names: + output_names = model.output_names + else: + output_names = [] + keys = _metric_keys( + chain(model.compiled_metrics.metrics, model.compiled_loss.metrics), + model_name, + output_names, ) - ] - else: - if hasattr(model, 'output_names') and model.output_names: - output_names = model.output_names + return [ + metric_types.MetricComputation( + keys=keys, + preprocessors=None, + combiner=_KerasCompiledMetricsCombiner( + keys, model_name, model_loader, eval_config, batch_size + ), + ) + ] else: - output_names = [] - keys = _metric_keys(model.metrics, model_name, output_names) - specs = model_util.get_input_specs(model, signature_name=None) - feature_keys = list(specs) if specs else [] - return [ - metric_types.MetricComputation( - keys=keys, - preprocessors=[ - metric_types.StandardMetricInputsPreprocessorList([ - metric_types.FeaturePreprocessor( - feature_keys=feature_keys, model_names=[model_name] - ), - metric_types.TransformedFeaturePreprocessor( - feature_keys=feature_keys, model_names=[model_name] - ), - ]) - ], - combiner=_KerasEvaluateCombiner( - keys, model_name, model_loader, eval_config, batch_size - ), - ) - ] + if hasattr(model, "output_names") and model.output_names: + output_names = model.output_names + else: + output_names = [] + keys = _metric_keys(model.metrics, model_name, output_names) + specs = model_util.get_input_specs(model, signature_name=None) + feature_keys = list(specs) if specs else [] + return [ + metric_types.MetricComputation( + keys=keys, + preprocessors=[ + metric_types.StandardMetricInputsPreprocessorList( + [ + metric_types.FeaturePreprocessor( + feature_keys=feature_keys, model_names=[model_name] + ), + metric_types.TransformedFeaturePreprocessor( + feature_keys=feature_keys, model_names=[model_name] + ), + ] + ) + ], + combiner=_KerasEvaluateCombiner( + keys, model_name, model_loader, eval_config, batch_size + ), + ) + ] def _metric_keys( @@ -113,410 +118,400 @@ def _metric_keys( model_name: str, output_names: Iterable[str], ) -> List[metric_types.MetricKey]: - """Returns metric keys for given metrics.""" - # We need to use the metric name to determine the associated output because - # keras does not provide an API (see b/149780822). Keras names its metrics - # using the following format: - # _[weighted]_ - result = [] - for metric in metrics: - sub_key = None - if hasattr(metric, 'class_id') and metric.class_id is not None: - sub_key = metric_types.SubKey(class_id=metric.class_id) - elif hasattr(metric, 'top_k') and metric.top_k is not None: - sub_key = metric_types.SubKey(top_k=metric.top_k) - for output_name in output_names or []: - if metric.name.startswith(output_name + '_'): - # TODO(b/171559113): Output prefixes used to be added multiple times. - # Remove this while loop after the last TF version with the issue is - # no longer supported. - name = metric.name - while name.startswith(output_name + '_'): - name = name[len(output_name) + 1 :] - result.append( - metric_types.MetricKey( - name=name, - model_name=model_name, - output_name=output_name, - sub_key=sub_key, - example_weighted=None, + """Returns metric keys for given metrics.""" + # We need to use the metric name to determine the associated output because + # keras does not provide an API (see b/149780822). Keras names its metrics + # using the following format: + # _[weighted]_ + result = [] + for metric in metrics: + sub_key = None + if hasattr(metric, "class_id") and metric.class_id is not None: + sub_key = metric_types.SubKey(class_id=metric.class_id) + elif hasattr(metric, "top_k") and metric.top_k is not None: + sub_key = metric_types.SubKey(top_k=metric.top_k) + for output_name in output_names or []: + if metric.name.startswith(output_name + "_"): + # TODO(b/171559113): Output prefixes used to be added multiple times. + # Remove this while loop after the last TF version with the issue is + # no longer supported. + name = metric.name + while name.startswith(output_name + "_"): + name = name[len(output_name) + 1 :] + result.append( + metric_types.MetricKey( + name=name, + model_name=model_name, + output_name=output_name, + sub_key=sub_key, + example_weighted=None, + ) + ) + break + else: + result.append( + metric_types.MetricKey( + name=metric.name, + model_name=model_name, + sub_key=sub_key, + example_weighted=None, + ) ) - ) - break - else: - result.append( - metric_types.MetricKey( - name=metric.name, - model_name=model_name, - sub_key=sub_key, - example_weighted=None, - ) - ) - return result + return result @beam.typehints.with_input_types(metric_types.StandardMetricInputs) @beam.typehints.with_output_types(Dict[metric_types.MetricKey, np.ndarray]) class _KerasCombiner(model_util.CombineFnWithModels): - """Base combiner for aggregating metrics for keras based models.""" - - def __init__( - self, - keys: List[metric_types.MetricKey], - model_name: str, - model_loader: types.ModelLoader, - eval_config: Optional[config_pb2.EvalConfig], - desired_batch_size: Optional[int] = None, - beam_metrics_prefix: str = '', - ): - super().__init__({model_name: model_loader}) - self._keys = keys - self._model_name = model_name - self._eval_config = eval_config - self._desired_batch_size = desired_batch_size - self._model = None - # This combiner makes use of the TFMetricsAccumulator to track the inputs - # and outputs. While the TFMetricsAccumulator is designed to store output - # weights for each input, this doesn't work well with metrics from the keras - # model because all the outputs get mixed together. So in this case the - # _output_names will contain ['', output1, output2, ...] and _output_counts - # will contain the corresponding counts of metrics for each output (with '' - # including the counts of metrics that are aggregations over multiple - # outputs - e.g. weighted loss). For the actual computations, we will store - # the inputs under the respective output indices (with '' having no inputs), - # but store all the metric weights under output index 0. - self._output_names = sorted(set(key.output_name or '' for key in keys)) - counts = collections.defaultdict(int) - for key in keys: - if key.output_name: - counts[key.output_name] += 1 - counts[''] = len(keys) - self._output_counts = [counts[name] for name in self._output_names] - self._batch_size_beam_metric_dist = beam.metrics.Metrics.distribution( - constants.METRICS_NAMESPACE, - '{}_combine_batch_size'.format(beam_metrics_prefix), - ) - self._total_input_byte_size_beam_metric_dist = ( - beam.metrics.Metrics.distribution( + """Base combiner for aggregating metrics for keras based models.""" + + def __init__( + self, + keys: List[metric_types.MetricKey], + model_name: str, + model_loader: types.ModelLoader, + eval_config: Optional[config_pb2.EvalConfig], + desired_batch_size: Optional[int] = None, + beam_metrics_prefix: str = "", + ): + super().__init__({model_name: model_loader}) + self._keys = keys + self._model_name = model_name + self._eval_config = eval_config + self._desired_batch_size = desired_batch_size + self._model = None + # This combiner makes use of the TFMetricsAccumulator to track the inputs + # and outputs. While the TFMetricsAccumulator is designed to store output + # weights for each input, this doesn't work well with metrics from the keras + # model because all the outputs get mixed together. So in this case the + # _output_names will contain ['', output1, output2, ...] and _output_counts + # will contain the corresponding counts of metrics for each output (with '' + # including the counts of metrics that are aggregations over multiple + # outputs - e.g. weighted loss). For the actual computations, we will store + # the inputs under the respective output indices (with '' having no inputs), + # but store all the metric weights under output index 0. + self._output_names = sorted(set(key.output_name or "" for key in keys)) + counts = collections.defaultdict(int) + for key in keys: + if key.output_name: + counts[key.output_name] += 1 + counts[""] = len(keys) + self._output_counts = [counts[name] for name in self._output_names] + self._batch_size_beam_metric_dist = beam.metrics.Metrics.distribution( constants.METRICS_NAMESPACE, - '{}_combine_batch_bytes_size'.format(beam_metrics_prefix), + f"{beam_metrics_prefix}_combine_batch_size", + ) + self._total_input_byte_size_beam_metric_dist = ( + beam.metrics.Metrics.distribution( + constants.METRICS_NAMESPACE, + f"{beam_metrics_prefix}_combine_batch_bytes_size", + ) + ) + self._num_compacts = beam.metrics.Metrics.counter( + constants.METRICS_NAMESPACE, "num_compacts" ) - ) - self._num_compacts = beam.metrics.Metrics.counter( - constants.METRICS_NAMESPACE, 'num_compacts' - ) - - def setup(self): - if self._model is None: - # TODO(b/179500321): We are skipping the shared handle here to ensure that - # we don't have issues with sharing the model between threads. This is - # very inefficient, we should just clone the model but - # tf_keras.models.clone_model removes compiled_metrics. - self._model = self._model_loaders[ # pylint: disable=protected-access - self._model_name - ]._construct_fn_with_load_time(self._set_model_load_seconds)() - - def _metrics(self) -> Iterable[tf_keras.metrics.Metric]: - """Returns metrics used by combiner.""" - raise NotImplementedError('Subclasses are expected to override this.') - - def _create_accumulator(self) -> tf_metric_accumulators.TFMetricsAccumulator: - """Returns a new accumulator.""" - raise NotImplementedError('Subclasses are expected to override this.') - - def _add_input( - self, - accumulator: tf_metric_accumulators.TFMetricsAccumulator, - element: metric_types.StandardMetricInputs, - ) -> tf_metric_accumulators.TFMetricsAccumulator: - """Add input to the accumulator.""" - raise NotImplementedError('Subclasses are expected to override this.') - - def _update_state( - self, accumulator: tf_metric_accumulators.TFMetricsAccumulator - ): - """Updates state for metrics associated with model.""" - raise NotImplementedError('Subclasses are expected to override this.') - - def _process_batch( - self, accumulator: tf_metric_accumulators.TFMetricsAccumulator - ): - if accumulator.len_inputs() == 0: - return - self._batch_size_beam_metric_dist.update(accumulator.len_inputs()) - self._total_input_byte_size_beam_metric_dist.update( - accumulator.get_size_estimate() - ) - for metric_index, metric in enumerate(self._metrics()): - metric.reset_states() - self._update_state(accumulator) - # For metrics stored with the model, the outputs get encoded in the - # metric names so we will use a single output for the weights and parse the - # names at the end to separate metrics by output. - for metric_index, metric in enumerate(self._metrics()): - accumulator.add_weights(0, metric_index, metric.get_weights()) - accumulator.clear_inputs() - - def create_accumulator(self) -> tf_metric_accumulators.TFMetricsAccumulator: - return self._create_accumulator() - - def add_input( - self, - accumulator: tf_metric_accumulators.TFMetricsAccumulator, - element: metric_types.StandardMetricInputs, - ) -> tf_metric_accumulators.TFMetricsAccumulator: - accumulator = self._add_input(accumulator, element) - if accumulator.should_flush(): - self._process_batch(accumulator) - return accumulator - - def merge_accumulators( - self, accumulators: Iterable[tf_metric_accumulators.TFMetricsAccumulator] - ) -> tf_metric_accumulators.TFMetricsAccumulator: - accumulators = iter(accumulators) - result = next(accumulators) - # Finish processing last batch - self._process_batch(result) - for accumulator in accumulators: - # Finish processing last batch - self._process_batch(accumulator) - # Merge the weights - for metric_index in range(len(self._keys)): - weights = accumulator.get_weights(0, metric_index) - if weights is None: - # It is possible for beam to create an accumulator but pass no - # inputs to it resulting in empty weights. In theory all weights - # should be empty but we check on a per metric weights basis. - continue - result.add_weights(0, metric_index, weights) - return result - def compact( - self, accumulator: tf_metric_accumulators.TFMetricsAccumulator - ) -> tf_metric_accumulators.TFMetricsAccumulator: - self._process_batch(accumulator) - self._num_compacts.inc(1) - return accumulator - - def extract_output( - self, accumulator: tf_metric_accumulators.TFMetricsAccumulator - ) -> Dict[metric_types.MetricKey, np.ndarray]: - # Finish processing last batch - self._process_batch(accumulator) - result = {} - for metric_index, metric in enumerate(self._metrics()): - key = self._keys[metric_index] - weights = accumulator.get_weights(0, metric_index) - if weights is not None: - metric.set_weights(weights) - else: - metric.reset_states() - result[key] = metric.result().numpy() - return result + def setup(self): + if self._model is None: + # TODO(b/179500321): We are skipping the shared handle here to ensure that + # we don't have issues with sharing the model between threads. This is + # very inefficient, we should just clone the model but + # tf_keras.models.clone_model removes compiled_metrics. + self._model = self._model_loaders[ # pylint: disable=protected-access + self._model_name + ]._construct_fn_with_load_time(self._set_model_load_seconds)() + + def _metrics(self) -> Iterable[tf_keras.metrics.Metric]: + """Returns metrics used by combiner.""" + raise NotImplementedError("Subclasses are expected to override this.") + + def _create_accumulator(self) -> tf_metric_accumulators.TFMetricsAccumulator: + """Returns a new accumulator.""" + raise NotImplementedError("Subclasses are expected to override this.") + + def _add_input( + self, + accumulator: tf_metric_accumulators.TFMetricsAccumulator, + element: metric_types.StandardMetricInputs, + ) -> tf_metric_accumulators.TFMetricsAccumulator: + """Add input to the accumulator.""" + raise NotImplementedError("Subclasses are expected to override this.") + + def _update_state(self, accumulator: tf_metric_accumulators.TFMetricsAccumulator): + """Updates state for metrics associated with model.""" + raise NotImplementedError("Subclasses are expected to override this.") + + def _process_batch(self, accumulator: tf_metric_accumulators.TFMetricsAccumulator): + if accumulator.len_inputs() == 0: + return + self._batch_size_beam_metric_dist.update(accumulator.len_inputs()) + self._total_input_byte_size_beam_metric_dist.update( + accumulator.get_size_estimate() + ) + for metric_index, metric in enumerate(self._metrics()): + metric.reset_states() + self._update_state(accumulator) + # For metrics stored with the model, the outputs get encoded in the + # metric names so we will use a single output for the weights and parse the + # names at the end to separate metrics by output. + for metric_index, metric in enumerate(self._metrics()): + accumulator.add_weights(0, metric_index, metric.get_weights()) + accumulator.clear_inputs() + + def create_accumulator(self) -> tf_metric_accumulators.TFMetricsAccumulator: + return self._create_accumulator() + + def add_input( + self, + accumulator: tf_metric_accumulators.TFMetricsAccumulator, + element: metric_types.StandardMetricInputs, + ) -> tf_metric_accumulators.TFMetricsAccumulator: + accumulator = self._add_input(accumulator, element) + if accumulator.should_flush(): + self._process_batch(accumulator) + return accumulator + + def merge_accumulators( + self, accumulators: Iterable[tf_metric_accumulators.TFMetricsAccumulator] + ) -> tf_metric_accumulators.TFMetricsAccumulator: + accumulators = iter(accumulators) + result = next(accumulators) + # Finish processing last batch + self._process_batch(result) + for accumulator in accumulators: + # Finish processing last batch + self._process_batch(accumulator) + # Merge the weights + for metric_index in range(len(self._keys)): + weights = accumulator.get_weights(0, metric_index) + if weights is None: + # It is possible for beam to create an accumulator but pass no + # inputs to it resulting in empty weights. In theory all weights + # should be empty but we check on a per metric weights basis. + continue + result.add_weights(0, metric_index, weights) + return result + + def compact( + self, accumulator: tf_metric_accumulators.TFMetricsAccumulator + ) -> tf_metric_accumulators.TFMetricsAccumulator: + self._process_batch(accumulator) + self._num_compacts.inc(1) + return accumulator + + def extract_output( + self, accumulator: tf_metric_accumulators.TFMetricsAccumulator + ) -> Dict[metric_types.MetricKey, np.ndarray]: + # Finish processing last batch + self._process_batch(accumulator) + result = {} + for metric_index, metric in enumerate(self._metrics()): + key = self._keys[metric_index] + weights = accumulator.get_weights(0, metric_index) + if weights is not None: + metric.set_weights(weights) + else: + metric.reset_states() + result[key] = metric.result().numpy() + return result @beam.typehints.with_input_types(metric_types.StandardMetricInputs) @beam.typehints.with_output_types(Dict[metric_types.MetricKey, np.ndarray]) class _KerasCompiledMetricsCombiner(_KerasCombiner): - """Aggregates metrics using keras compiled_metrics and compiled_loss.""" - - def __init__( - self, - keys: List[metric_types.MetricKey], - model_name: str, - model_loader: types.ModelLoader, - eval_config: Optional[config_pb2.EvalConfig], - desired_batch_size: Optional[int] = None, - ): - super().__init__( - keys, - model_name, - model_loader, - eval_config, - desired_batch_size, - 'keras_compiled_metrics_combine', - ) - - def _metrics(self) -> Iterable[tf_keras.metrics.Metric]: - return chain( - self._model.compiled_metrics.metrics, self._model.compiled_loss.metrics - ) - - def _create_accumulator( - self, - ) -> tf_metric_accumulators.TFCompilableMetricsAccumulator: - padding_options = None - if self._eval_config is not None: - model_spec = model_util.get_model_spec( - self._eval_config, self._model_name - ) - if model_spec is not None and model_spec.HasField('padding_options'): - padding_options = model_spec.padding_options - return tf_metric_accumulators.TFCompilableMetricsAccumulator( - padding_options, - self._output_counts, - desired_batch_size=self._desired_batch_size, - ) - - def _add_input( - self, - accumulator: tf_metric_accumulators.TFCompilableMetricsAccumulator, - element: metric_types.StandardMetricInputs, - ) -> tf_metric_accumulators.TFCompilableMetricsAccumulator: - for i, output_name in enumerate(self._output_names): - if not output_name and len(self._output_names) > 1: - # The first output_name for multi-output models is '' and is used to - # store combined metric weights for all outputs, but is not for inputs. - labels, predictions, example_weights = None, None, None - else: - labels, predictions, example_weights = next( - metric_util.to_label_prediction_example_weight( - element, - self._eval_config, - self._model_name, - output_name, - flatten=False, - example_weighted=True, - ) + """Aggregates metrics using keras compiled_metrics and compiled_loss.""" + + def __init__( + self, + keys: List[metric_types.MetricKey], + model_name: str, + model_loader: types.ModelLoader, + eval_config: Optional[config_pb2.EvalConfig], + desired_batch_size: Optional[int] = None, + ): + super().__init__( + keys, + model_name, + model_loader, + eval_config, + desired_batch_size, + "keras_compiled_metrics_combine", ) - accumulator.add_input(i, labels, predictions, example_weights) - return accumulator - - def _update_state( - self, accumulator: tf_metric_accumulators.TFCompilableMetricsAccumulator - ): - if len(self._output_names) == 1: - # Single-output models don't use dicts. - l, p, w = accumulator.get_inputs(0) - labels = tf.convert_to_tensor(l) - predictions = tf.convert_to_tensor(p) - example_weights = tf.convert_to_tensor(w) - else: - labels = {} - predictions = {} - example_weights = {} - for i, output_name in enumerate(self._output_names): - if not output_name: - # The empty output_name for multi-output models is not used for inputs - continue - l, p, w = accumulator.get_inputs(i) - labels[output_name] = tf.convert_to_tensor(l) - predictions[output_name] = tf.convert_to_tensor(p) - example_weights[output_name] = tf.convert_to_tensor(w) - self._model.compiled_metrics.update_state( - labels, predictions, sample_weight=example_weights - ) - self._model.compiled_loss( - labels, predictions, sample_weight=example_weights - ) + + def _metrics(self) -> Iterable[tf_keras.metrics.Metric]: + return chain( + self._model.compiled_metrics.metrics, self._model.compiled_loss.metrics + ) + + def _create_accumulator( + self, + ) -> tf_metric_accumulators.TFCompilableMetricsAccumulator: + padding_options = None + if self._eval_config is not None: + model_spec = model_util.get_model_spec(self._eval_config, self._model_name) + if model_spec is not None and model_spec.HasField("padding_options"): + padding_options = model_spec.padding_options + return tf_metric_accumulators.TFCompilableMetricsAccumulator( + padding_options, + self._output_counts, + desired_batch_size=self._desired_batch_size, + ) + + def _add_input( + self, + accumulator: tf_metric_accumulators.TFCompilableMetricsAccumulator, + element: metric_types.StandardMetricInputs, + ) -> tf_metric_accumulators.TFCompilableMetricsAccumulator: + for i, output_name in enumerate(self._output_names): + if not output_name and len(self._output_names) > 1: + # The first output_name for multi-output models is '' and is used to + # store combined metric weights for all outputs, but is not for inputs. + labels, predictions, example_weights = None, None, None + else: + labels, predictions, example_weights = next( + metric_util.to_label_prediction_example_weight( + element, + self._eval_config, + self._model_name, + output_name, + flatten=False, + example_weighted=True, + ) + ) + accumulator.add_input(i, labels, predictions, example_weights) + return accumulator + + def _update_state( + self, accumulator: tf_metric_accumulators.TFCompilableMetricsAccumulator + ): + if len(self._output_names) == 1: + # Single-output models don't use dicts. + l, p, w = accumulator.get_inputs(0) + labels = tf.convert_to_tensor(l) + predictions = tf.convert_to_tensor(p) + example_weights = tf.convert_to_tensor(w) + else: + labels = {} + predictions = {} + example_weights = {} + for i, output_name in enumerate(self._output_names): + if not output_name: + # The empty output_name for multi-output models is not used for inputs + continue + l, p, w = accumulator.get_inputs(i) + labels[output_name] = tf.convert_to_tensor(l) + predictions[output_name] = tf.convert_to_tensor(p) + example_weights[output_name] = tf.convert_to_tensor(w) + self._model.compiled_metrics.update_state( + labels, predictions, sample_weight=example_weights + ) + self._model.compiled_loss(labels, predictions, sample_weight=example_weights) @beam.typehints.with_input_types(metric_types.StandardMetricInputs) @beam.typehints.with_output_types(Dict[metric_types.MetricKey, np.ndarray]) class _KerasEvaluateCombiner(_KerasCombiner): - """Aggregates metrics using keras model.evaluate method.""" - - def __init__( - self, - keys: List[metric_types.MetricKey], - model_name: str, - model_loader: types.ModelLoader, - eval_config: Optional[config_pb2.EvalConfig], - desired_batch_size: Optional[int] = None, - ): - super().__init__( - keys, - model_name, - model_loader, - eval_config, - desired_batch_size, - 'keras_evaluate_combine', - ) - - def _metrics(self) -> Iterable[tf_keras.metrics.Metric]: - return self._model.metrics - - def _create_accumulator(self) -> tf_metric_accumulators.TFMetricsAccumulator: - return tf_metric_accumulators.TFMetricsAccumulator( - # Separate inputs are tracked for (inputs, labels, example_weights). - # Since the inputs are the same for each output, only the first output - # index will set the input data. - input_counts=[3] * len(self._output_counts), - metric_counts=self._output_counts, - size_estimator_fn=len, - desired_batch_size=self._desired_batch_size, - ) - - def _add_input( - self, - accumulator: tf_metric_accumulators.TFMetricsAccumulator, - element: metric_types.StandardMetricInputs, - ) -> tf_metric_accumulators.TFMetricsAccumulator: - for i, output_name in enumerate(self._output_names): - if not output_name and len(self._output_names) > 1: - # The first output_name for multi-output models is '' and is used to - # store combined metric weights for all outputs, but is not for labels - # and example weights. - labels, example_weights = None, None - else: - labels, _, example_weights = next( - metric_util.to_label_prediction_example_weight( - element, - self._eval_config, - self._model_name, - output_name, - flatten=False, - example_weighted=True, - ) + """Aggregates metrics using keras model.evaluate method.""" + + def __init__( + self, + keys: List[metric_types.MetricKey], + model_name: str, + model_loader: types.ModelLoader, + eval_config: Optional[config_pb2.EvalConfig], + desired_batch_size: Optional[int] = None, + ): + super().__init__( + keys, + model_name, + model_loader, + eval_config, + desired_batch_size, + "keras_evaluate_combine", ) - if i == 0: - if element.transformed_features: - features = {} - features.update(element.features) - features.update(element.transformed_features) - else: - features = element.features - else: - features = None - accumulator.add_input(i, features, labels, example_weights) - - return accumulator - - def _update_state( - self, accumulator: tf_metric_accumulators.TFMetricsAccumulator - ): - features = {} - labels = {} - example_weights = {} - for i, output_name in enumerate(self._output_names): - f, l, w = accumulator.get_inputs(i) - if i == 0: - features = util.merge_extracts(f) - if not output_name and len(self._output_names) > 1: - # The empty output_name for multi-output models is not used for inputs. - continue - labels[output_name] = np.array(l) - weights = np.array(w) - # TFv1 will not squeeze the weights, so must do manually - if weights.shape[-1] == 1: - weights = weights.squeeze(axis=-1) - example_weights[output_name] = weights - if len(self._output_names) == 1: - # Single-output models don't use dicts. - labels = next(iter(labels.values())) - example_weights = next(iter(example_weights.values())) - input_specs = model_util.get_input_specs(self._model, signature_name=None) - inputs = model_util.get_inputs(features, input_specs) - if inputs is None: - raise ValueError( - 'unable to prepare inputs for evaluation: ' - f'input_specs={input_specs}, features={features}' - ) - self._model.evaluate( - x=inputs, - y=labels, - batch_size=util.batch_size(features), - verbose=0, - sample_weight=example_weights, - ) + def _metrics(self) -> Iterable[tf_keras.metrics.Metric]: + return self._model.metrics + + def _create_accumulator(self) -> tf_metric_accumulators.TFMetricsAccumulator: + return tf_metric_accumulators.TFMetricsAccumulator( + # Separate inputs are tracked for (inputs, labels, example_weights). + # Since the inputs are the same for each output, only the first output + # index will set the input data. + input_counts=[3] * len(self._output_counts), + metric_counts=self._output_counts, + size_estimator_fn=len, + desired_batch_size=self._desired_batch_size, + ) + + def _add_input( + self, + accumulator: tf_metric_accumulators.TFMetricsAccumulator, + element: metric_types.StandardMetricInputs, + ) -> tf_metric_accumulators.TFMetricsAccumulator: + for i, output_name in enumerate(self._output_names): + if not output_name and len(self._output_names) > 1: + # The first output_name for multi-output models is '' and is used to + # store combined metric weights for all outputs, but is not for labels + # and example weights. + labels, example_weights = None, None + else: + labels, _, example_weights = next( + metric_util.to_label_prediction_example_weight( + element, + self._eval_config, + self._model_name, + output_name, + flatten=False, + example_weighted=True, + ) + ) + + if i == 0: + if element.transformed_features: + features = {} + features.update(element.features) + features.update(element.transformed_features) + else: + features = element.features + else: + features = None + accumulator.add_input(i, features, labels, example_weights) + + return accumulator + + def _update_state(self, accumulator: tf_metric_accumulators.TFMetricsAccumulator): + features = {} + labels = {} + example_weights = {} + for i, output_name in enumerate(self._output_names): + f, l, w = accumulator.get_inputs(i) + if i == 0: + features = util.merge_extracts(f) + if not output_name and len(self._output_names) > 1: + # The empty output_name for multi-output models is not used for inputs. + continue + labels[output_name] = np.array(l) + weights = np.array(w) + # TFv1 will not squeeze the weights, so must do manually + if weights.shape[-1] == 1: + weights = weights.squeeze(axis=-1) + example_weights[output_name] = weights + if len(self._output_names) == 1: + # Single-output models don't use dicts. + labels = next(iter(labels.values())) + example_weights = next(iter(example_weights.values())) + input_specs = model_util.get_input_specs(self._model, signature_name=None) + inputs = model_util.get_inputs(features, input_specs) + if inputs is None: + raise ValueError( + "unable to prepare inputs for evaluation: " + f"input_specs={input_specs}, features={features}" + ) + self._model.evaluate( + x=inputs, + y=labels, + batch_size=util.batch_size(features), + verbose=0, + sample_weight=example_weights, + ) diff --git a/tensorflow_model_analysis/evaluators/legacy_poisson_bootstrap.py b/tensorflow_model_analysis/evaluators/legacy_poisson_bootstrap.py index 62c6a7671c..ebaf591d5a 100644 --- a/tensorflow_model_analysis/evaluators/legacy_poisson_bootstrap.py +++ b/tensorflow_model_analysis/evaluators/legacy_poisson_bootstrap.py @@ -17,11 +17,11 @@ import apache_beam as beam import numpy as np +from google.protobuf import message + from tensorflow_model_analysis.api import types from tensorflow_model_analysis.slicer import slicer_lib as slicer -from google.protobuf import message - DEFAULT_NUM_BOOTSTRAP_SAMPLES = 20 # TFMA v1 uses Text for its keys while TFMA v2 uses MetricKey @@ -36,181 +36,184 @@ def ComputeWithConfidenceIntervals( # pylint: disable=invalid-name compute_per_slice_metrics_cls: Type[beam.PTransform], num_bootstrap_samples: Optional[int] = DEFAULT_NUM_BOOTSTRAP_SAMPLES, random_seed_for_testing: Optional[int] = None, - **kwargs + **kwargs, ) -> beam.pvalue.PCollection: - """PTransform for computing metrics using T-Distribution values. - - Args: - sliced_extracts: Incoming PCollection consisting of slice key and extracts. - compute_per_slice_metrics_cls: PTransform class that takes a PCollection of - (slice key, extracts) as input and returns (slice key, dict of metrics) as - output. The class will be instantiated multiple times to compute metrics - both with and without sampling. The class will be initialized using kwargs - 'compute_with_sampling' and 'random_seed_for_testing' along with any - kwargs passed in **kwargs. - num_bootstrap_samples: Number of replicas to use in calculating uncertainty - using bootstrapping. If 1 is provided (default), aggregate metrics will be - calculated with no uncertainty. If num_bootstrap_samples is > 0, multiple - samples of each slice will be calculated using the Poisson bootstrap - method. To calculate standard errors, num_bootstrap_samples should be 20 - or more in order to provide useful data. More is better, but you pay a - performance cost. - random_seed_for_testing: Seed to use for unit testing, because - nondeterministic tests stink. Each partition will use this value + i. - **kwargs: Additional args to pass to compute_per_slice_metrics_cls init. - - Returns: - PCollection of (slice key, dict of metrics) - """ - if not num_bootstrap_samples: - num_bootstrap_samples = 1 - # TODO(ckuhn): Cap the number of bootstrap samples at 20. - if num_bootstrap_samples < 1: - raise ValueError( - 'num_bootstrap_samples should be > 0, got %d' % num_bootstrap_samples - ) + """PTransform for computing metrics using T-Distribution values. + + Args: + ---- + sliced_extracts: Incoming PCollection consisting of slice key and extracts. + compute_per_slice_metrics_cls: PTransform class that takes a PCollection of + (slice key, extracts) as input and returns (slice key, dict of metrics) as + output. The class will be instantiated multiple times to compute metrics + both with and without sampling. The class will be initialized using kwargs + 'compute_with_sampling' and 'random_seed_for_testing' along with any + kwargs passed in **kwargs. + num_bootstrap_samples: Number of replicas to use in calculating uncertainty + using bootstrapping. If 1 is provided (default), aggregate metrics will be + calculated with no uncertainty. If num_bootstrap_samples is > 0, multiple + samples of each slice will be calculated using the Poisson bootstrap + method. To calculate standard errors, num_bootstrap_samples should be 20 + or more in order to provide useful data. More is better, but you pay a + performance cost. + random_seed_for_testing: Seed to use for unit testing, because + nondeterministic tests stink. Each partition will use this value + i. + **kwargs: Additional args to pass to compute_per_slice_metrics_cls init. + + Returns: + ------- + PCollection of (slice key, dict of metrics) + """ + if not num_bootstrap_samples: + num_bootstrap_samples = 1 + # TODO(ckuhn): Cap the number of bootstrap samples at 20. + if num_bootstrap_samples < 1: + raise ValueError( + "num_bootstrap_samples should be > 0, got %d" % num_bootstrap_samples + ) - output_results = ( - sliced_extracts - | 'ComputeUnsampledMetrics' - >> compute_per_slice_metrics_cls( - compute_with_sampling=False, random_seed_for_testing=None, **kwargs - ) - ) - - if num_bootstrap_samples > 1: - multicombine = [] - for i in range(num_bootstrap_samples): - seed = ( - None - if random_seed_for_testing is None - else random_seed_for_testing + i - ) - multicombine.append( - sliced_extracts - | 'ComputeSampledMetrics%d' % i - >> compute_per_slice_metrics_cls( - compute_with_sampling=True, random_seed_for_testing=seed, **kwargs - ) - ) output_results = ( - multicombine - | 'FlattenBootstrapPartitions' >> beam.Flatten() - | 'GroupBySlice' >> beam.GroupByKey() - | 'MergeBootstrap' - >> beam.ParDo(_MergeBootstrap(), beam.pvalue.AsDict(output_results)) + sliced_extracts + | "ComputeUnsampledMetrics" + >> compute_per_slice_metrics_cls( + compute_with_sampling=False, random_seed_for_testing=None, **kwargs + ) ) - return output_results - -class _MergeBootstrap(beam.DoFn): - """Merge the bootstrap values and fit a T-distribution to get confidence.""" - - def process( - self, - element: Tuple[slicer.SliceKeyType, Iterable[_MetricsDict]], - unsampled_results: Dict[slicer.SliceKeyType, _MetricsDict], - ) -> Iterator[Tuple[slicer.SliceKeyType, _MetricsDict]]: - """Merge the bootstrap values. - - Args: - element: The element is the tuple that contains slice key and a list of - the metrics dict. It's the output of the GroupByKey step. All the - metrics that under the same slice key are generated by - poisson-bootstrap. - unsampled_results: The unsampled_results is passed in as a side input. - It's a tuple that contains the slice key and the metrics dict from a run - of the slice with no sampling (ie, all examples in the set are - represented exactly once.) This should be identical to the values - obtained without sampling. - - Yields: - A tuple of slice key and the metrics dict which contains the unsampled - value, as well as parameters about t distribution. If the metric is a - proto only the unsampled value will be returned. - - Raises: - ValueError if the key of metrics inside element does not equal to the - key of metrics in unsampled_results. - """ - slice_key, metrics = element - # metrics should be a list of dicts, but the dataflow runner has a quirk - # that requires specific casting. - metrics = list(metrics) - if len(metrics) == 1: - yield slice_key, metrics[0] - return - - # Group the same metrics into one list. - metrics_dict = {} - for metric in metrics: - for metrics_name in metric: - if metrics_name not in metrics_dict: - metrics_dict[metrics_name] = [] - metrics_dict[metrics_name].append(metric[metrics_name]) - - unsampled_metrics_dict = unsampled_results.get(slice_key, {}) - - # The key set of the two metrics dicts must be identical. - if set(metrics_dict.keys()) != set(unsampled_metrics_dict.keys()): - raise ValueError( - 'Keys of two metrics do not match: sampled_metrics: %s. ' - 'unsampled_metrics: %s' - % (metrics_dict.keys(), unsampled_metrics_dict.keys()) - ) - - metrics_with_confidence = {} - for metrics_name in metrics_dict: - # If metric is a proto, return as is. - unsampled_value = unsampled_metrics_dict[metrics_name] - if isinstance(unsampled_value, message.Message): - metrics_with_confidence[metrics_name] = unsampled_value - else: - metrics_with_confidence[metrics_name] = _calculate_t_distribution( - metrics_dict[metrics_name], unsampled_value + if num_bootstrap_samples > 1: + multicombine = [] + for i in range(num_bootstrap_samples): + seed = ( + None if random_seed_for_testing is None else random_seed_for_testing + i + ) + multicombine.append( + sliced_extracts + | "ComputeSampledMetrics%d" % i + >> compute_per_slice_metrics_cls( + compute_with_sampling=True, random_seed_for_testing=seed, **kwargs + ) + ) + output_results = ( + multicombine + | "FlattenBootstrapPartitions" >> beam.Flatten() + | "GroupBySlice" >> beam.GroupByKey() + | "MergeBootstrap" + >> beam.ParDo(_MergeBootstrap(), beam.pvalue.AsDict(output_results)) ) + return output_results - yield slice_key, metrics_with_confidence + +class _MergeBootstrap(beam.DoFn): + """Merge the bootstrap values and fit a T-distribution to get confidence.""" + + def process( + self, + element: Tuple[slicer.SliceKeyType, Iterable[_MetricsDict]], + unsampled_results: Dict[slicer.SliceKeyType, _MetricsDict], + ) -> Iterator[Tuple[slicer.SliceKeyType, _MetricsDict]]: + """Merge the bootstrap values. + + Args: + ---- + element: The element is the tuple that contains slice key and a list of + the metrics dict. It's the output of the GroupByKey step. All the + metrics that under the same slice key are generated by + poisson-bootstrap. + unsampled_results: The unsampled_results is passed in as a side input. + It's a tuple that contains the slice key and the metrics dict from a run + of the slice with no sampling (ie, all examples in the set are + represented exactly once.) This should be identical to the values + obtained without sampling. + + Yields: + ------ + A tuple of slice key and the metrics dict which contains the unsampled + value, as well as parameters about t distribution. If the metric is a + proto only the unsampled value will be returned. + + Raises: + ------ + ValueError if the key of metrics inside element does not equal to the + key of metrics in unsampled_results. + """ + slice_key, metrics = element + # metrics should be a list of dicts, but the dataflow runner has a quirk + # that requires specific casting. + metrics = list(metrics) + if len(metrics) == 1: + yield slice_key, metrics[0] + return + + # Group the same metrics into one list. + metrics_dict = {} + for metric in metrics: + for metrics_name in metric: + if metrics_name not in metrics_dict: + metrics_dict[metrics_name] = [] + metrics_dict[metrics_name].append(metric[metrics_name]) + + unsampled_metrics_dict = unsampled_results.get(slice_key, {}) + + # The key set of the two metrics dicts must be identical. + if set(metrics_dict.keys()) != set(unsampled_metrics_dict.keys()): + raise ValueError( + "Keys of two metrics do not match: sampled_metrics: %s. " + "unsampled_metrics: %s" + % (metrics_dict.keys(), unsampled_metrics_dict.keys()) + ) + + metrics_with_confidence = {} + for metrics_name in metrics_dict: + # If metric is a proto, return as is. + unsampled_value = unsampled_metrics_dict[metrics_name] + if isinstance(unsampled_value, message.Message): + metrics_with_confidence[metrics_name] = unsampled_value + else: + metrics_with_confidence[metrics_name] = _calculate_t_distribution( + metrics_dict[metrics_name], unsampled_value + ) + + yield slice_key, metrics_with_confidence def _calculate_t_distribution( # pylint: disable=invalid-name sampling_data_list: List[Union[int, float, np.ndarray]], unsampled_data: Union[int, float, np.ndarray], ): - """Calculate the confidence interval of the data. - - Args: - sampling_data_list: A list of number or np.ndarray. - unsampled_data: Individual number or np.ndarray. The format of the - unsampled_data should match the format of the element inside - sampling_data_list. - - Returns: - Confidence Interval value stored inside - types.ValueWithTDistribution. - """ - if isinstance(sampling_data_list[0], (np.ndarray, list)): - merged_data = sampling_data_list[0][:] - if isinstance(sampling_data_list[0], np.ndarray): - merged_data = merged_data.astype(object) - for index in range(len(merged_data)): - merged_data[index] = _calculate_t_distribution( - [data[index] for data in sampling_data_list], unsampled_data[index] - ) - return merged_data - else: - # Data has to be numeric. That means throw out nan values. - sampling_data_list = [ - data for data in sampling_data_list if not np.isnan(data) - ] - n_samples = len(sampling_data_list) - if n_samples: - sample_mean = np.mean(sampling_data_list) - sample_std = np.std(sampling_data_list, ddof=1) - return types.ValueWithTDistribution( - sample_mean, sample_std, n_samples - 1, unsampled_data - ) + """Calculate the confidence interval of the data. + + Args: + ---- + sampling_data_list: A list of number or np.ndarray. + unsampled_data: Individual number or np.ndarray. The format of the + unsampled_data should match the format of the element inside + sampling_data_list. + + Returns: + ------- + Confidence Interval value stored inside + types.ValueWithTDistribution. + """ + if isinstance(sampling_data_list[0], (np.ndarray, list)): + merged_data = sampling_data_list[0][:] + if isinstance(sampling_data_list[0], np.ndarray): + merged_data = merged_data.astype(object) + for index in range(len(merged_data)): + merged_data[index] = _calculate_t_distribution( + [data[index] for data in sampling_data_list], unsampled_data[index] + ) + return merged_data else: - return types.ValueWithTDistribution( - float('nan'), float('nan'), -1, float('nan') - ) + # Data has to be numeric. That means throw out nan values. + sampling_data_list = [data for data in sampling_data_list if not np.isnan(data)] + n_samples = len(sampling_data_list) + if n_samples: + sample_mean = np.mean(sampling_data_list) + sample_std = np.std(sampling_data_list, ddof=1) + return types.ValueWithTDistribution( + sample_mean, sample_std, n_samples - 1, unsampled_data + ) + else: + return types.ValueWithTDistribution( + float("nan"), float("nan"), -1, float("nan") + ) diff --git a/tensorflow_model_analysis/evaluators/legacy_poisson_bootstrap_test.py b/tensorflow_model_analysis/evaluators/legacy_poisson_bootstrap_test.py index e46aa917ea..2c154a84f2 100644 --- a/tensorflow_model_analysis/evaluators/legacy_poisson_bootstrap_test.py +++ b/tensorflow_model_analysis/evaluators/legacy_poisson_bootstrap_test.py @@ -15,80 +15,86 @@ import numpy as np import tensorflow as tf + from tensorflow_model_analysis.api import types -from tensorflow_model_analysis.evaluators import legacy_poisson_bootstrap as poisson_bootstrap +from tensorflow_model_analysis.evaluators import ( + legacy_poisson_bootstrap as poisson_bootstrap, +) class PoissonBootstrapTest(tf.test.TestCase): - - def testCalculateConfidenceInterval(self): - sampling_data_list = [ - np.array([ - [0, 0, 2, 7, 0.77777779, 1], - [1, 0, 2, 6, 0.75, 0.85714287], - [4, 0, 2, 3, 0.60000002, 0.42857143], - [4, 2, 0, 3, 1, 0.42857143], - [7, 2, 0, 0, float('nan'), 0], - ]), - np.array([ - [7, 2, 0, 0, float('nan'), 0], - [0, 0, 2, 7, 0.77777779, 1], - [1, 0, 2, 6, 0.75, 0.85714287], - [4, 0, 2, 3, 0.60000002, 0.42857143], - [4, 2, 0, 3, 1, 0.42857143], - ]), - ] - unsampled_data = np.array([ - [4, 2, 0, 3, 1, 0.42857143], - [7, 2, 0, 0, float('nan'), 0], - [0, 0, 2, 7, 0.77777779, 1], - [1, 0, 2, 6, 0.75, 0.85714287], - [4, 0, 2, 3, 0.60000002, 0.42857143], - ]) - result = poisson_bootstrap._calculate_t_distribution( - sampling_data_list, unsampled_data - ) - self.assertIsInstance(result, np.ndarray) - self.assertEqual(result.shape, (5, 6)) - self.assertAlmostEqual(result[0][0].sample_mean, 3.5, delta=0.1) - self.assertAlmostEqual( - result[0][0].sample_standard_deviation, 4.94, delta=0.1 - ) - self.assertEqual(result[0][0].sample_degrees_of_freedom, 1) - self.assertEqual(result[0][0].unsampled_value, 4.0) - self.assertAlmostEqual(result[0][4].sample_mean, 0.77, delta=0.1) - self.assertTrue(np.isnan(result[0][4].sample_standard_deviation)) - self.assertEqual(result[0][4].sample_degrees_of_freedom, 0) - self.assertEqual(result[0][4].unsampled_value, 1.0) - - sampling_data_list = [ - np.array([1, 2]), - np.array([1, 2]), - np.array([1, float('nan')]), - ] - unsampled_data = np.array([1, 2]) - result = poisson_bootstrap._calculate_t_distribution( - sampling_data_list, unsampled_data - ) - self.assertIsInstance(result, np.ndarray) - self.assertEqual( - result.tolist(), - [ - types.ValueWithTDistribution( - sample_mean=1.0, - sample_standard_deviation=0.0, - sample_degrees_of_freedom=2, - unsampled_value=1, + def testCalculateConfidenceInterval(self): + sampling_data_list = [ + np.array( + [ + [0, 0, 2, 7, 0.77777779, 1], + [1, 0, 2, 6, 0.75, 0.85714287], + [4, 0, 2, 3, 0.60000002, 0.42857143], + [4, 2, 0, 3, 1, 0.42857143], + [7, 2, 0, 0, float("nan"), 0], + ] ), - types.ValueWithTDistribution( - sample_mean=2.0, - sample_standard_deviation=0.0, - sample_degrees_of_freedom=1, - unsampled_value=2, + np.array( + [ + [7, 2, 0, 0, float("nan"), 0], + [0, 0, 2, 7, 0.77777779, 1], + [1, 0, 2, 6, 0.75, 0.85714287], + [4, 0, 2, 3, 0.60000002, 0.42857143], + [4, 2, 0, 3, 1, 0.42857143], + ] ), - ], - ) + ] + unsampled_data = np.array( + [ + [4, 2, 0, 3, 1, 0.42857143], + [7, 2, 0, 0, float("nan"), 0], + [0, 0, 2, 7, 0.77777779, 1], + [1, 0, 2, 6, 0.75, 0.85714287], + [4, 0, 2, 3, 0.60000002, 0.42857143], + ] + ) + result = poisson_bootstrap._calculate_t_distribution( + sampling_data_list, unsampled_data + ) + self.assertIsInstance(result, np.ndarray) + self.assertEqual(result.shape, (5, 6)) + self.assertAlmostEqual(result[0][0].sample_mean, 3.5, delta=0.1) + self.assertAlmostEqual(result[0][0].sample_standard_deviation, 4.94, delta=0.1) + self.assertEqual(result[0][0].sample_degrees_of_freedom, 1) + self.assertEqual(result[0][0].unsampled_value, 4.0) + self.assertAlmostEqual(result[0][4].sample_mean, 0.77, delta=0.1) + self.assertTrue(np.isnan(result[0][4].sample_standard_deviation)) + self.assertEqual(result[0][4].sample_degrees_of_freedom, 0) + self.assertEqual(result[0][4].unsampled_value, 1.0) + + sampling_data_list = [ + np.array([1, 2]), + np.array([1, 2]), + np.array([1, float("nan")]), + ] + unsampled_data = np.array([1, 2]) + result = poisson_bootstrap._calculate_t_distribution( + sampling_data_list, unsampled_data + ) + self.assertIsInstance(result, np.ndarray) + self.assertEqual( + result.tolist(), + [ + types.ValueWithTDistribution( + sample_mean=1.0, + sample_standard_deviation=0.0, + sample_degrees_of_freedom=2, + unsampled_value=1, + ), + types.ValueWithTDistribution( + sample_mean=2.0, + sample_standard_deviation=0.0, + sample_degrees_of_freedom=1, + unsampled_value=2, + ), + ], + ) -if __name__ == '__main__': - tf.test.main() +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow_model_analysis/evaluators/metrics_plots_and_validations_evaluator.py b/tensorflow_model_analysis/evaluators/metrics_plots_and_validations_evaluator.py index 4c4d41d2de..6e35acf5a3 100644 --- a/tensorflow_model_analysis/evaluators/metrics_plots_and_validations_evaluator.py +++ b/tensorflow_model_analysis/evaluators/metrics_plots_and_validations_evaluator.py @@ -17,35 +17,46 @@ import datetime import itertools import numbers -from typing import Any, Dict, Iterable, Iterator, List, NamedTuple, Optional, Tuple, Type, TypeVar, Union +from typing import ( + Any, + Dict, + Iterable, + Iterator, + List, + NamedTuple, + Optional, + Tuple, + Type, + TypeVar, + Union, +) import apache_beam as beam import numpy as np +from tensorflow_metadata.proto.v0 import schema_pb2 + from tensorflow_model_analysis import constants from tensorflow_model_analysis.api import types -from tensorflow_model_analysis.evaluators import counter_util -from tensorflow_model_analysis.evaluators import evaluator -from tensorflow_model_analysis.evaluators import jackknife -from tensorflow_model_analysis.evaluators import keras_util -from tensorflow_model_analysis.evaluators import metrics_validator -from tensorflow_model_analysis.evaluators import poisson_bootstrap +from tensorflow_model_analysis.evaluators import ( + counter_util, + evaluator, + jackknife, + keras_util, + metrics_validator, + poisson_bootstrap, +) from tensorflow_model_analysis.extractors import slice_key_extractor -from tensorflow_model_analysis.metrics import metric_specs -from tensorflow_model_analysis.metrics import metric_types -from tensorflow_model_analysis.metrics import metric_util +from tensorflow_model_analysis.metrics import metric_specs, metric_types, metric_util from tensorflow_model_analysis.proto import config_pb2 from tensorflow_model_analysis.slicer import slicer_lib as slicer -from tensorflow_model_analysis.utils import model_util -from tensorflow_model_analysis.utils import util - -from tensorflow_metadata.proto.v0 import schema_pb2 +from tensorflow_model_analysis.utils import model_util, util SliceKeyTypeVar = TypeVar( - 'SliceKeyTypeVar', slicer.SliceKeyType, slicer.CrossSliceKeyType + "SliceKeyTypeVar", slicer.SliceKeyType, slicer.CrossSliceKeyType ) -_COMBINER_INPUTS_KEY = '_combiner_inputs' -_DEFAULT_COMBINER_INPUT_KEY = '_default_combiner_input' +_COMBINER_INPUTS_KEY = "_combiner_inputs" +_DEFAULT_COMBINER_INPUT_KEY = "_default_combiner_input" _DEFAULT_NUM_JACKKNIFE_BUCKETS = 20 _DEFAULT_NUM_BOOTSTRAP_SAMPLES = 20 @@ -68,127 +79,121 @@ def MetricsPlotsAndValidationsEvaluator( # pylint: disable=invalid-name schema: Optional[schema_pb2.Schema] = None, random_seed_for_testing: Optional[int] = None, ) -> evaluator.Evaluator: - """Creates an Evaluator for evaluating metrics and plots. - - Args: - eval_config: Eval config. - eval_shared_model: Optional shared model (single-model evaluation) or list - of shared models (multi-model evaluation). Only required if there are - metrics to be computed in-graph using the model. - metrics_key: Name to use for metrics key in Evaluation output. - plots_key: Name to use for plots key in Evaluation output. - attributions_key: Name to use for attributions key in Evaluation output. - run_after: Extractor to run after (None means before any extractors). - schema: A schema to use for customizing metrics and plots. - random_seed_for_testing: Seed to use for unit testing. - - Returns: - Evaluator for evaluating metrics and plots. The output will be stored under - 'metrics' and 'plots' keys. - """ - eval_shared_models = model_util.verify_and_update_eval_shared_models( - eval_shared_model - ) - if eval_shared_models: - eval_shared_models = {m.model_name: m for m in eval_shared_models} - - # pylint: disable=no-value-for-parameter - return evaluator.Evaluator( - stage_name='EvaluateMetricsAndPlots', - run_after=run_after, - ptransform=_EvaluateMetricsPlotsAndValidations( - eval_config=eval_config, - eval_shared_models=eval_shared_models, - metrics_key=metrics_key, - plots_key=plots_key, - attributions_key=attributions_key, - schema=schema, - random_seed_for_testing=random_seed_for_testing, - ), - ) - - -MetricComputations = NamedTuple( - 'MetricComputations', - [ - ('non_derived_computations', List[metric_types.MetricComputation]), - ('derived_computations', List[metric_types.DerivedMetricComputation]), - ( - 'cross_slice_computations', - List[metric_types.CrossSliceMetricComputation], - ), - ( - 'ci_derived_computations', - List[metric_types.CIDerivedMetricComputation], + """Creates an Evaluator for evaluating metrics and plots. + + Args: + ---- + eval_config: Eval config. + eval_shared_model: Optional shared model (single-model evaluation) or list + of shared models (multi-model evaluation). Only required if there are + metrics to be computed in-graph using the model. + metrics_key: Name to use for metrics key in Evaluation output. + plots_key: Name to use for plots key in Evaluation output. + attributions_key: Name to use for attributions key in Evaluation output. + run_after: Extractor to run after (None means before any extractors). + schema: A schema to use for customizing metrics and plots. + random_seed_for_testing: Seed to use for unit testing. + + Returns: + ------- + Evaluator for evaluating metrics and plots. The output will be stored under + 'metrics' and 'plots' keys. + """ + eval_shared_models = model_util.verify_and_update_eval_shared_models( + eval_shared_model + ) + if eval_shared_models: + eval_shared_models = {m.model_name: m for m in eval_shared_models} + + # pylint: disable=no-value-for-parameter + return evaluator.Evaluator( + stage_name="EvaluateMetricsAndPlots", + run_after=run_after, + ptransform=_EvaluateMetricsPlotsAndValidations( + eval_config=eval_config, + eval_shared_models=eval_shared_models, + metrics_key=metrics_key, + plots_key=plots_key, + attributions_key=attributions_key, + schema=schema, + random_seed_for_testing=random_seed_for_testing, ), - ], -) + ) + + +class MetricComputations(NamedTuple): + non_derived_computations: List[metric_types.MetricComputation] + derived_computations: List[metric_types.DerivedMetricComputation] + cross_slice_computations: List[metric_types.CrossSliceMetricComputation] + ci_derived_computations: List[metric_types.CIDerivedMetricComputation] def _filter_and_separate_computations( computations: metric_types.MetricComputations, ) -> MetricComputations: - """Filters duplicate computations and separates non-derived and derived. - - All metrics are based on either direct computations using combiners or are - based on the results of one or more other computations. This code separates - the three types of computations so that only the combiner based computations - are passed to the main combiner call and the remainder are processed after - those combiners have run. Filtering is required because - DerivedMetricComputations and CrossSliceMetricComputations typically include - copies of the MetricComputations that they depend on in order to avoid having - to pre-construct and pass around all the dependencies at the time the metrics - are constructed. Instead, each derived metric creates a version of the metric - it depends on and then this code de-dups computations that are identical so - only one gets computed. - - Args: - computations: Computations. - - Returns: - Tuple of (metric computations, derived metric computations, cross slice - metric computations, CI derived metric computations). - """ - non_derived_computations = [] - processed_non_derived_computations = {} - derived_computations = [] - processed_derived_computations = {} - cross_slice_computations = [] - processed_cross_slice_computations = {} - ci_derived_computations = [] - processed_ci_derived_computations = {} - # The order of the computations matters (i.e. one computation may depend on - # another). While there shouldn't be any differences in matching computations - # the implemented hash is only based on combiner/result names and the keys. To - # ensure we don't move a metric ahead of its dependencies (see b/205314632) we - # will return the first computation added when there are duplicates. - for c in computations: - if isinstance(c, metric_types.MetricComputation): - if c not in processed_non_derived_computations: - processed_non_derived_computations[c] = len(non_derived_computations) - non_derived_computations.append(c) - # CIDerivedMetricComputation is inherited from DerivedMetricComputation, so - # order of elif's matter here. - elif isinstance(c, metric_types.CIDerivedMetricComputation): - if c not in processed_ci_derived_computations: - processed_ci_derived_computations[c] = len(ci_derived_computations) - ci_derived_computations.append(c) - elif isinstance(c, metric_types.DerivedMetricComputation): - if c not in processed_derived_computations: - processed_derived_computations[c] = len(derived_computations) - derived_computations.append(c) - elif isinstance(c, metric_types.CrossSliceMetricComputation): - if c not in processed_cross_slice_computations: - processed_cross_slice_computations[c] = len(cross_slice_computations) - cross_slice_computations.append(c) - else: - raise TypeError('Unsupported metric computation type: {}'.format(c)) - return MetricComputations( - non_derived_computations=non_derived_computations, - derived_computations=derived_computations, - cross_slice_computations=cross_slice_computations, - ci_derived_computations=ci_derived_computations, - ) + """Filters duplicate computations and separates non-derived and derived. + + All metrics are based on either direct computations using combiners or are + based on the results of one or more other computations. This code separates + the three types of computations so that only the combiner based computations + are passed to the main combiner call and the remainder are processed after + those combiners have run. Filtering is required because + DerivedMetricComputations and CrossSliceMetricComputations typically include + copies of the MetricComputations that they depend on in order to avoid having + to pre-construct and pass around all the dependencies at the time the metrics + are constructed. Instead, each derived metric creates a version of the metric + it depends on and then this code de-dups computations that are identical so + only one gets computed. + + Args: + ---- + computations: Computations. + + Returns: + ------- + Tuple of (metric computations, derived metric computations, cross slice + metric computations, CI derived metric computations). + """ + non_derived_computations = [] + processed_non_derived_computations = {} + derived_computations = [] + processed_derived_computations = {} + cross_slice_computations = [] + processed_cross_slice_computations = {} + ci_derived_computations = [] + processed_ci_derived_computations = {} + # The order of the computations matters (i.e. one computation may depend on + # another). While there shouldn't be any differences in matching computations + # the implemented hash is only based on combiner/result names and the keys. To + # ensure we don't move a metric ahead of its dependencies (see b/205314632) we + # will return the first computation added when there are duplicates. + for c in computations: + if isinstance(c, metric_types.MetricComputation): + if c not in processed_non_derived_computations: + processed_non_derived_computations[c] = len(non_derived_computations) + non_derived_computations.append(c) + # CIDerivedMetricComputation is inherited from DerivedMetricComputation, so + # order of elif's matter here. + elif isinstance(c, metric_types.CIDerivedMetricComputation): + if c not in processed_ci_derived_computations: + processed_ci_derived_computations[c] = len(ci_derived_computations) + ci_derived_computations.append(c) + elif isinstance(c, metric_types.DerivedMetricComputation): + if c not in processed_derived_computations: + processed_derived_computations[c] = len(derived_computations) + derived_computations.append(c) + elif isinstance(c, metric_types.CrossSliceMetricComputation): + if c not in processed_cross_slice_computations: + processed_cross_slice_computations[c] = len(cross_slice_computations) + cross_slice_computations.append(c) + else: + raise TypeError(f"Unsupported metric computation type: {c}") + return MetricComputations( + non_derived_computations=non_derived_computations, + derived_computations=derived_computations, + cross_slice_computations=cross_slice_computations, + ci_derived_computations=ci_derived_computations, + ) @beam.ptransform_fn @@ -198,257 +203,257 @@ def _GroupByQueryKey( # pylint: disable=invalid-name extracts: beam.pvalue.PCollection, query_key: str, ) -> beam.pvalue.PCollection: - """PTransform for grouping extracts by a query key. - - Args: - extracts: Incoming PCollection consisting of extracts. - query_key: Query key to group extracts by. Must be a member of the dict of - features stored under tfma.FEATURES_KEY. - - Returns: - PCollection of lists of extracts where each list is associated with same - query key. - """ - missing_query_key_counter = beam.metrics.Metrics.counter( - constants.METRICS_NAMESPACE, 'missing_query_key' - ) - - def key_by_query_key( - extracts: types.Extracts, query_key: str - ) -> Tuple[str, types.Extracts]: - """Extract the query key from the extract and key by that.""" - value = metric_util.to_scalar( - util.get_by_keys( - extracts, [constants.FEATURES_KEY, query_key], optional=True - ), - tensor_name=query_key, + """PTransform for grouping extracts by a query key. + + Args: + ---- + extracts: Incoming PCollection consisting of extracts. + query_key: Query key to group extracts by. Must be a member of the dict of + features stored under tfma.FEATURES_KEY. + + Returns: + ------- + PCollection of lists of extracts where each list is associated with same + query key. + """ + missing_query_key_counter = beam.metrics.Metrics.counter( + constants.METRICS_NAMESPACE, "missing_query_key" ) - if value is None: - missing_query_key_counter.inc() - return ('', extracts) - return ('{}'.format(value), extracts) - # pylint: disable=no-value-for-parameter - return ( - extracts - | 'KeyByQueryId' >> beam.Map(key_by_query_key, query_key) - | 'GroupByKey' >> beam.CombinePerKey(beam.combiners.ToListCombineFn()) - | 'DropQueryId' >> beam.Map(lambda kv: kv[1]) - | 'MergeExtracts' >> beam.Map(util.merge_extracts) - ) + def key_by_query_key( + extracts: types.Extracts, query_key: str + ) -> Tuple[str, types.Extracts]: + """Extract the query key from the extract and key by that.""" + value = metric_util.to_scalar( + util.get_by_keys( + extracts, [constants.FEATURES_KEY, query_key], optional=True + ), + tensor_name=query_key, + ) + if value is None: + missing_query_key_counter.inc() + return ("", extracts) + return (f"{value}", extracts) + + # pylint: disable=no-value-for-parameter + return ( + extracts + | "KeyByQueryId" >> beam.Map(key_by_query_key, query_key) + | "GroupByKey" >> beam.CombinePerKey(beam.combiners.ToListCombineFn()) + | "DropQueryId" >> beam.Map(lambda kv: kv[1]) + | "MergeExtracts" >> beam.Map(util.merge_extracts) + ) class _PreprocessorDoFn(beam.DoFn): - """Do function that computes initial state from extracts. - - The outputs for each preprocessor are stored under the key '_combiner_inputs' - in the overall extracts returned by this process call. These outputs are - stored as a list in same order as the computations were passed as input so - that the combiner can later access them by index. For computations that use - the default labels, predictions, and example weights as their combiner inputs, - the list entries will contain None values. A '_default_combiner_inputs' - extract will also exist (if needed) containing StandardMetricInputs. - - If a FeaturePreprocessor is used the outputs of the preprocessor will be - combined with the default labels, predictions, and example weights and stored - in the StandardMetricInputs features value under the _default_combiner_inputs - key. - - If the incoming data is a list of extracts (i.e. a query_key was used), the - output will be a single extract with the keys within the extract representing - the list as processed by the preprocessor. For example, the _slice_key_types - will be a merger of all unique _slice key_types across the extracts list - and the _default_combiner_inputs will be a list of StandardMetricInputs (one - for each example matching the query_key). - """ - - def __init__(self, computations: List[metric_types.MetricComputation]): - self._computations = computations - self._evaluate_num_instances = beam.metrics.Metrics.counter( - constants.METRICS_NAMESPACE, 'evaluate_num_instances' - ) - self._timer = beam.metrics.Metrics.distribution( - constants.METRICS_NAMESPACE, '_PreprocessorDoFn_seconds' - ) + """Do function that computes initial state from extracts. + + The outputs for each preprocessor are stored under the key '_combiner_inputs' + in the overall extracts returned by this process call. These outputs are + stored as a list in same order as the computations were passed as input so + that the combiner can later access them by index. For computations that use + the default labels, predictions, and example weights as their combiner inputs, + the list entries will contain None values. A '_default_combiner_inputs' + extract will also exist (if needed) containing StandardMetricInputs. + + If a FeaturePreprocessor is used the outputs of the preprocessor will be + combined with the default labels, predictions, and example weights and stored + in the StandardMetricInputs features value under the _default_combiner_inputs + key. + + If the incoming data is a list of extracts (i.e. a query_key was used), the + output will be a single extract with the keys within the extract representing + the list as processed by the preprocessor. For example, the _slice_key_types + will be a merger of all unique _slice key_types across the extracts list + and the _default_combiner_inputs will be a list of StandardMetricInputs (one + for each example matching the query_key). + """ - def setup(self): - for computation in self._computations: - if computation.preprocessors: - for preprocessor in computation.preprocessors: - preprocessor.setup() - - def start_bundle(self): - for computation in self._computations: - if computation.preprocessors: - for preprocessor in computation.preprocessors: - preprocessor.start_bundle() - - def finish_bundle(self): - for computation in self._computations: - if computation.preprocessors: - for preprocessor in computation.preprocessors: - preprocessor.finish_bundle() - - def teardown(self): - for computation in self._computations: - if computation.preprocessors: - for preprocessor in computation.preprocessors: - preprocessor.teardown() - - def process(self, extracts: types.Extracts) -> Iterable[Any]: - start_time = datetime.datetime.now() - self._evaluate_num_instances.inc(1) - - # Any combiner_inputs that are set to None will have the default - # StandardMetricInputs passed to the combiner's add_input method. Note - # that for efficiency a single StandardMetricInputs type is created that has - # an include_filter that is a merger of the include_filter values for all - # StandardMetricInputsProcessors used by all metrics. This avoids processing - # extracts more than once, but does mean metrics may contain - # StandardMetricInputs with keys that are not part of their preprocessing - # filters. - combiner_inputs = [] - standard_preprocessors = [] - added_default_standard_preprocessor = False - for computation in self._computations: - if not computation.preprocessors: - # In this case, the combiner is requesting to be passed the default - # StandardMetricInputs (i.e. labels, predictions, and example weights). - combiner_inputs.append(None) - if not added_default_standard_preprocessor: - standard_preprocessors.append( - metric_types.StandardMetricInputsPreprocessor() - ) - added_default_standard_preprocessor = True - elif ( - len(computation.preprocessors) == 1 - and type(computation.preprocessors[0]) # pylint: disable=unidiomatic-typecheck - == metric_types.StandardMetricInputsPreprocessor - ): - # In this case a custom filter was used, but it is still part of the - # StandardMetricInputs. This will be merged into a single preprocessor - # for efficiency later, but we still use None to indicate that the - # shared StandardMetricInputs value should be passed to the combiner. - combiner_inputs.append(None) - standard_preprocessors.append(computation.preprocessors[0]) - else: - # The combiner accepts the list of outputs from preprocessors. It allows - # all the following combiners to share the same work without - # duplication. - preprocessed_extracts = [copy.copy(extracts)] - for preprocessor in computation.preprocessors: - # For each of the extract, it is processed to becomes - # a iterator of extracts. They are flattened through the chain - # operations. - preprocessed_extracts = list( - itertools.chain.from_iterable([ - list(preprocessor.process(extract)) - for extract in preprocessed_extracts - ]) - ) - # Combiner inputs appends a list of processed extracts - combiner_inputs.append(preprocessed_extracts) - - output = { - constants.SLICE_KEY_TYPES_KEY: extracts[constants.SLICE_KEY_TYPES_KEY], - _COMBINER_INPUTS_KEY: combiner_inputs, - } - if standard_preprocessors: - preprocessor = metric_types.StandardMetricInputsPreprocessorList( - standard_preprocessors - ) - extracts = copy.copy(extracts) - # TODO(b/229267982): Consider removing include flags. - default_combiner_input = [] - for extracts in preprocessor.process(extracts): - default_combiner_input.append( - metric_util.to_standard_metric_inputs( - extracts, - include_labels=constants.LABELS_KEY - in preprocessor.include_filter, - include_predictions=( - constants.PREDICTIONS_KEY in preprocessor.include_filter - ), - include_any_feature=( - (constants.FEATURES_KEY in preprocessor.include_filter) - or ( - constants.TRANSFORMED_FEATURES_KEY - in preprocessor.include_filter + def __init__(self, computations: List[metric_types.MetricComputation]): + self._computations = computations + self._evaluate_num_instances = beam.metrics.Metrics.counter( + constants.METRICS_NAMESPACE, "evaluate_num_instances" + ) + self._timer = beam.metrics.Metrics.distribution( + constants.METRICS_NAMESPACE, "_PreprocessorDoFn_seconds" + ) + + def setup(self): + for computation in self._computations: + if computation.preprocessors: + for preprocessor in computation.preprocessors: + preprocessor.setup() + + def start_bundle(self): + for computation in self._computations: + if computation.preprocessors: + for preprocessor in computation.preprocessors: + preprocessor.start_bundle() + + def finish_bundle(self): + for computation in self._computations: + if computation.preprocessors: + for preprocessor in computation.preprocessors: + preprocessor.finish_bundle() + + def teardown(self): + for computation in self._computations: + if computation.preprocessors: + for preprocessor in computation.preprocessors: + preprocessor.teardown() + + def process(self, extracts: types.Extracts) -> Iterable[Any]: + start_time = datetime.datetime.now() + self._evaluate_num_instances.inc(1) + + # Any combiner_inputs that are set to None will have the default + # StandardMetricInputs passed to the combiner's add_input method. Note + # that for efficiency a single StandardMetricInputs type is created that has + # an include_filter that is a merger of the include_filter values for all + # StandardMetricInputsProcessors used by all metrics. This avoids processing + # extracts more than once, but does mean metrics may contain + # StandardMetricInputs with keys that are not part of their preprocessing + # filters. + combiner_inputs = [] + standard_preprocessors = [] + added_default_standard_preprocessor = False + for computation in self._computations: + if not computation.preprocessors: + # In this case, the combiner is requesting to be passed the default + # StandardMetricInputs (i.e. labels, predictions, and example weights). + combiner_inputs.append(None) + if not added_default_standard_preprocessor: + standard_preprocessors.append( + metric_types.StandardMetricInputsPreprocessor() ) - ), - include_attributions=( - constants.ATTRIBUTIONS_KEY in preprocessor.include_filter - ), + added_default_standard_preprocessor = True + elif ( + len(computation.preprocessors) == 1 + and type(computation.preprocessors[0]) # pylint: disable=unidiomatic-typecheck + == metric_types.StandardMetricInputsPreprocessor + ): + # In this case a custom filter was used, but it is still part of the + # StandardMetricInputs. This will be merged into a single preprocessor + # for efficiency later, but we still use None to indicate that the + # shared StandardMetricInputs value should be passed to the combiner. + combiner_inputs.append(None) + standard_preprocessors.append(computation.preprocessors[0]) + else: + # The combiner accepts the list of outputs from preprocessors. It allows + # all the following combiners to share the same work without + # duplication. + preprocessed_extracts = [copy.copy(extracts)] + for preprocessor in computation.preprocessors: + # For each of the extract, it is processed to becomes + # a iterator of extracts. They are flattened through the chain + # operations. + preprocessed_extracts = list( + itertools.chain.from_iterable( + [ + list(preprocessor.process(extract)) + for extract in preprocessed_extracts + ] + ) + ) + # Combiner inputs appends a list of processed extracts + combiner_inputs.append(preprocessed_extracts) + + output = { + constants.SLICE_KEY_TYPES_KEY: extracts[constants.SLICE_KEY_TYPES_KEY], + _COMBINER_INPUTS_KEY: combiner_inputs, + } + if standard_preprocessors: + preprocessor = metric_types.StandardMetricInputsPreprocessorList( + standard_preprocessors ) - ) - output[_DEFAULT_COMBINER_INPUT_KEY] = default_combiner_input - yield output + extracts = copy.copy(extracts) + # TODO(b/229267982): Consider removing include flags. + default_combiner_input = [] + for extracts in preprocessor.process(extracts): + default_combiner_input.append( + metric_util.to_standard_metric_inputs( + extracts, + include_labels=constants.LABELS_KEY + in preprocessor.include_filter, + include_predictions=( + constants.PREDICTIONS_KEY in preprocessor.include_filter + ), + include_any_feature=( + (constants.FEATURES_KEY in preprocessor.include_filter) + or ( + constants.TRANSFORMED_FEATURES_KEY + in preprocessor.include_filter + ) + ), + include_attributions=( + constants.ATTRIBUTIONS_KEY in preprocessor.include_filter + ), + ) + ) + output[_DEFAULT_COMBINER_INPUT_KEY] = default_combiner_input + yield output - self._timer.update( - int((datetime.datetime.now() - start_time).total_seconds()) - ) + self._timer.update(int((datetime.datetime.now() - start_time).total_seconds())) @beam.typehints.with_input_types(types.Extracts) @beam.typehints.with_output_types(metric_types.MetricsDict) class _ComputationsCombineFn(beam.combiners.SingleInputTupleCombineFn): - """Combine function that computes metric using initial state from extracts.""" - - def __init__(self, computations: List[metric_types.MetricComputation]): - """Init. - - Args: - computations: List of MetricComputations. - """ - super().__init__(*[c.combiner for c in computations]) - self._num_compacts = beam.metrics.Metrics.counter( - constants.METRICS_NAMESPACE, 'num_compacts' - ) + """Combine function that computes metric using initial state from extracts.""" + + def __init__(self, computations: List[metric_types.MetricComputation]): + """Init. + + Args: + ---- + computations: List of MetricComputations. + """ + super().__init__(*[c.combiner for c in computations]) + self._num_compacts = beam.metrics.Metrics.counter( + constants.METRICS_NAMESPACE, "num_compacts" + ) - def add_input(self, accumulator: Any, element: types.Extracts): - - def get_combiner_input(element, i): - item = element[_COMBINER_INPUTS_KEY][i] - if item is None: - item = element[_DEFAULT_COMBINER_INPUT_KEY] - return item - - results = [] - for i, (c, a) in enumerate(zip(self._combiners, accumulator)): - try: - combiner_input = get_combiner_input(element, i) - result = c.add_inputs(a, combiner_input) - except Exception as e: - raise RuntimeError( - f'add_input failed on "{c}" with inputs:\n{combiner_input}' - ) from e - results.append(result) - return tuple(results) - - def compact(self, accumulator: Any) -> Any: - self._num_compacts.inc(1) - return super().compact(accumulator) - - def extract_output(self, accumulator: Any) -> metric_types.MetricsDict: - result = {} - for c, a in zip(self._combiners, accumulator): - result.update(c.extract_output(a)) - return result + def add_input(self, accumulator: Any, element: types.Extracts): + def get_combiner_input(element, i): + item = element[_COMBINER_INPUTS_KEY][i] + if item is None: + item = element[_DEFAULT_COMBINER_INPUT_KEY] + return item + + results = [] + for i, (c, a) in enumerate(zip(self._combiners, accumulator)): + try: + combiner_input = get_combiner_input(element, i) + result = c.add_inputs(a, combiner_input) + except Exception as e: + raise RuntimeError( + f'add_input failed on "{c}" with inputs:\n{combiner_input}' + ) from e + results.append(result) + return tuple(results) + + def compact(self, accumulator: Any) -> Any: + self._num_compacts.inc(1) + return super().compact(accumulator) + + def extract_output(self, accumulator: Any) -> metric_types.MetricsDict: + result = {} + for c, a in zip(self._combiners, accumulator): + result.update(c.extract_output(a)) + return result def _is_private_metrics(metric_key: metric_types.MetricKey): - return metric_key.name.startswith('_') and not metric_key.name.startswith( - '__' - ) + return metric_key.name.startswith("_") and not metric_key.name.startswith("__") def _remove_private_metrics( slice_key: SliceKeyTypeVar, metrics: metric_types.MetricsDict ) -> Tuple[SliceKeyTypeVar, metric_types.MetricsDict]: - return ( - slice_key, - {k: v for (k, v) in metrics.items() if not _is_private_metrics(k)}, - ) + return ( + slice_key, + {k: v for (k, v) in metrics.items() if not _is_private_metrics(k)}, + ) @beam.ptransform_fn @@ -461,124 +466,110 @@ def _AddCrossSliceMetrics( # pylint: disable=invalid-name ) -> beam.pvalue.PCollection[ Tuple[slicer.SliceKeyOrCrossSliceKeyType, metric_types.MetricsDict] ]: - """Generates CrossSlice metrics from SingleSlices.""" - - def is_slice_applicable( - sliced_combiner_output: Tuple[ - slicer.SliceKeyType, metric_types.MetricsDict - ], - slicing_specs: Union[ - config_pb2.SlicingSpec, Iterable[config_pb2.SlicingSpec] - ], - ) -> bool: - slice_key, _ = sliced_combiner_output - for slicing_spec in slicing_specs: - if slicer.SingleSliceSpec(spec=slicing_spec).is_slice_applicable( - slice_key - ): - return True - return False - - def is_not_slice_applicable( - sliced_combiner_output: Tuple[ - slicer.SliceKeyType, metric_types.MetricsDict - ], - slicing_specs: Union[ - config_pb2.SlicingSpec, Iterable[config_pb2.SlicingSpec] - ], - ) -> bool: - return not is_slice_applicable(sliced_combiner_output, slicing_specs) - - def compute_cross_slices( - baseline_slice: Tuple[slicer.SliceKeyType, metric_types.MetricsDict], - comparison_slices: Iterable[ - Tuple[slicer.SliceKeyType, Dict[metric_types.MetricKey, Any]] - ], - ) -> Iterator[ - Tuple[slicer.CrossSliceKeyType, Dict[metric_types.MetricKey, Any]] - ]: - baseline_slice_key, baseline_metrics = baseline_slice - for comparison_slice_key, comparison_metrics in comparison_slices: - result = {} - for ( - comparison_metric_key, - comparison_metric_value, - ) in comparison_metrics.items(): - if ( - comparison_metric_key not in baseline_metrics - or _is_private_metrics(comparison_metric_key) - or not isinstance(comparison_metric_key, metric_types.MetricKey) - or isinstance(comparison_metric_key, metric_types.PlotKey) - or isinstance(comparison_metric_key, metric_types.AttributionsKey) - ): - continue - result[comparison_metric_key] = ( - baseline_metrics[comparison_metric_key] - comparison_metric_value - ) - - # Compute cross slice comparison for CrossSliceDerivedComputations - for c in cross_slice_computations: - result.update( - c.cross_slice_comparison(baseline_metrics, comparison_metrics) + """Generates CrossSlice metrics from SingleSlices.""" + + def is_slice_applicable( + sliced_combiner_output: Tuple[slicer.SliceKeyType, metric_types.MetricsDict], + slicing_specs: Union[config_pb2.SlicingSpec, Iterable[config_pb2.SlicingSpec]], + ) -> bool: + slice_key, _ = sliced_combiner_output + for slicing_spec in slicing_specs: + if slicer.SingleSliceSpec(spec=slicing_spec).is_slice_applicable(slice_key): + return True + return False + + def is_not_slice_applicable( + sliced_combiner_output: Tuple[slicer.SliceKeyType, metric_types.MetricsDict], + slicing_specs: Union[config_pb2.SlicingSpec, Iterable[config_pb2.SlicingSpec]], + ) -> bool: + return not is_slice_applicable(sliced_combiner_output, slicing_specs) + + def compute_cross_slices( + baseline_slice: Tuple[slicer.SliceKeyType, metric_types.MetricsDict], + comparison_slices: Iterable[ + Tuple[slicer.SliceKeyType, Dict[metric_types.MetricKey, Any]] + ], + ) -> Iterator[Tuple[slicer.CrossSliceKeyType, Dict[metric_types.MetricKey, Any]]]: + baseline_slice_key, baseline_metrics = baseline_slice + for comparison_slice_key, comparison_metrics in comparison_slices: + result = {} + for ( + comparison_metric_key, + comparison_metric_value, + ) in comparison_metrics.items(): + if ( + comparison_metric_key not in baseline_metrics + or _is_private_metrics(comparison_metric_key) + or not isinstance(comparison_metric_key, metric_types.MetricKey) + or isinstance(comparison_metric_key, metric_types.PlotKey) + or isinstance(comparison_metric_key, metric_types.AttributionsKey) + ): + continue + result[comparison_metric_key] = ( + baseline_metrics[comparison_metric_key] - comparison_metric_value + ) + + # Compute cross slice comparison for CrossSliceDerivedComputations + for c in cross_slice_computations: + result.update( + c.cross_slice_comparison(baseline_metrics, comparison_metrics) + ) + + yield ((baseline_slice_key, comparison_slice_key), result) + + cross_slice_outputs = [] + for cross_slice_ind, cross_slice_spec in enumerate(cross_slice_specs): + baseline_slices = ( + sliced_combiner_outputs + | "FilterBaselineSlices(%d)" % cross_slice_ind + >> beam.Filter(is_slice_applicable, [cross_slice_spec.baseline_spec]) ) - yield ((baseline_slice_key, comparison_slice_key), result) - - cross_slice_outputs = [] - for cross_slice_ind, cross_slice_spec in enumerate(cross_slice_specs): - baseline_slices = ( - sliced_combiner_outputs - | 'FilterBaselineSlices(%d)' % cross_slice_ind - >> beam.Filter(is_slice_applicable, [cross_slice_spec.baseline_spec]) - ) + if cross_slice_spec.slicing_specs: + slicing_specs = list(cross_slice_spec.slicing_specs) + comparison_slices = ( + sliced_combiner_outputs + | "FilterToComparisonSlices(%d)" % cross_slice_ind + >> beam.Filter(is_slice_applicable, slicing_specs) + ) + else: + # When slicing_specs is not set, consider all available slices except the + # baseline as candidates. + comparison_slices = ( + sliced_combiner_outputs + | "FilterOutBaselineSlices(%d)" % cross_slice_ind + >> beam.Filter( + is_not_slice_applicable, [cross_slice_spec.baseline_spec] + ) + ) - if cross_slice_spec.slicing_specs: - slicing_specs = list(cross_slice_spec.slicing_specs) - comparison_slices = ( - sliced_combiner_outputs - | 'FilterToComparisonSlices(%d)' % cross_slice_ind - >> beam.Filter(is_slice_applicable, slicing_specs) - ) - else: - # When slicing_specs is not set, consider all available slices except the - # baseline as candidates. - comparison_slices = ( - sliced_combiner_outputs - | 'FilterOutBaselineSlices(%d)' % cross_slice_ind - >> beam.Filter( - is_not_slice_applicable, [cross_slice_spec.baseline_spec] - ) - ) - - cross_slice_outputs.append( - baseline_slices - | 'GenerateCrossSlices(%d)' % cross_slice_ind - >> beam.FlatMap( - compute_cross_slices, - comparison_slices=beam.pvalue.AsIter(comparison_slices), + cross_slice_outputs.append( + baseline_slices + | "GenerateCrossSlices(%d)" % cross_slice_ind + >> beam.FlatMap( + compute_cross_slices, + comparison_slices=beam.pvalue.AsIter(comparison_slices), + ) ) - ) - if cross_slice_outputs: - cross_slice_outputs = ( - cross_slice_outputs | 'FlattenCrossSliceResults' >> beam.Flatten() - ) - return [ - sliced_combiner_outputs, - cross_slice_outputs, - ] | 'CombineSingleSlicesWithCrossSlice' >> beam.Flatten() - else: - return sliced_combiner_outputs + if cross_slice_outputs: + cross_slice_outputs = ( + cross_slice_outputs | "FlattenCrossSliceResults" >> beam.Flatten() + ) + return [ + sliced_combiner_outputs, + cross_slice_outputs, + ] | "CombineSingleSlicesWithCrossSlice" >> beam.Flatten() + else: + return sliced_combiner_outputs def _is_metric_diffable(metric_value: Any): - """Check whether a metric value is a number or an ndarray of numbers.""" - return isinstance( - metric_value, (numbers.Number, types.StructuredMetricValue) - ) or ( - isinstance(metric_value, np.ndarray) - and np.issubdtype(metric_value.dtype, np.number) - ) + """Check whether a metric value is a number or an ndarray of numbers.""" + return isinstance(metric_value, (numbers.Number, types.StructuredMetricValue)) or ( + isinstance(metric_value, np.ndarray) + and np.issubdtype(metric_value.dtype, np.number) + ) @beam.ptransform_fn @@ -593,80 +584,80 @@ def _AddDerivedCrossSliceAndDiffMetrics( # pylint: disable=invalid-name ) -> beam.PCollection[ Tuple[slicer.SliceKeyOrCrossSliceKeyType, metric_types.MetricsDict] ]: - """A PTransform for adding cross slice and derived metrics. - - This PTransform uses the input PCollection of sliced metrics to compute - derived metrics, cross-slice diff metrics, and cross-model diff metrics, in - that order. This means that cross-slice metrics are computed for base and - derived metrics, and that cross-model diffs are computed for base and derived - metrics corresponding to both single slices and cross-slice pairs. - - Args: - sliced_base_metrics: A PCollection of per-slice MetricsDicts containing the - metrics to be used as inputs for derived, cross-slice, and diff metrics. - derived_computations: List of DerivedMetricComputations. - cross_slice_computations: List of CrossSliceMetricComputation. - cross_slice_specs: List of CrossSlicingSpec. - baseline_model_name: Name for baseline model. - - Returns: - PCollection of sliced dict of metrics, containing all base metrics (that are - non-private), derived metrics, cross-slice metrics, and diff metrics. - """ - - def add_derived_metrics( - sliced_metrics: Tuple[slicer.SliceKeyType, metric_types.MetricsDict], - derived_computations: List[metric_types.DerivedMetricComputation], - ) -> Tuple[slicer.SliceKeyType, metric_types.MetricsDict]: - """Merges per-metric dicts into single dict and adds derived metrics.""" - slice_key, metrics = sliced_metrics - result = copy.copy(metrics) - for c in derived_computations: - result.update(c.result(result)) - return slice_key, result + """A PTransform for adding cross slice and derived metrics. - def add_diff_metrics( - sliced_metrics: Tuple[ - Union[slicer.SliceKeyType, slicer.CrossSliceKeyType], - Dict[metric_types.MetricKey, Any], - ], - baseline_model_name: Optional[str], - ) -> Tuple[slicer.SliceKeyType, Dict[metric_types.MetricKey, Any]]: - """Add diff metrics if there is a baseline model.""" + This PTransform uses the input PCollection of sliced metrics to compute + derived metrics, cross-slice diff metrics, and cross-model diff metrics, in + that order. This means that cross-slice metrics are computed for base and + derived metrics, and that cross-model diffs are computed for base and derived + metrics corresponding to both single slices and cross-slice pairs. - slice_key, metrics = sliced_metrics - result = copy.copy(metrics) - - if baseline_model_name: - diff_result = {} - for k, v in result.items(): - if _is_private_metrics(k): - continue - if k.is_diff: - # For metrics which directly produce diff metrics, we skip this step - continue - if ( - k.model_name != baseline_model_name - and k.make_baseline_key(baseline_model_name) in result - ): - # Check if metric is diffable, skip plots and non-numerical values. - if _is_metric_diffable(v): - diff_result[k.make_diff_key()] = ( - v - result[k.make_baseline_key(baseline_model_name)] - ) - result.update(diff_result) - return slice_key, result + Args: + ---- + sliced_base_metrics: A PCollection of per-slice MetricsDicts containing the + metrics to be used as inputs for derived, cross-slice, and diff metrics. + derived_computations: List of DerivedMetricComputations. + cross_slice_computations: List of CrossSliceMetricComputation. + cross_slice_specs: List of CrossSlicingSpec. + baseline_model_name: Name for baseline model. + + Returns: + ------- + PCollection of sliced dict of metrics, containing all base metrics (that are + non-private), derived metrics, cross-slice metrics, and diff metrics. + """ - return ( - sliced_base_metrics - | 'AddDerivedMetrics' - >> beam.Map(add_derived_metrics, derived_computations) - | 'AddCrossSliceMetrics' - >> _AddCrossSliceMetrics( # pylint: disable=no-value-for-parameter - cross_slice_specs, cross_slice_computations - ) - | 'AddDiffMetrics' >> beam.Map(add_diff_metrics, baseline_model_name) - ) + def add_derived_metrics( + sliced_metrics: Tuple[slicer.SliceKeyType, metric_types.MetricsDict], + derived_computations: List[metric_types.DerivedMetricComputation], + ) -> Tuple[slicer.SliceKeyType, metric_types.MetricsDict]: + """Merges per-metric dicts into single dict and adds derived metrics.""" + slice_key, metrics = sliced_metrics + result = copy.copy(metrics) + for c in derived_computations: + result.update(c.result(result)) + return slice_key, result + + def add_diff_metrics( + sliced_metrics: Tuple[ + Union[slicer.SliceKeyType, slicer.CrossSliceKeyType], + Dict[metric_types.MetricKey, Any], + ], + baseline_model_name: Optional[str], + ) -> Tuple[slicer.SliceKeyType, Dict[metric_types.MetricKey, Any]]: + """Add diff metrics if there is a baseline model.""" + slice_key, metrics = sliced_metrics + result = copy.copy(metrics) + + if baseline_model_name: + diff_result = {} + for k, v in result.items(): + if _is_private_metrics(k): + continue + if k.is_diff: + # For metrics which directly produce diff metrics, we skip this step + continue + if ( + k.model_name != baseline_model_name + and k.make_baseline_key(baseline_model_name) in result + ): + # Check if metric is diffable, skip plots and non-numerical values. + if _is_metric_diffable(v): + diff_result[k.make_diff_key()] = ( + v - result[k.make_baseline_key(baseline_model_name)] + ) + result.update(diff_result) + return slice_key, result + + return ( + sliced_base_metrics + | "AddDerivedMetrics" >> beam.Map(add_derived_metrics, derived_computations) + | "AddCrossSliceMetrics" + >> _AddCrossSliceMetrics( # pylint: disable=no-value-for-parameter + cross_slice_specs, cross_slice_computations + ) + | "AddDiffMetrics" >> beam.Map(add_diff_metrics, baseline_model_name) + ) def _filter_by_key_type( @@ -681,91 +672,89 @@ def _filter_by_key_type( ] ], ) -> Tuple[SliceKeyTypeVar, Dict[metric_types.MetricKey, Any]]: - """Filters metrics and plots by key type.""" - slice_value, metrics_plots_attributions = sliced_metrics_plots_attributions - output = {} - for k, v in metrics_plots_attributions.items(): - # PlotKey is a subclass of MetricKey so must check key_type based on PlotKey - if key_type == metric_types.PlotKey: - if isinstance(k, metric_types.PlotKey): - output[k] = v - # AttributionsKey is a also subclass of MetricKey - elif key_type == metric_types.AttributionsKey: - if isinstance(k, metric_types.AttributionsKey): - output[k] = v - else: - if not isinstance(k, metric_types.PlotKey) and not isinstance( - k, metric_types.AttributionsKey - ): - output[k] = v - return (slice_value, output) - - -_ConfidenceIntervalParams = NamedTuple( - '_ConfidenceIntervalParams', - [ - ('num_jackknife_samples', int), - ('num_bootstrap_samples', int), - ('skip_ci_metric_keys', Iterable[metric_types.MetricKey]), - ], -) + """Filters metrics and plots by key type.""" + slice_value, metrics_plots_attributions = sliced_metrics_plots_attributions + output = {} + for k, v in metrics_plots_attributions.items(): + # PlotKey is a subclass of MetricKey so must check key_type based on PlotKey + if key_type == metric_types.PlotKey: + if isinstance(k, metric_types.PlotKey): + output[k] = v + # AttributionsKey is a also subclass of MetricKey + elif key_type == metric_types.AttributionsKey: + if isinstance(k, metric_types.AttributionsKey): + output[k] = v + else: + if not isinstance(k, metric_types.PlotKey) and not isinstance( + k, metric_types.AttributionsKey + ): + output[k] = v + return (slice_value, output) + + +class _ConfidenceIntervalParams(NamedTuple): + num_jackknife_samples: int + num_bootstrap_samples: int + skip_ci_metric_keys: Iterable[metric_types.MetricKey] def _get_confidence_interval_params( eval_config: config_pb2.EvalConfig, metrics_specs: Iterable[config_pb2.MetricsSpec], ) -> _ConfidenceIntervalParams: - """Helper method for extracting confidence interval info from configs. - - Args: - eval_config: The eval_config. - metrics_specs: The metrics_specs containing either all metrics, or the ones - which share a query key. - - Returns: - A _ConfidenceIntervalParams object containing the number of jacknife samples - to use for computing a jackknife confidence interval, the number of - bootstrap samples to use for computing Poisson bootstrap confidence - intervals, and the set of metric keys which should not have confidence - intervals displayed in the output. - """ - skip_ci_metric_keys = ( - metric_specs.metric_keys_to_skip_for_confidence_intervals( - metrics_specs, eval_config=eval_config - ) - ) - num_jackknife_samples = 0 - num_bootstrap_samples = 0 - ci_method = eval_config.options.confidence_intervals.method - if eval_config.options.compute_confidence_intervals.value: - if ci_method == config_pb2.ConfidenceIntervalOptions.JACKKNIFE: - num_jackknife_samples = _DEFAULT_NUM_JACKKNIFE_BUCKETS - elif ci_method == config_pb2.ConfidenceIntervalOptions.POISSON_BOOTSTRAP: - num_bootstrap_samples = _DEFAULT_NUM_BOOTSTRAP_SAMPLES - return _ConfidenceIntervalParams( - num_jackknife_samples, num_bootstrap_samples, skip_ci_metric_keys - ) + """Helper method for extracting confidence interval info from configs. + + Args: + ---- + eval_config: The eval_config. + metrics_specs: The metrics_specs containing either all metrics, or the ones + which share a query key. + + Returns: + ------- + A _ConfidenceIntervalParams object containing the number of jacknife samples + to use for computing a jackknife confidence interval, the number of + bootstrap samples to use for computing Poisson bootstrap confidence + intervals, and the set of metric keys which should not have confidence + intervals displayed in the output. + """ + skip_ci_metric_keys = metric_specs.metric_keys_to_skip_for_confidence_intervals( + metrics_specs, eval_config=eval_config + ) + num_jackknife_samples = 0 + num_bootstrap_samples = 0 + ci_method = eval_config.options.confidence_intervals.method + if eval_config.options.compute_confidence_intervals.value: + if ci_method == config_pb2.ConfidenceIntervalOptions.JACKKNIFE: + num_jackknife_samples = _DEFAULT_NUM_JACKKNIFE_BUCKETS + elif ci_method == config_pb2.ConfidenceIntervalOptions.POISSON_BOOTSTRAP: + num_bootstrap_samples = _DEFAULT_NUM_BOOTSTRAP_SAMPLES + return _ConfidenceIntervalParams( + num_jackknife_samples, num_bootstrap_samples, skip_ci_metric_keys + ) def _add_ci_derived_metrics( sliced_metrics: Tuple[slicer.SliceKeyType, metric_types.MetricsDict], computations: List[metric_types.CIDerivedMetricComputation], ) -> Tuple[slicer.SliceKeyType, metric_types.MetricsDict]: - """PTransform to compute CI derived metrics. - - Args: - sliced_metrics: A PCollection of per-slice MetricsDicts containing the - metrics to be used as inputs for ci-derived metrics. - computations: List of CIDerivedMetricComputation. + """PTransform to compute CI derived metrics. - Returns: - PCollection of sliced dict of metrics updated with ci-derived metrics. - """ - slice_key, metrics = sliced_metrics - result = copy.copy(metrics) - for c in computations: - result.update(c.result(result)) - return slice_key, result + Args: + ---- + sliced_metrics: A PCollection of per-slice MetricsDicts containing the + metrics to be used as inputs for ci-derived metrics. + computations: List of CIDerivedMetricComputation. + + Returns: + ------- + PCollection of sliced dict of metrics updated with ci-derived metrics. + """ + slice_key, metrics = sliced_metrics + result = copy.copy(metrics) + for c in computations: + result.update(c.result(result)) + return slice_key, result @beam.ptransform_fn @@ -782,198 +771,195 @@ def _ComputeMetricsAndPlots( # pylint: disable=invalid-name schema: Optional[schema_pb2.Schema] = None, random_seed_for_testing: Optional[int] = None, ) -> evaluator.Evaluation: - """Computes metrics and plots. - - Args: - extracts: PCollection of Extracts. If a query_key was used then the - PCollection will contain a list of extracts. - eval_config: Eval config. - metrics_specs: Subset of the metric specs to compute metrics for. If a - query_key was used all of the metric specs will be for the same query_key. - eval_shared_models: Optional dict of shared models keyed by model name. Only - required if there are metrics to be computed in-graph using the model. - metrics_key: Name to use for metrics key in Evaluation output. - plots_key: Name to use for plots key in Evaluation output. - attributions_key: Name to use for attributions key in Evaluation output. - schema: A schema to use for customizing metrics and plots. - random_seed_for_testing: Seed to use for unit testing. - - Returns: - Evaluation containing dict of PCollections of (slice_key, results_dict) - tuples where the dict is keyed by either the metrics_key (e.g. 'metrics'), - plots_key (e.g. 'plots'), or attributions_key (e.g. 'attributions') - depending on what the results_dict contains. - """ - computations = [] - # Add default metric computations - if eval_shared_models: - # Note that there is the possibility for metric naming collisions here - # (e.g. 'auc' calculated within the model as well as by AUC metric - # computation performed outside the model). Currently all the overlapping - # metrics such as AUC that are computed outside the model are all derived - # metrics so they will override the metrics calculated by the model which is - # the desired behavior. - for model_name, eval_shared_model in eval_shared_models.items(): - if not eval_shared_model.include_default_metrics: - continue - if eval_shared_model.model_type == constants.TF_KERAS: - computations.extend( - keras_util.metric_computations_using_keras_saved_model( - model_name, eval_shared_model.model_loader, eval_config - ) - ) - # Add metric computations from specs - metric_computations = _filter_and_separate_computations( - metric_specs.to_computations( - metrics_specs, eval_config=eval_config, schema=schema - ) - ) - computations.extend(metric_computations.non_derived_computations) - - # Find out which model is baseline. - baseline_spec = model_util.get_baseline_model_spec(eval_config) - baseline_model_name = baseline_spec.name if baseline_spec else None - - # pylint: disable=no-value-for-parameter - - # Input: Single extract per example (or list of extracts if query_key used) - # where each item contains slice keys and other extracts from upstream - # extractors (e.g. labels, predictions, etc). - # Output: Single extract (per example) containing slice keys and initial - # combiner state returned from preprocessor. Note that even if a - # query_key was used the output is still only a single extract - # (though, that extract may contain lists of values (predictions, - # labels, etc) in its keys). - # - # Note that the output of this step is extracts instead of just a tuple of - # computation outputs because FanoutSlices takes extracts as input (and in - # many cases a subset of the extracts themselves are what is fanned out). - extracts = extracts | 'Preprocesss' >> beam.ParDo( - _PreprocessorDoFn(computations) - ) - - # Input: Single extract containing slice keys and initial combiner inputs. If - # query_key is used the extract represents multiple examples with the - # same query_key, otherwise the extract represents a single example. - # Output: Tuple (slice key, combiner inputs extracts). Notice that the per - # example (or list or examples if query_key used) input extract turns - # into n logical extracts, references to which are replicated once per - # applicable slice key. - slices = extracts | 'FanoutSlices' >> slicer.FanoutSlices() - - slices_count = ( - slices - | 'ExtractSliceKeys' >> beam.Keys() - | 'CountPerSliceKey' >> beam.combiners.Count.PerElement() - ) - - model_types = _get_model_types_for_logging(eval_shared_models) - - _ = ( - extracts.pipeline - | 'IncrementMetricsSpecsCounters' - >> counter_util.IncrementMetricsSpecsCounters(metrics_specs, model_types), - slices_count - | 'IncrementSliceSpecCounters' - >> counter_util.IncrementSliceSpecCounters(), - ) - - ci_params = _get_confidence_interval_params(eval_config, metrics_specs) - - cross_slice_specs = [] - if eval_config.cross_slicing_specs: - cross_slice_specs = eval_config.cross_slicing_specs - - computations_combine_fn = _ComputationsCombineFn(computations=computations) - derived_metrics_ptransform = _AddDerivedCrossSliceAndDiffMetrics( - metric_computations.derived_computations, - metric_computations.cross_slice_computations, - cross_slice_specs, - baseline_model_name, - ) - - # Input: Tuple of (slice key, combiner input extracts). - # Output: Tuple of (slice key, dict of computed metrics/plots/attributions). - # The dicts will be keyed by MetricKey/PlotKey/AttributionsKey and the - # values will be the result of the associated computations. A given - # MetricComputation can perform computations for multiple keys, but - # the keys should be unique across computations. - if ci_params.num_bootstrap_samples: - sliced_metrics_plots_and_attributions = ( - slices - | 'PoissonBootstrapConfidenceIntervals' - >> poisson_bootstrap.ComputeWithConfidenceIntervals( - computations_combine_fn=computations_combine_fn, - derived_metrics_ptransform=derived_metrics_ptransform, - num_bootstrap_samples=ci_params.num_bootstrap_samples, - hot_key_fanout=_COMBINE_PER_SLICE_KEY_HOT_KEY_FANOUT, - skip_ci_metric_keys=ci_params.skip_ci_metric_keys, - random_seed_for_testing=random_seed_for_testing, + """Computes metrics and plots. + + Args: + ---- + extracts: PCollection of Extracts. If a query_key was used then the + PCollection will contain a list of extracts. + eval_config: Eval config. + metrics_specs: Subset of the metric specs to compute metrics for. If a + query_key was used all of the metric specs will be for the same query_key. + eval_shared_models: Optional dict of shared models keyed by model name. Only + required if there are metrics to be computed in-graph using the model. + metrics_key: Name to use for metrics key in Evaluation output. + plots_key: Name to use for plots key in Evaluation output. + attributions_key: Name to use for attributions key in Evaluation output. + schema: A schema to use for customizing metrics and plots. + random_seed_for_testing: Seed to use for unit testing. + + Returns: + ------- + Evaluation containing dict of PCollections of (slice_key, results_dict) + tuples where the dict is keyed by either the metrics_key (e.g. 'metrics'), + plots_key (e.g. 'plots'), or attributions_key (e.g. 'attributions') + depending on what the results_dict contains. + """ + computations = [] + # Add default metric computations + if eval_shared_models: + # Note that there is the possibility for metric naming collisions here + # (e.g. 'auc' calculated within the model as well as by AUC metric + # computation performed outside the model). Currently all the overlapping + # metrics such as AUC that are computed outside the model are all derived + # metrics so they will override the metrics calculated by the model which is + # the desired behavior. + for model_name, eval_shared_model in eval_shared_models.items(): + if not eval_shared_model.include_default_metrics: + continue + if eval_shared_model.model_type == constants.TF_KERAS: + computations.extend( + keras_util.metric_computations_using_keras_saved_model( + model_name, eval_shared_model.model_loader, eval_config + ) + ) + # Add metric computations from specs + metric_computations = _filter_and_separate_computations( + metric_specs.to_computations( + metrics_specs, eval_config=eval_config, schema=schema ) ) - elif ci_params.num_jackknife_samples: - sliced_metrics_plots_and_attributions = ( + computations.extend(metric_computations.non_derived_computations) + + # Find out which model is baseline. + baseline_spec = model_util.get_baseline_model_spec(eval_config) + baseline_model_name = baseline_spec.name if baseline_spec else None + + # pylint: disable=no-value-for-parameter + + # Input: Single extract per example (or list of extracts if query_key used) + # where each item contains slice keys and other extracts from upstream + # extractors (e.g. labels, predictions, etc). + # Output: Single extract (per example) containing slice keys and initial + # combiner state returned from preprocessor. Note that even if a + # query_key was used the output is still only a single extract + # (though, that extract may contain lists of values (predictions, + # labels, etc) in its keys). + # + # Note that the output of this step is extracts instead of just a tuple of + # computation outputs because FanoutSlices takes extracts as input (and in + # many cases a subset of the extracts themselves are what is fanned out). + extracts = extracts | "Preprocesss" >> beam.ParDo(_PreprocessorDoFn(computations)) + + # Input: Single extract containing slice keys and initial combiner inputs. If + # query_key is used the extract represents multiple examples with the + # same query_key, otherwise the extract represents a single example. + # Output: Tuple (slice key, combiner inputs extracts). Notice that the per + # example (or list or examples if query_key used) input extract turns + # into n logical extracts, references to which are replicated once per + # applicable slice key. + slices = extracts | "FanoutSlices" >> slicer.FanoutSlices() + + slices_count = ( slices - | 'JackknifeConfidenceIntervals' - >> jackknife.ComputeWithConfidenceIntervals( - computations_combine_fn=computations_combine_fn, - derived_metrics_ptransform=derived_metrics_ptransform, - num_jackknife_samples=ci_params.num_jackknife_samples, - skip_ci_metric_keys=ci_params.skip_ci_metric_keys, - random_seed_for_testing=random_seed_for_testing, - ) + | "ExtractSliceKeys" >> beam.Keys() + | "CountPerSliceKey" >> beam.combiners.Count.PerElement() ) - else: - sliced_metrics_plots_and_attributions = ( - slices - | 'CombineMetricsPerSlice' - >> beam.CombinePerKey(computations_combine_fn).with_hot_key_fanout( - _COMBINE_PER_SLICE_KEY_HOT_KEY_FANOUT - ) - | 'AddDerivedCrossSliceAndDiffMetrics' >> derived_metrics_ptransform + + model_types = _get_model_types_for_logging(eval_shared_models) + + _ = ( + extracts.pipeline + | "IncrementMetricsSpecsCounters" + >> counter_util.IncrementMetricsSpecsCounters(metrics_specs, model_types), + slices_count + | "IncrementSliceSpecCounters" >> counter_util.IncrementSliceSpecCounters(), + ) + + ci_params = _get_confidence_interval_params(eval_config, metrics_specs) + + cross_slice_specs = [] + if eval_config.cross_slicing_specs: + cross_slice_specs = eval_config.cross_slicing_specs + + computations_combine_fn = _ComputationsCombineFn(computations=computations) + derived_metrics_ptransform = _AddDerivedCrossSliceAndDiffMetrics( + metric_computations.derived_computations, + metric_computations.cross_slice_computations, + cross_slice_specs, + baseline_model_name, ) - sliced_metrics_plots_and_attributions = ( - sliced_metrics_plots_and_attributions - | 'AddCIDerivedMetrics' - >> beam.Map( - _add_ci_derived_metrics, metric_computations.ci_derived_computations - ) - | 'RemovePrivateMetrics' >> beam.MapTuple(_remove_private_metrics) - ) + # Input: Tuple of (slice key, combiner input extracts). + # Output: Tuple of (slice key, dict of computed metrics/plots/attributions). + # The dicts will be keyed by MetricKey/PlotKey/AttributionsKey and the + # values will be the result of the associated computations. A given + # MetricComputation can perform computations for multiple keys, but + # the keys should be unique across computations. + if ci_params.num_bootstrap_samples: + sliced_metrics_plots_and_attributions = ( + slices + | "PoissonBootstrapConfidenceIntervals" + >> poisson_bootstrap.ComputeWithConfidenceIntervals( + computations_combine_fn=computations_combine_fn, + derived_metrics_ptransform=derived_metrics_ptransform, + num_bootstrap_samples=ci_params.num_bootstrap_samples, + hot_key_fanout=_COMBINE_PER_SLICE_KEY_HOT_KEY_FANOUT, + skip_ci_metric_keys=ci_params.skip_ci_metric_keys, + random_seed_for_testing=random_seed_for_testing, + ) + ) + elif ci_params.num_jackknife_samples: + sliced_metrics_plots_and_attributions = ( + slices + | "JackknifeConfidenceIntervals" + >> jackknife.ComputeWithConfidenceIntervals( + computations_combine_fn=computations_combine_fn, + derived_metrics_ptransform=derived_metrics_ptransform, + num_jackknife_samples=ci_params.num_jackknife_samples, + skip_ci_metric_keys=ci_params.skip_ci_metric_keys, + random_seed_for_testing=random_seed_for_testing, + ) + ) + else: + sliced_metrics_plots_and_attributions = ( + slices + | "CombineMetricsPerSlice" + >> beam.CombinePerKey(computations_combine_fn).with_hot_key_fanout( + _COMBINE_PER_SLICE_KEY_HOT_KEY_FANOUT + ) + | "AddDerivedCrossSliceAndDiffMetrics" >> derived_metrics_ptransform + ) - if eval_config.options.min_slice_size.value > 1: sliced_metrics_plots_and_attributions = ( sliced_metrics_plots_and_attributions - | 'FilterSmallSlices' - >> slicer.FilterOutSlices( - slices_count, eval_config.options.min_slice_size.value + | "AddCIDerivedMetrics" + >> beam.Map( + _add_ci_derived_metrics, metric_computations.ci_derived_computations ) + | "RemovePrivateMetrics" >> beam.MapTuple(_remove_private_metrics) ) - sliced_metrics = ( - sliced_metrics_plots_and_attributions - | 'FilterByMetrics' - >> beam.Map(_filter_by_key_type, metric_types.MetricKey) - ) - sliced_plots = ( - sliced_metrics_plots_and_attributions - | 'FilterByPlots' >> beam.Map(_filter_by_key_type, metric_types.PlotKey) - ) + if eval_config.options.min_slice_size.value > 1: + sliced_metrics_plots_and_attributions = ( + sliced_metrics_plots_and_attributions + | "FilterSmallSlices" + >> slicer.FilterOutSlices( + slices_count, eval_config.options.min_slice_size.value + ) + ) - sliced_attributions = ( - sliced_metrics_plots_and_attributions - | 'FilterByAttributions' - >> beam.Map(_filter_by_key_type, metric_types.AttributionsKey) - ) + sliced_metrics = ( + sliced_metrics_plots_and_attributions + | "FilterByMetrics" >> beam.Map(_filter_by_key_type, metric_types.MetricKey) + ) + sliced_plots = sliced_metrics_plots_and_attributions | "FilterByPlots" >> beam.Map( + _filter_by_key_type, metric_types.PlotKey + ) - # pylint: enable=no-value-for-parameter + sliced_attributions = ( + sliced_metrics_plots_and_attributions + | "FilterByAttributions" + >> beam.Map(_filter_by_key_type, metric_types.AttributionsKey) + ) + + # pylint: enable=no-value-for-parameter - return { - metrics_key: sliced_metrics, - plots_key: sliced_plots, - attributions_key: sliced_attributions, - } + return { + metrics_key: sliced_metrics, + plots_key: sliced_plots, + attributions_key: sliced_attributions, + } @beam.ptransform_fn @@ -990,95 +976,98 @@ def _EvaluateMetricsPlotsAndValidations( # pylint: disable=invalid-name schema: Optional[schema_pb2.Schema] = None, random_seed_for_testing: Optional[int] = None, ) -> evaluator.Evaluation: - """Evaluates metrics, plots, and validations. - - Args: - extracts: PCollection of Extracts. The extracts must contain a list of - slices of type SliceKeyType keyed by tfma.SLICE_KEY_TYPES_KEY as well as - any extracts required by the metric implementations (typically this will - include labels keyed by tfma.LABELS_KEY, predictions keyed by - tfma.PREDICTIONS_KEY, and example weights keyed by - tfma.EXAMPLE_WEIGHTS_KEY). Usually these will be added by calling the - default_extractors function. - eval_config: Eval config. - eval_shared_models: Optional dict of shared models keyed by model name. Only - required if there are metrics to be computed in-graph using the model. - metrics_key: Name to use for metrics key in Evaluation output. - plots_key: Name to use for plots key in Evaluation output. - attributions_key: Name to use for attributions key in Evaluation output. - validations_key: Name to use for validation key in Evaluation output. - schema: A schema to use for customizing metrics and plots. - random_seed_for_testing: Seed to use for unit testing. - - Returns: - Evaluation containing dict of PCollections of (slice_key, results_dict) - tuples where the dict is keyed by either the metrics_key (e.g. 'metrics'), - plots_key (e.g. 'plots'), attributions_key (e.g. 'attributions'), or - validation_key (e.g. 'validations') depending on what the results_dict - contains. - """ - # Separate metrics based on query_key (which may be None). - metrics_specs_by_query_key = {} - for spec in eval_config.metrics_specs: - if spec.query_key not in metrics_specs_by_query_key: - metrics_specs_by_query_key[spec.query_key] = [] - metrics_specs_by_query_key[spec.query_key].append(spec) - - # If there are no metrics specs then add an empty one (this is required for - # cases where only the default metrics from the model are used). - if not metrics_specs_by_query_key: - metrics_specs_by_query_key[''] = [config_pb2.MetricsSpec()] - - # pylint: disable=no-value-for-parameter - - evaluations = {} - for query_key, metrics_specs in metrics_specs_by_query_key.items(): - query_key_text = query_key if query_key else '' - if query_key: - extracts_for_evaluation = extracts | 'GroupByQueryKey({})'.format( - query_key_text - ) >> _GroupByQueryKey(query_key) - include_default_metrics = False - else: - extracts_for_evaluation = extracts - include_default_metrics = eval_config and ( - not eval_config.options.HasField('include_default_metrics') - or eval_config.options.include_default_metrics.value - ) - evaluation = extracts_for_evaluation | 'ComputeMetricsAndPlots({})'.format( - query_key_text - ) >> _ComputeMetricsAndPlots( - eval_config=eval_config, - metrics_specs=metrics_specs, - eval_shared_models=( - eval_shared_models if include_default_metrics else None - ), - metrics_key=metrics_key, - plots_key=plots_key, - attributions_key=attributions_key, - schema=schema, - random_seed_for_testing=random_seed_for_testing, - ) + """Evaluates metrics, plots, and validations. - for k, v in evaluation.items(): - if k not in evaluations: - evaluations[k] = [] - evaluations[k].append(v) - evaluation_results = evaluator.combine_dict_based_evaluations(evaluations) + Args: + ---- + extracts: PCollection of Extracts. The extracts must contain a list of + slices of type SliceKeyType keyed by tfma.SLICE_KEY_TYPES_KEY as well as + any extracts required by the metric implementations (typically this will + include labels keyed by tfma.LABELS_KEY, predictions keyed by + tfma.PREDICTIONS_KEY, and example weights keyed by + tfma.EXAMPLE_WEIGHTS_KEY). Usually these will be added by calling the + default_extractors function. + eval_config: Eval config. + eval_shared_models: Optional dict of shared models keyed by model name. Only + required if there are metrics to be computed in-graph using the model. + metrics_key: Name to use for metrics key in Evaluation output. + plots_key: Name to use for plots key in Evaluation output. + attributions_key: Name to use for attributions key in Evaluation output. + validations_key: Name to use for validation key in Evaluation output. + schema: A schema to use for customizing metrics and plots. + random_seed_for_testing: Seed to use for unit testing. + + Returns: + ------- + Evaluation containing dict of PCollections of (slice_key, results_dict) + tuples where the dict is keyed by either the metrics_key (e.g. 'metrics'), + plots_key (e.g. 'plots'), attributions_key (e.g. 'attributions'), or + validation_key (e.g. 'validations') depending on what the results_dict + contains. + """ + # Separate metrics based on query_key (which may be None). + metrics_specs_by_query_key = {} + for spec in eval_config.metrics_specs: + if spec.query_key not in metrics_specs_by_query_key: + metrics_specs_by_query_key[spec.query_key] = [] + metrics_specs_by_query_key[spec.query_key].append(spec) + + # If there are no metrics specs then add an empty one (this is required for + # cases where only the default metrics from the model are used). + if not metrics_specs_by_query_key: + metrics_specs_by_query_key[""] = [config_pb2.MetricsSpec()] + + # pylint: disable=no-value-for-parameter + + evaluations = {} + for query_key, metrics_specs in metrics_specs_by_query_key.items(): + query_key_text = query_key if query_key else "" + if query_key: + extracts_for_evaluation = ( + extracts + | f"GroupByQueryKey({query_key_text})" >> _GroupByQueryKey(query_key) + ) + include_default_metrics = False + else: + extracts_for_evaluation = extracts + include_default_metrics = eval_config and ( + not eval_config.options.HasField("include_default_metrics") + or eval_config.options.include_default_metrics.value + ) + evaluation = ( + extracts_for_evaluation + | f"ComputeMetricsAndPlots({query_key_text})" + >> _ComputeMetricsAndPlots( + eval_config=eval_config, + metrics_specs=metrics_specs, + eval_shared_models=( + eval_shared_models if include_default_metrics else None + ), + metrics_key=metrics_key, + plots_key=plots_key, + attributions_key=attributions_key, + schema=schema, + random_seed_for_testing=random_seed_for_testing, + ) + ) - validations = evaluation_results[metrics_key] | 'ValidateMetrics' >> beam.Map( - metrics_validator.validate_metrics, eval_config - ) - evaluation_results[validations_key] = validations - return evaluation_results + for k, v in evaluation.items(): + if k not in evaluations: + evaluations[k] = [] + evaluations[k].append(v) + evaluation_results = evaluator.combine_dict_based_evaluations(evaluations) + + validations = evaluation_results[metrics_key] | "ValidateMetrics" >> beam.Map( + metrics_validator.validate_metrics, eval_config + ) + evaluation_results[validations_key] = validations + return evaluation_results def _get_model_types_for_logging( eval_shared_models: Dict[str, types.EvalSharedModel], ): - if eval_shared_models: - return set( - [model.model_type for (name, model) in eval_shared_models.items()] - ) - else: - return set([constants.MODEL_AGNOSTIC]) + if eval_shared_models: + return set([model.model_type for (name, model) in eval_shared_models.items()]) + else: + return set([constants.MODEL_AGNOSTIC]) diff --git a/tensorflow_model_analysis/evaluators/metrics_plots_and_validations_evaluator_test.py b/tensorflow_model_analysis/evaluators/metrics_plots_and_validations_evaluator_test.py index 16241e3d99..a271663b37 100644 --- a/tensorflow_model_analysis/evaluators/metrics_plots_and_validations_evaluator_test.py +++ b/tensorflow_model_analysis/evaluators/metrics_plots_and_validations_evaluator_test.py @@ -15,166 +15,164 @@ import os -from absl.testing import parameterized import apache_beam as beam -from apache_beam.testing import util import numpy as np import tensorflow as tf +from absl.testing import parameterized +from apache_beam.testing import util +from google.protobuf import text_format +from tensorflow_metadata.proto.v0 import schema_pb2 +from tfx_bsl.tfxio import tensor_adapter, test_util + from tensorflow_model_analysis import constants from tensorflow_model_analysis.api import model_eval_lib from tensorflow_model_analysis.evaluators import metrics_plots_and_validations_evaluator -from tensorflow_model_analysis.extractors import example_weights_extractor -from tensorflow_model_analysis.extractors import features_extractor -from tensorflow_model_analysis.extractors import labels_extractor -from tensorflow_model_analysis.extractors import materialized_predictions_extractor -from tensorflow_model_analysis.extractors import predictions_extractor -from tensorflow_model_analysis.extractors import slice_key_extractor -from tensorflow_model_analysis.extractors import unbatch_extractor -from tensorflow_model_analysis.metrics import attributions -from tensorflow_model_analysis.metrics import binary_confusion_matrices -from tensorflow_model_analysis.metrics import calibration -from tensorflow_model_analysis.metrics import calibration_plot -from tensorflow_model_analysis.metrics import confusion_matrix_plot -from tensorflow_model_analysis.metrics import metric_specs -from tensorflow_model_analysis.metrics import metric_types -from tensorflow_model_analysis.proto import config_pb2 -from tensorflow_model_analysis.proto import validation_result_pb2 +from tensorflow_model_analysis.extractors import ( + example_weights_extractor, + features_extractor, + labels_extractor, + materialized_predictions_extractor, + predictions_extractor, + slice_key_extractor, + unbatch_extractor, +) +from tensorflow_model_analysis.metrics import ( + attributions, + binary_confusion_matrices, + calibration, + calibration_plot, + confusion_matrix_plot, + metric_specs, + metric_types, +) +from tensorflow_model_analysis.proto import config_pb2, validation_result_pb2 from tensorflow_model_analysis.utils import test_util as testutil from tensorflow_model_analysis.utils.keras_lib import tf_keras -from tfx_bsl.tfxio import tensor_adapter -from tfx_bsl.tfxio import test_util - -from google.protobuf import text_format -from tensorflow_metadata.proto.v0 import schema_pb2 -_TF_MAJOR_VERSION = int(tf.version.VERSION.split('.')[0]) +_TF_MAJOR_VERSION = int(tf.version.VERSION.split(".")[0]) class MetricsPlotsAndValidationsEvaluatorTest( testutil.TensorflowModelAnalysisTest, parameterized.TestCase ): + def _getExportDir(self): + return os.path.join(self._getTempDir(), "export_dir") + + def _getBaselineDir(self): + return os.path.join(self._getTempDir(), "baseline_export_dir") + + def _build_keras_model(self, model_name, model_dir, mul): + input_layer = tf_keras.layers.Input(shape=(1,), name="input_1") + output_layer = tf_keras.layers.Lambda( + lambda x, mul: x * mul, output_shape=(1,), arguments={"mul": mul} + )(input_layer) + model = tf_keras.models.Model([input_layer], output_layer) + model.compile( + optimizer=tf_keras.optimizers.Adam(lr=0.001), + loss=tf_keras.losses.BinaryCrossentropy(name="loss"), + metrics=["accuracy"], + ) + model.save(model_dir, save_format="tf") + return self.createTestEvalSharedModel( + model_name=model_name, model_path=model_dir + ) - def _getExportDir(self): - return os.path.join(self._getTempDir(), 'export_dir') - - def _getBaselineDir(self): - return os.path.join(self._getTempDir(), 'baseline_export_dir') - - def _build_keras_model(self, model_name, model_dir, mul): - input_layer = tf_keras.layers.Input(shape=(1,), name='input_1') - output_layer = tf_keras.layers.Lambda( - lambda x, mul: x * mul, output_shape=(1,), arguments={'mul': mul} - )(input_layer) - model = tf_keras.models.Model([input_layer], output_layer) - model.compile( - optimizer=tf_keras.optimizers.Adam(lr=0.001), - loss=tf_keras.losses.BinaryCrossentropy(name='loss'), - metrics=['accuracy'], - ) - model.save(model_dir, save_format='tf') - return self.createTestEvalSharedModel( - model_name=model_name, model_path=model_dir - ) - - def testFilterAndSeparateComputations(self): - eval_config = config_pb2.EvalConfig( - model_specs=[ - config_pb2.ModelSpec(name='candidate', label_key='tips'), - config_pb2.ModelSpec( - name='baseline', label_key='tips', is_baseline=True + def testFilterAndSeparateComputations(self): + eval_config = config_pb2.EvalConfig( + model_specs=[ + config_pb2.ModelSpec(name="candidate", label_key="tips"), + config_pb2.ModelSpec( + name="baseline", label_key="tips", is_baseline=True + ), + ], + cross_slicing_specs=[config_pb2.CrossSlicingSpec()], + ) + metrics_specs = metric_specs.specs_from_metrics( + [ + tf_keras.metrics.BinaryAccuracy(name="accuracy"), + tf_keras.metrics.AUC(name="auc", num_thresholds=10000), + tf_keras.metrics.AUC( + name="auc_precison_recall", curve="PR", num_thresholds=10000 + ), + tf_keras.metrics.Precision(name="precision"), + tf_keras.metrics.Recall(name="recall"), + calibration.MeanLabel(name="mean_label"), + calibration.MeanPrediction(name="mean_prediction"), + calibration.Calibration(name="calibration"), + confusion_matrix_plot.ConfusionMatrixPlot(name="confusion_matrix_plot"), + calibration_plot.CalibrationPlot(name="calibration_plot"), + ], + model_names=["candidate", "baseline"], + binarize=config_pb2.BinarizationOptions(class_ids={"values": [0, 5]}), + ) + computations = metric_specs.to_computations( + metrics_specs, eval_config=eval_config + ) + non_derived, derived, _, ci_derived = ( + metrics_plots_and_validations_evaluator._filter_and_separate_computations( + computations + ) + ) + # 2 models x 2 classes x _binary_confusion_matrix_[0.5]_100, + # 2 models x 2 classes x _CalibrationHistogramCombiner + # 2 models x 2 classes x _calibration_historgram_27 + # 2 models x 2 classes x _CompilableMetricsCombiner, + # 2 models x 2 classes x _WeightedLabelsPredictionsExamplesCombiner, + # 4 models x _ExampleCountCombiner + self.assertLen(non_derived, 16) + # 2 models x 2 classes x _binary_confusion_matrices_[0.5], + # 2 models x 2 classes x _binary_confusion_matrices_10000 + # 2 models x 2 classes x _binary_confusion_matrices_confusion_matrix_plot + # 2 models x 2 classes x precision + # 2 models x 2 classes x recall + # 2 models x 2 classes x calibration + # 2 models x 2 classes x auc_precision_recall + # 2 models x 2 classes x mean_prediction + # 2 models x 2 classes x mean_label + # 2 models x 2 classes x confusion_matrix_plot + # 2 models x 2 classes x calibration_plot + # 2 models x 2 classes x auc + # 2 models x 2 classes x accuracy + self.assertLen(derived, 52) + # None of the metric has CIDerivedMetricComputation. + self.assertEmpty(ci_derived) + + def testFilterAndSeparateComputationsWithCIDerivedMetrics(self): + def derived_metric_fn(): + pass + + def ci_derived_fn(): + pass + + computations = [ + metric_types.DerivedMetricComputation( + [metric_types.MetricKey("key1")], derived_metric_fn ), - ], - cross_slicing_specs=[config_pb2.CrossSlicingSpec()], - ) - metrics_specs = metric_specs.specs_from_metrics( - [ - tf_keras.metrics.BinaryAccuracy(name='accuracy'), - tf_keras.metrics.AUC(name='auc', num_thresholds=10000), - tf_keras.metrics.AUC( - name='auc_precison_recall', curve='PR', num_thresholds=10000 + metric_types.CIDerivedMetricComputation( + [metric_types.MetricKey("key1")], ci_derived_fn ), - tf_keras.metrics.Precision(name='precision'), - tf_keras.metrics.Recall(name='recall'), - calibration.MeanLabel(name='mean_label'), - calibration.MeanPrediction(name='mean_prediction'), - calibration.Calibration(name='calibration'), - confusion_matrix_plot.ConfusionMatrixPlot( - name='confusion_matrix_plot' + metric_types.CIDerivedMetricComputation( + [metric_types.MetricKey("key1")], ci_derived_fn ), - calibration_plot.CalibrationPlot(name='calibration_plot'), - ], - model_names=['candidate', 'baseline'], - binarize=config_pb2.BinarizationOptions(class_ids={'values': [0, 5]}), - ) - computations = metric_specs.to_computations( - metrics_specs, eval_config=eval_config - ) - non_derived, derived, _, ci_derived = ( - metrics_plots_and_validations_evaluator._filter_and_separate_computations( - computations - ) - ) - # 2 models x 2 classes x _binary_confusion_matrix_[0.5]_100, - # 2 models x 2 classes x _CalibrationHistogramCombiner - # 2 models x 2 classes x _calibration_historgram_27 - # 2 models x 2 classes x _CompilableMetricsCombiner, - # 2 models x 2 classes x _WeightedLabelsPredictionsExamplesCombiner, - # 4 models x _ExampleCountCombiner - self.assertLen(non_derived, 16) - # 2 models x 2 classes x _binary_confusion_matrices_[0.5], - # 2 models x 2 classes x _binary_confusion_matrices_10000 - # 2 models x 2 classes x _binary_confusion_matrices_confusion_matrix_plot - # 2 models x 2 classes x precision - # 2 models x 2 classes x recall - # 2 models x 2 classes x calibration - # 2 models x 2 classes x auc_precision_recall - # 2 models x 2 classes x mean_prediction - # 2 models x 2 classes x mean_label - # 2 models x 2 classes x confusion_matrix_plot - # 2 models x 2 classes x calibration_plot - # 2 models x 2 classes x auc - # 2 models x 2 classes x accuracy - self.assertLen(derived, 52) - # None of the metric has CIDerivedMetricComputation. - self.assertEmpty(ci_derived) - - def testFilterAndSeparateComputationsWithCIDerivedMetrics(self): - - def derived_metric_fn(): - pass - - def ci_derived_fn(): - pass - - computations = [ - metric_types.DerivedMetricComputation( - [metric_types.MetricKey('key1')], derived_metric_fn - ), - metric_types.CIDerivedMetricComputation( - [metric_types.MetricKey('key1')], ci_derived_fn - ), - metric_types.CIDerivedMetricComputation( - [metric_types.MetricKey('key1')], ci_derived_fn - ), - ] - _, derived, _, ci_derived = ( - metrics_plots_and_validations_evaluator._filter_and_separate_computations( - computations + ] + _, derived, _, ci_derived = ( + metrics_plots_and_validations_evaluator._filter_and_separate_computations( + computations + ) ) - ) - self.assertLen(derived, 1) - self.assertLen(ci_derived, 1) + self.assertLen(derived, 1) + self.assertLen(ci_derived, 1) - def testEvaluateWithKerasAndDiffMetrics(self): - model_dir, baseline_dir = self._getExportDir(), self._getBaselineDir() - eval_shared_model = self._build_keras_model('candidate', model_dir, mul=0) - baseline_eval_shared_model = self._build_keras_model( - 'baseline', baseline_dir, mul=1 - ) + def testEvaluateWithKerasAndDiffMetrics(self): + model_dir, baseline_dir = self._getExportDir(), self._getBaselineDir() + eval_shared_model = self._build_keras_model("candidate", model_dir, mul=0) + baseline_eval_shared_model = self._build_keras_model( + "baseline", baseline_dir, mul=1 + ) - schema = text_format.Parse( - """ + schema = text_format.Parse( + """ tensor_representation_group { key: "" value { @@ -206,259 +204,259 @@ def testEvaluateWithKerasAndDiffMetrics(self): type: BYTES } """, - schema_pb2.Schema(), - ) - tfx_io = test_util.InMemoryTFExampleRecord( - schema=schema, raw_record_column_name=constants.ARROW_INPUT_COLUMN - ) - tensor_adapter_config = tensor_adapter.TensorAdapterConfig( - arrow_schema=tfx_io.ArrowSchema(), - tensor_representations=tfx_io.TensorRepresentations(), - ) + schema_pb2.Schema(), + ) + tfx_io = test_util.InMemoryTFExampleRecord( + schema=schema, raw_record_column_name=constants.ARROW_INPUT_COLUMN + ) + tensor_adapter_config = tensor_adapter.TensorAdapterConfig( + arrow_schema=tfx_io.ArrowSchema(), + tensor_representations=tfx_io.TensorRepresentations(), + ) - examples = [ - self._makeExample( - input_1=0.0, - label=1.0, - example_weight=1.0, - extra_feature='non_model_feature', - ), - self._makeExample( - input_1=1.0, - label=0.0, - example_weight=0.5, - extra_feature='non_model_feature', - ), - ] - - eval_config = config_pb2.EvalConfig( - model_specs=[ - config_pb2.ModelSpec( - name='candidate', - label_key='label', - example_weight_key='example_weight', + examples = [ + self._makeExample( + input_1=0.0, + label=1.0, + example_weight=1.0, + extra_feature="non_model_feature", ), - config_pb2.ModelSpec( - name='baseline', - label_key='label', - example_weight_key='example_weight', - is_baseline=True, + self._makeExample( + input_1=1.0, + label=0.0, + example_weight=0.5, + extra_feature="non_model_feature", ), - ], - slicing_specs=[config_pb2.SlicingSpec()], - metrics_specs=metric_specs.specs_from_metrics( - [ - calibration.MeanLabel('mean_label'), - calibration.MeanPrediction('mean_prediction'), + ] + + eval_config = config_pb2.EvalConfig( + model_specs=[ + config_pb2.ModelSpec( + name="candidate", + label_key="label", + example_weight_key="example_weight", + ), + config_pb2.ModelSpec( + name="baseline", + label_key="label", + example_weight_key="example_weight", + is_baseline=True, + ), ], - model_names=['candidate', 'baseline'], - ), - ) + slicing_specs=[config_pb2.SlicingSpec()], + metrics_specs=metric_specs.specs_from_metrics( + [ + calibration.MeanLabel("mean_label"), + calibration.MeanPrediction("mean_prediction"), + ], + model_names=["candidate", "baseline"], + ), + ) - eval_shared_models = [eval_shared_model, baseline_eval_shared_model] - extractors = [ - features_extractor.FeaturesExtractor( - eval_config=eval_config, - tensor_representations=tensor_adapter_config.tensor_representations, - ), - labels_extractor.LabelsExtractor(eval_config), - example_weights_extractor.ExampleWeightsExtractor(eval_config), - predictions_extractor.PredictionsExtractor( - eval_shared_model=eval_shared_models, eval_config=eval_config - ), - unbatch_extractor.UnbatchExtractor(), - slice_key_extractor.SliceKeyExtractor(eval_config=eval_config), - ] - evaluators = [ - metrics_plots_and_validations_evaluator.MetricsPlotsAndValidationsEvaluator( - eval_config=eval_config, eval_shared_model=eval_shared_models + eval_shared_models = [eval_shared_model, baseline_eval_shared_model] + extractors = [ + features_extractor.FeaturesExtractor( + eval_config=eval_config, + tensor_representations=tensor_adapter_config.tensor_representations, + ), + labels_extractor.LabelsExtractor(eval_config), + example_weights_extractor.ExampleWeightsExtractor(eval_config), + predictions_extractor.PredictionsExtractor( + eval_shared_model=eval_shared_models, eval_config=eval_config + ), + unbatch_extractor.UnbatchExtractor(), + slice_key_extractor.SliceKeyExtractor(eval_config=eval_config), + ] + evaluators = [ + metrics_plots_and_validations_evaluator.MetricsPlotsAndValidationsEvaluator( + eval_config=eval_config, eval_shared_model=eval_shared_models + ) + ] + + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + metrics = ( + pipeline + | "Create" >> beam.Create([e.SerializeToString() for e in examples]) + | "BatchExamples" >> tfx_io.BeamSource() + | "InputsToExtracts" >> model_eval_lib.BatchedInputsToExtracts() + | "ExtractAndEvaluate" + >> model_eval_lib.ExtractAndEvaluate( + extractors=extractors, evaluators=evaluators + ) + ) + + # pylint: enable=no-value-for-parameter + + def check_metrics(got): + try: + self.assertLen(got, 1) + got_slice_key, got_metrics = got[0] + self.assertEqual(got_slice_key, ()) + # check only the diff metrics. + weighted_example_count_key = metric_types.MetricKey( + name="weighted_example_count", + model_name="candidate", + is_diff=True, + example_weighted=True, + ) + prediction_key = metric_types.MetricKey( + name="mean_prediction", + model_name="candidate", + is_diff=True, + example_weighted=True, + ) + label_key = metric_types.MetricKey( + name="mean_label", + model_name="candidate", + is_diff=True, + example_weighted=True, + ) + self.assertDictElementsAlmostEqual( + got_metrics, + { + weighted_example_count_key: 0, + label_key: 0, + prediction_key: 0 - (0 * 1 + 1 * 0.5) / (1 + 0.5), + }, + ) + + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that( + metrics[constants.METRICS_KEY], check_metrics, label="metrics" + ) + + def testEvaluateWithAttributions(self): + eval_config = config_pb2.EvalConfig( + model_specs=[config_pb2.ModelSpec()], + metrics_specs=[ + config_pb2.MetricsSpec( + metrics=[ + config_pb2.MetricConfig( + class_name=attributions.TotalAttributions().__class__.__name__ + ) + ] + ) + ], + options=config_pb2.Options( + disabled_outputs={"values": ["eval_config_pb2.json"]} + ), ) - ] - - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - metrics = ( - pipeline - | 'Create' >> beam.Create([e.SerializeToString() for e in examples]) - | 'BatchExamples' >> tfx_io.BeamSource() - | 'InputsToExtracts' >> model_eval_lib.BatchedInputsToExtracts() - | 'ExtractAndEvaluate' - >> model_eval_lib.ExtractAndEvaluate( - extractors=extractors, evaluators=evaluators - ) - ) - - # pylint: enable=no-value-for-parameter - - def check_metrics(got): - try: - self.assertLen(got, 1) - got_slice_key, got_metrics = got[0] - self.assertEqual(got_slice_key, ()) - # check only the diff metrics. - weighted_example_count_key = metric_types.MetricKey( - name='weighted_example_count', - model_name='candidate', - is_diff=True, - example_weighted=True, - ) - prediction_key = metric_types.MetricKey( - name='mean_prediction', - model_name='candidate', - is_diff=True, - example_weighted=True, - ) - label_key = metric_types.MetricKey( - name='mean_label', - model_name='candidate', - is_diff=True, - example_weighted=True, - ) - self.assertDictElementsAlmostEqual( - got_metrics, - { - weighted_example_count_key: 0, - label_key: 0, - prediction_key: 0 - (0 * 1 + 1 * 0.5) / (1 + 0.5), - }, - ) - - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that( - metrics[constants.METRICS_KEY], check_metrics, label='metrics' - ) - - def testEvaluateWithAttributions(self): - eval_config = config_pb2.EvalConfig( - model_specs=[config_pb2.ModelSpec()], - metrics_specs=[ - config_pb2.MetricsSpec( - metrics=[ - config_pb2.MetricConfig( - class_name=attributions.TotalAttributions().__class__.__name__ + extractors = [slice_key_extractor.SliceKeyExtractor()] + evaluators = [ + metrics_plots_and_validations_evaluator.MetricsPlotsAndValidationsEvaluator( + eval_config=eval_config + ) + ] + + example1 = { + "labels": None, + "predictions": None, + "example_weights": np.array(1.0), + "features": {}, + "attributions": {"feature1": 1.1, "feature2": 1.2}, + } + example2 = { + "labels": None, + "predictions": None, + "example_weights": np.array(1.0), + "features": {}, + "attributions": {"feature1": 2.1, "feature2": 2.2}, + } + example3 = { + "labels": None, + "predictions": None, + "example_weights": np.array(1.0), + "features": {}, + "attributions": { + "feature1": np.array([3.1]), + "feature2": np.array([3.2]), + }, + } + + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + results = ( + pipeline + | "Create" >> beam.Create([example1, example2, example3]) + | "ExtractEvaluate" + >> model_eval_lib.ExtractAndEvaluate( + extractors=extractors, evaluators=evaluators + ) + ) + + # pylint: enable=no-value-for-parameter + + def check_attributions(got): + try: + self.assertLen(got, 1) + got_slice_key, got_attributions = got[0] + self.assertEqual(got_slice_key, ()) + total_attributions_key = metric_types.MetricKey( + name="total_attributions" ) - ] + self.assertIn(total_attributions_key, got_attributions) + self.assertDictElementsAlmostEqual( + got_attributions[total_attributions_key], + {"feature1": 1.1 + 2.1 + 3.1, "feature2": 1.2 + 2.2 + 3.2}, + ) + + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that( + results[constants.ATTRIBUTIONS_KEY], + check_attributions, + label="attributions", ) - ], - options=config_pb2.Options( - disabled_outputs={'values': ['eval_config_pb2.json']} - ), - ) - extractors = [slice_key_extractor.SliceKeyExtractor()] - evaluators = [ - metrics_plots_and_validations_evaluator.MetricsPlotsAndValidationsEvaluator( - eval_config=eval_config + + def testEvaluateWithJackknifeAndDiffMetrics(self): + model_dir, baseline_dir = self._getExportDir(), self._getBaselineDir() + eval_shared_model = self._build_keras_model("candidate", model_dir, mul=0) + baseline_eval_shared_model = self._build_keras_model( + "baseline", baseline_dir, mul=1 ) - ] - - example1 = { - 'labels': None, - 'predictions': None, - 'example_weights': np.array(1.0), - 'features': {}, - 'attributions': {'feature1': 1.1, 'feature2': 1.2}, - } - example2 = { - 'labels': None, - 'predictions': None, - 'example_weights': np.array(1.0), - 'features': {}, - 'attributions': {'feature1': 2.1, 'feature2': 2.2}, - } - example3 = { - 'labels': None, - 'predictions': None, - 'example_weights': np.array(1.0), - 'features': {}, - 'attributions': { - 'feature1': np.array([3.1]), - 'feature2': np.array([3.2]), - }, - } - - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - results = ( - pipeline - | 'Create' >> beam.Create([example1, example2, example3]) - | 'ExtractEvaluate' - >> model_eval_lib.ExtractAndEvaluate( - extractors=extractors, evaluators=evaluators - ) - ) - - # pylint: enable=no-value-for-parameter - - def check_attributions(got): - try: - self.assertLen(got, 1) - got_slice_key, got_attributions = got[0] - self.assertEqual(got_slice_key, ()) - total_attributions_key = metric_types.MetricKey( - name='total_attributions' - ) - self.assertIn(total_attributions_key, got_attributions) - self.assertDictElementsAlmostEqual( - got_attributions[total_attributions_key], - {'feature1': 1.1 + 2.1 + 3.1, 'feature2': 1.2 + 2.2 + 3.2}, - ) - - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that( - results[constants.ATTRIBUTIONS_KEY], - check_attributions, - label='attributions', - ) - - def testEvaluateWithJackknifeAndDiffMetrics(self): - model_dir, baseline_dir = self._getExportDir(), self._getBaselineDir() - eval_shared_model = self._build_keras_model('candidate', model_dir, mul=0) - baseline_eval_shared_model = self._build_keras_model( - 'baseline', baseline_dir, mul=1 - ) - options = config_pb2.Options() - options.compute_confidence_intervals.value = True - options.confidence_intervals.method = ( - config_pb2.ConfidenceIntervalOptions.JACKKNIFE - ) + options = config_pb2.Options() + options.compute_confidence_intervals.value = True + options.confidence_intervals.method = ( + config_pb2.ConfidenceIntervalOptions.JACKKNIFE + ) - eval_config = config_pb2.EvalConfig( - model_specs=[ - config_pb2.ModelSpec( - name='candidate', - label_key='label', - example_weight_key='example_weight', - ), - config_pb2.ModelSpec( - name='baseline', - label_key='label', - example_weight_key='example_weight', - is_baseline=True, - ), - ], - slicing_specs=[config_pb2.SlicingSpec()], - metrics_specs=metric_specs.specs_from_metrics( - [ - calibration.MeanLabel('mean_label'), - calibration.MeanPrediction('mean_prediction'), + eval_config = config_pb2.EvalConfig( + model_specs=[ + config_pb2.ModelSpec( + name="candidate", + label_key="label", + example_weight_key="example_weight", + ), + config_pb2.ModelSpec( + name="baseline", + label_key="label", + example_weight_key="example_weight", + is_baseline=True, + ), ], - model_names=['candidate', 'baseline'], - ), - options=options, - ) + slicing_specs=[config_pb2.SlicingSpec()], + metrics_specs=metric_specs.specs_from_metrics( + [ + calibration.MeanLabel("mean_label"), + calibration.MeanPrediction("mean_prediction"), + ], + model_names=["candidate", "baseline"], + ), + options=options, + ) - eval_shared_models = { - 'candidate': eval_shared_model, - 'baseline': baseline_eval_shared_model, - } + eval_shared_models = { + "candidate": eval_shared_model, + "baseline": baseline_eval_shared_model, + } - schema = text_format.Parse( - """ + schema = text_format.Parse( + """ tensor_representation_group { key: "" value { @@ -490,170 +488,170 @@ def testEvaluateWithJackknifeAndDiffMetrics(self): type: BYTES } """, - schema_pb2.Schema(), - ) - tfx_io = test_util.InMemoryTFExampleRecord( - schema=schema, raw_record_column_name=constants.ARROW_INPUT_COLUMN - ) - tensor_adapter_config = tensor_adapter.TensorAdapterConfig( - arrow_schema=tfx_io.ArrowSchema(), - tensor_representations=tfx_io.TensorRepresentations(), - ) + schema_pb2.Schema(), + ) + tfx_io = test_util.InMemoryTFExampleRecord( + schema=schema, raw_record_column_name=constants.ARROW_INPUT_COLUMN + ) + tensor_adapter_config = tensor_adapter.TensorAdapterConfig( + arrow_schema=tfx_io.ArrowSchema(), + tensor_representations=tfx_io.TensorRepresentations(), + ) - examples = [ - self._makeExample( - input_1=0.0, - label=1.0, - example_weight=1.0, - extra_feature='non_model_feature', - ), - self._makeExample( - input_1=1.0, - label=0.0, - example_weight=0.5, - extra_feature='non_model_feature', - ), - ] + examples = [ + self._makeExample( + input_1=0.0, + label=1.0, + example_weight=1.0, + extra_feature="non_model_feature", + ), + self._makeExample( + input_1=1.0, + label=0.0, + example_weight=0.5, + extra_feature="non_model_feature", + ), + ] - extractors = [ - features_extractor.FeaturesExtractor( - eval_config=eval_config, - tensor_representations=tensor_adapter_config.tensor_representations, - ), - labels_extractor.LabelsExtractor(eval_config), - example_weights_extractor.ExampleWeightsExtractor(eval_config), - predictions_extractor.PredictionsExtractor( - eval_shared_model=eval_shared_models, eval_config=eval_config - ), - unbatch_extractor.UnbatchExtractor(), - slice_key_extractor.SliceKeyExtractor(eval_config=eval_config), - ] - evaluators = [ - metrics_plots_and_validations_evaluator.MetricsPlotsAndValidationsEvaluator( - eval_config=eval_config, eval_shared_model=eval_shared_models - ) - ] - - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - metrics = ( - pipeline - | 'Create' - >> beam.Create([e.SerializeToString() for e in examples * 1000]) - | 'BatchExamples' >> tfx_io.BeamSource() - | 'InputsToExtracts' >> model_eval_lib.BatchedInputsToExtracts() - | 'ExtractAndEvaluate' - >> model_eval_lib.ExtractAndEvaluate( - extractors=extractors, evaluators=evaluators - ) - ) - - # pylint: enable=no-value-for-parameter - - def check_metrics(got): - try: - self.assertLen(got, 1) - got_slice_key, got_metrics = got[0] - self.assertEqual(got_slice_key, ()) - # check only the diff metrics. - weighted_example_count_key = metric_types.MetricKey( - name='weighted_example_count', - model_name='candidate', - is_diff=True, - example_weighted=True, - ) - prediction_key = metric_types.MetricKey( - name='mean_prediction', - model_name='candidate', - is_diff=True, - example_weighted=True, - ) - label_key = metric_types.MetricKey( - name='mean_label', - model_name='candidate', - is_diff=True, - example_weighted=True, - ) - self.assertDictElementsWithTDistributionAlmostEqual( - got_metrics, - { - weighted_example_count_key: 0, - label_key: 0, - prediction_key: 0 - (0 * 1 + 1 * 0.5) / (1 + 0.5), - }, - ) - - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(metrics[constants.METRICS_KEY], check_metrics) - - @parameterized.named_parameters( - ('compiled_metrics', False), - ('evaluate', True), - ) - def testEvaluateWithKerasModelWithInGraphMetrics(self, add_custom_metrics): - # Custom metrics not supported in TFv1 - if _TF_MAJOR_VERSION < 2: - add_custom_metrics = False - - input1 = tf_keras.layers.Input(shape=(1,), name='input_1') - input2 = tf_keras.layers.Input(shape=(1,), name='input_2') - inputs = [input1, input2] - input_layer = tf_keras.layers.concatenate(inputs) - output_layer = tf_keras.layers.Dense( - 1, activation=tf.nn.sigmoid, name='output' - )(input_layer) - model = tf_keras.models.Model(inputs, output_layer) - # The model.evaluate API is used when custom metrics are used. Otherwise - # model.compiled_metrics is used. - if add_custom_metrics: - model.add_metric(tf.reduce_sum(input_layer), name='custom') - model.compile( - optimizer=tf_keras.optimizers.Adam(lr=0.001), - loss=tf_keras.losses.BinaryCrossentropy(name='loss'), - metrics=[tf_keras.metrics.BinaryAccuracy(name='binary_accuracy')], + extractors = [ + features_extractor.FeaturesExtractor( + eval_config=eval_config, + tensor_representations=tensor_adapter_config.tensor_representations, + ), + labels_extractor.LabelsExtractor(eval_config), + example_weights_extractor.ExampleWeightsExtractor(eval_config), + predictions_extractor.PredictionsExtractor( + eval_shared_model=eval_shared_models, eval_config=eval_config + ), + unbatch_extractor.UnbatchExtractor(), + slice_key_extractor.SliceKeyExtractor(eval_config=eval_config), + ] + evaluators = [ + metrics_plots_and_validations_evaluator.MetricsPlotsAndValidationsEvaluator( + eval_config=eval_config, eval_shared_model=eval_shared_models + ) + ] + + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + metrics = ( + pipeline + | "Create" + >> beam.Create([e.SerializeToString() for e in examples * 1000]) + | "BatchExamples" >> tfx_io.BeamSource() + | "InputsToExtracts" >> model_eval_lib.BatchedInputsToExtracts() + | "ExtractAndEvaluate" + >> model_eval_lib.ExtractAndEvaluate( + extractors=extractors, evaluators=evaluators + ) + ) + + # pylint: enable=no-value-for-parameter + + def check_metrics(got): + try: + self.assertLen(got, 1) + got_slice_key, got_metrics = got[0] + self.assertEqual(got_slice_key, ()) + # check only the diff metrics. + weighted_example_count_key = metric_types.MetricKey( + name="weighted_example_count", + model_name="candidate", + is_diff=True, + example_weighted=True, + ) + prediction_key = metric_types.MetricKey( + name="mean_prediction", + model_name="candidate", + is_diff=True, + example_weighted=True, + ) + label_key = metric_types.MetricKey( + name="mean_label", + model_name="candidate", + is_diff=True, + example_weighted=True, + ) + self.assertDictElementsWithTDistributionAlmostEqual( + got_metrics, + { + weighted_example_count_key: 0, + label_key: 0, + prediction_key: 0 - (0 * 1 + 1 * 0.5) / (1 + 0.5), + }, + ) + + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(metrics[constants.METRICS_KEY], check_metrics) + + @parameterized.named_parameters( + ("compiled_metrics", False), + ("evaluate", True), ) + def testEvaluateWithKerasModelWithInGraphMetrics(self, add_custom_metrics): + # Custom metrics not supported in TFv1 + if _TF_MAJOR_VERSION < 2: + add_custom_metrics = False + + input1 = tf_keras.layers.Input(shape=(1,), name="input_1") + input2 = tf_keras.layers.Input(shape=(1,), name="input_2") + inputs = [input1, input2] + input_layer = tf_keras.layers.concatenate(inputs) + output_layer = tf_keras.layers.Dense( + 1, activation=tf.nn.sigmoid, name="output" + )(input_layer) + model = tf_keras.models.Model(inputs, output_layer) + # The model.evaluate API is used when custom metrics are used. Otherwise + # model.compiled_metrics is used. + if add_custom_metrics: + model.add_metric(tf.reduce_sum(input_layer), name="custom") + model.compile( + optimizer=tf_keras.optimizers.Adam(lr=0.001), + loss=tf_keras.losses.BinaryCrossentropy(name="loss"), + metrics=[tf_keras.metrics.BinaryAccuracy(name="binary_accuracy")], + ) - export_dir = self._getExportDir() - model.save(export_dir, save_format='tf') + export_dir = self._getExportDir() + model.save(export_dir, save_format="tf") - eval_config = config_pb2.EvalConfig( - model_specs=[ - config_pb2.ModelSpec( - label_key='label', example_weight_key='example_weight' - ) - ], - slicing_specs=[config_pb2.SlicingSpec()], - metrics_specs=metric_specs.specs_from_metrics( - [calibration.MeanLabel('mean_label')], - unweighted_metrics=[ - tf_keras.metrics.BinaryAccuracy(name='binary_accuracy'), - calibration.MeanLabel('mean_label'), + eval_config = config_pb2.EvalConfig( + model_specs=[ + config_pb2.ModelSpec( + label_key="label", example_weight_key="example_weight" + ) ], - ), - ) - eval_shared_model = self.createTestEvalSharedModel(model_path=export_dir) - - examples = [ - self._makeExample( - input_1=0.0, - input_2=1.0, - label=1.0, - example_weight=1.0, - extra_feature='non_model_feature', - ), - self._makeExample( - input_1=1.0, - input_2=0.0, - label=0.0, - example_weight=0.5, - extra_feature='non_model_feature', - ), - ] + slicing_specs=[config_pb2.SlicingSpec()], + metrics_specs=metric_specs.specs_from_metrics( + [calibration.MeanLabel("mean_label")], + unweighted_metrics=[ + tf_keras.metrics.BinaryAccuracy(name="binary_accuracy"), + calibration.MeanLabel("mean_label"), + ], + ), + ) + eval_shared_model = self.createTestEvalSharedModel(model_path=export_dir) + + examples = [ + self._makeExample( + input_1=0.0, + input_2=1.0, + label=1.0, + example_weight=1.0, + extra_feature="non_model_feature", + ), + self._makeExample( + input_1=1.0, + input_2=0.0, + label=0.0, + example_weight=0.5, + extra_feature="non_model_feature", + ), + ] - schema = text_format.Parse( - """ + schema = text_format.Parse( + """ tensor_representation_group { key: "" value { @@ -698,161 +696,157 @@ def testEvaluateWithKerasModelWithInGraphMetrics(self, add_custom_metrics): type: BYTES } """, - schema_pb2.Schema(), - ) - tfx_io = test_util.InMemoryTFExampleRecord( - schema=schema, raw_record_column_name=constants.ARROW_INPUT_COLUMN - ) - tensor_adapter_config = tensor_adapter.TensorAdapterConfig( - arrow_schema=tfx_io.ArrowSchema(), - tensor_representations=tfx_io.TensorRepresentations(), - ) - extractors = [ - features_extractor.FeaturesExtractor( - eval_config=eval_config, - tensor_representations=tensor_adapter_config.tensor_representations, - ), - labels_extractor.LabelsExtractor(eval_config), - example_weights_extractor.ExampleWeightsExtractor(eval_config), - predictions_extractor.PredictionsExtractor( - eval_shared_model=eval_shared_model, eval_config=eval_config - ), - unbatch_extractor.UnbatchExtractor(), - slice_key_extractor.SliceKeyExtractor(eval_config=eval_config), - ] - evaluators = [ - metrics_plots_and_validations_evaluator.MetricsPlotsAndValidationsEvaluator( - eval_config=eval_config, eval_shared_model=eval_shared_model + schema_pb2.Schema(), ) - ] - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - metrics = ( - pipeline - | 'Create' >> beam.Create([e.SerializeToString() for e in examples]) - | 'BatchExamples' >> tfx_io.BeamSource() - | 'InputsToExtracts' >> model_eval_lib.BatchedInputsToExtracts() - | 'ExtractAndEvaluate' - >> model_eval_lib.ExtractAndEvaluate( - extractors=extractors, evaluators=evaluators - ) - ) - - # pylint: enable=no-value-for-parameter - - def check_metrics(got): - try: - self.assertLen(got, 1) - got_slice_key, got_metrics = got[0] - self.assertEqual(got_slice_key, ()) - example_count_key = metric_types.MetricKey(name='example_count') - weighted_example_count_key = metric_types.MetricKey( - name='weighted_example_count', example_weighted=True - ) - label_key = metric_types.MetricKey( - name='mean_label', example_weighted=True - ) - label_unweighted_key = metric_types.MetricKey( - name='mean_label', example_weighted=False - ) - binary_accuracy_key = metric_types.MetricKey( - name='binary_accuracy', example_weighted=False - ) - self.assertIn(binary_accuracy_key, got_metrics) - binary_accuracy_unweighted_key = metric_types.MetricKey( - name='binary_accuracy', example_weighted=False - ) - self.assertIn(binary_accuracy_unweighted_key, got_metrics) - expected_values = { - example_count_key: 2, - weighted_example_count_key: 1.0 + 0.5, - label_key: (1.0 * 1.0 + 0.0 * 0.5) / (1.0 + 0.5), - label_unweighted_key: (1.0 + 0.0) / (1.0 + 1.0), - } - self.assertDictElementsAlmostEqual(got_metrics, expected_values) - - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that( - metrics[constants.METRICS_KEY], check_metrics, label='metrics' - ) - - def testAddCrossSliceMetricsMatchAll(self): - overall_slice_key = () - slice_key1 = (('feature', 1),) - slice_key2 = (('feature', 2),) - slice_key3 = (('feature', 3),) - metrics_dict = {} - sliced_metrics = [ - (overall_slice_key, metrics_dict), - (slice_key1, metrics_dict), - (slice_key2, metrics_dict), - (slice_key3, metrics_dict), - ] - with beam.Pipeline() as pipeline: - cross_sliced_metrics = ( - pipeline - | 'CreateSlicedMetrics' >> beam.Create(sliced_metrics) - | 'AddCrossSliceMetrics' - >> metrics_plots_and_validations_evaluator._AddCrossSliceMetrics( - cross_slice_specs=[ - config_pb2.CrossSlicingSpec( - baseline_spec={}, slicing_specs=[] - ) - ], - cross_slice_computations=[], - ) - ) - - def check_result(got_sliced_metrics): - actual_slice_keys = [k for k, _ in got_sliced_metrics] - expected_slice_keys = [ - # cross slice keys - (overall_slice_key, slice_key1), - (overall_slice_key, slice_key2), - (overall_slice_key, slice_key3), - # single slice keys - overall_slice_key, - slice_key1, - slice_key2, - slice_key3, + tfx_io = test_util.InMemoryTFExampleRecord( + schema=schema, raw_record_column_name=constants.ARROW_INPUT_COLUMN + ) + tensor_adapter_config = tensor_adapter.TensorAdapterConfig( + arrow_schema=tfx_io.ArrowSchema(), + tensor_representations=tfx_io.TensorRepresentations(), + ) + extractors = [ + features_extractor.FeaturesExtractor( + eval_config=eval_config, + tensor_representations=tensor_adapter_config.tensor_representations, + ), + labels_extractor.LabelsExtractor(eval_config), + example_weights_extractor.ExampleWeightsExtractor(eval_config), + predictions_extractor.PredictionsExtractor( + eval_shared_model=eval_shared_model, eval_config=eval_config + ), + unbatch_extractor.UnbatchExtractor(), + slice_key_extractor.SliceKeyExtractor(eval_config=eval_config), + ] + evaluators = [ + metrics_plots_and_validations_evaluator.MetricsPlotsAndValidationsEvaluator( + eval_config=eval_config, eval_shared_model=eval_shared_model + ) ] - self.assertCountEqual(expected_slice_keys, actual_slice_keys) - - util.assert_that(cross_sliced_metrics, check_result) - - @parameterized.named_parameters( - ('IntIsDiffable', 1, True), - ('FloatIsDiffable', 1.0, True), - ('NumpyFloatDtypeIsDiffable', np.array([1.0], dtype=np.float64), True), - ('NumpyIntDtypeIsDiffable', np.array([1], dtype=np.int64), True), - ('MessageNotDiffable', validation_result_pb2.ValidationResult(), False), - ( - 'TupleNotDiffable', - binary_confusion_matrices.Matrices( - thresholds=[-1e-7, 0.5, 1.0 + 1e-7], - tp=[2.0, 1.0, 0.0], - fp=[2.0, 0.0, 0.0], - tn=[0.0, 2.0, 2.0], - fn=[0.0, 1.0, 2.0], - ), - True, - ), - ('BytesNotDiffable', b'some bytes', False), - ('NumpyObjectDtypeNotDiffable', np.array(['obj'], dtype=object), False), - ) - def testIsMetricDiffable(self, metric_value, expected_is_diffable): - self.assertEqual( - expected_is_diffable, - metrics_plots_and_validations_evaluator._is_metric_diffable( - metric_value + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + metrics = ( + pipeline + | "Create" >> beam.Create([e.SerializeToString() for e in examples]) + | "BatchExamples" >> tfx_io.BeamSource() + | "InputsToExtracts" >> model_eval_lib.BatchedInputsToExtracts() + | "ExtractAndEvaluate" + >> model_eval_lib.ExtractAndEvaluate( + extractors=extractors, evaluators=evaluators + ) + ) + + # pylint: enable=no-value-for-parameter + + def check_metrics(got): + try: + self.assertLen(got, 1) + got_slice_key, got_metrics = got[0] + self.assertEqual(got_slice_key, ()) + example_count_key = metric_types.MetricKey(name="example_count") + weighted_example_count_key = metric_types.MetricKey( + name="weighted_example_count", example_weighted=True + ) + label_key = metric_types.MetricKey( + name="mean_label", example_weighted=True + ) + label_unweighted_key = metric_types.MetricKey( + name="mean_label", example_weighted=False + ) + binary_accuracy_key = metric_types.MetricKey( + name="binary_accuracy", example_weighted=False + ) + self.assertIn(binary_accuracy_key, got_metrics) + binary_accuracy_unweighted_key = metric_types.MetricKey( + name="binary_accuracy", example_weighted=False + ) + self.assertIn(binary_accuracy_unweighted_key, got_metrics) + expected_values = { + example_count_key: 2, + weighted_example_count_key: 1.0 + 0.5, + label_key: (1.0 * 1.0 + 0.0 * 0.5) / (1.0 + 0.5), + label_unweighted_key: (1.0 + 0.0) / (1.0 + 1.0), + } + self.assertDictElementsAlmostEqual(got_metrics, expected_values) + + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that( + metrics[constants.METRICS_KEY], check_metrics, label="metrics" + ) + + def testAddCrossSliceMetricsMatchAll(self): + overall_slice_key = () + slice_key1 = (("feature", 1),) + slice_key2 = (("feature", 2),) + slice_key3 = (("feature", 3),) + metrics_dict = {} + sliced_metrics = [ + (overall_slice_key, metrics_dict), + (slice_key1, metrics_dict), + (slice_key2, metrics_dict), + (slice_key3, metrics_dict), + ] + with beam.Pipeline() as pipeline: + cross_sliced_metrics = ( + pipeline + | "CreateSlicedMetrics" >> beam.Create(sliced_metrics) + | "AddCrossSliceMetrics" + >> metrics_plots_and_validations_evaluator._AddCrossSliceMetrics( + cross_slice_specs=[ + config_pb2.CrossSlicingSpec(baseline_spec={}, slicing_specs=[]) + ], + cross_slice_computations=[], + ) + ) + + def check_result(got_sliced_metrics): + actual_slice_keys = [k for k, _ in got_sliced_metrics] + expected_slice_keys = [ + # cross slice keys + (overall_slice_key, slice_key1), + (overall_slice_key, slice_key2), + (overall_slice_key, slice_key3), + # single slice keys + overall_slice_key, + slice_key1, + slice_key2, + slice_key3, + ] + self.assertCountEqual(expected_slice_keys, actual_slice_keys) + + util.assert_that(cross_sliced_metrics, check_result) + + @parameterized.named_parameters( + ("IntIsDiffable", 1, True), + ("FloatIsDiffable", 1.0, True), + ("NumpyFloatDtypeIsDiffable", np.array([1.0], dtype=np.float64), True), + ("NumpyIntDtypeIsDiffable", np.array([1], dtype=np.int64), True), + ("MessageNotDiffable", validation_result_pb2.ValidationResult(), False), + ( + "TupleNotDiffable", + binary_confusion_matrices.Matrices( + thresholds=[-1e-7, 0.5, 1.0 + 1e-7], + tp=[2.0, 1.0, 0.0], + fp=[2.0, 0.0, 0.0], + tn=[0.0, 2.0, 2.0], + fn=[0.0, 1.0, 2.0], + ), + True, ), + ("BytesNotDiffable", b"some bytes", False), + ("NumpyObjectDtypeNotDiffable", np.array(["obj"], dtype=object), False), ) + def testIsMetricDiffable(self, metric_value, expected_is_diffable): + self.assertEqual( + expected_is_diffable, + metrics_plots_and_validations_evaluator._is_metric_diffable(metric_value), + ) - def testMetricsSpecsCountersInModelAgnosticMode(self): - schema = text_format.Parse( - """ + def testMetricsSpecsCountersInModelAgnosticMode(self): + schema = text_format.Parse( + """ feature { name: "label" type: FLOAT @@ -862,70 +856,70 @@ def testMetricsSpecsCountersInModelAgnosticMode(self): type: FLOAT } """, - schema_pb2.Schema(), - ) + schema_pb2.Schema(), + ) - tfx_io = test_util.InMemoryTFExampleRecord( - schema=schema, raw_record_column_name=constants.ARROW_INPUT_COLUMN - ) + tfx_io = test_util.InMemoryTFExampleRecord( + schema=schema, raw_record_column_name=constants.ARROW_INPUT_COLUMN + ) + + examples = [ + self._makeExample(label=1.0, prediction=0.7), + self._makeExample(label=0.0, prediction=0.3), + ] - examples = [ - self._makeExample(label=1.0, prediction=0.7), - self._makeExample(label=0.0, prediction=0.3), - ] - - eval_config = config_pb2.EvalConfig( - model_specs=[ - config_pb2.ModelSpec(prediction_key='prediction', label_key='label') - ], - metrics_specs=[ - config_pb2.MetricsSpec( - metrics=[config_pb2.MetricConfig(class_name='ExampleCount')] + eval_config = config_pb2.EvalConfig( + model_specs=[ + config_pb2.ModelSpec(prediction_key="prediction", label_key="label") + ], + metrics_specs=[ + config_pb2.MetricsSpec( + metrics=[config_pb2.MetricConfig(class_name="ExampleCount")] + ) + ], + slicing_specs=[config_pb2.SlicingSpec()], + ) + + extractors = [ + features_extractor.FeaturesExtractor(eval_config), + labels_extractor.LabelsExtractor(eval_config), + example_weights_extractor.ExampleWeightsExtractor(eval_config), + materialized_predictions_extractor.MaterializedPredictionsExtractor( + eval_config + ), + unbatch_extractor.UnbatchExtractor(), + slice_key_extractor.SliceKeyExtractor(eval_config=eval_config), + ] + evaluators = [ + metrics_plots_and_validations_evaluator.MetricsPlotsAndValidationsEvaluator( + eval_config ) - ], - slicing_specs=[config_pb2.SlicingSpec()], - ) + ] - extractors = [ - features_extractor.FeaturesExtractor(eval_config), - labels_extractor.LabelsExtractor(eval_config), - example_weights_extractor.ExampleWeightsExtractor(eval_config), - materialized_predictions_extractor.MaterializedPredictionsExtractor( - eval_config - ), - unbatch_extractor.UnbatchExtractor(), - slice_key_extractor.SliceKeyExtractor(eval_config=eval_config), - ] - evaluators = [ - metrics_plots_and_validations_evaluator.MetricsPlotsAndValidationsEvaluator( - eval_config + with beam.Pipeline() as pipeline: + _ = ( + pipeline + | "Create" >> beam.Create([e.SerializeToString() for e in examples]) + | "BatchExamples" >> tfx_io.BeamSource() + | "InputsToExtracts" >> model_eval_lib.BatchedInputsToExtracts() + | "ExtractEvaluate" + >> model_eval_lib.ExtractAndEvaluate( + extractors=extractors, evaluators=evaluators + ) + ) + + metric_filter = beam.metrics.metric.MetricsFilter().with_name( + "metric_computed_ExampleCount_v2_" + constants.MODEL_AGNOSTIC ) - ] - - with beam.Pipeline() as pipeline: - _ = ( - pipeline - | 'Create' >> beam.Create([e.SerializeToString() for e in examples]) - | 'BatchExamples' >> tfx_io.BeamSource() - | 'InputsToExtracts' >> model_eval_lib.BatchedInputsToExtracts() - | 'ExtractEvaluate' - >> model_eval_lib.ExtractAndEvaluate( - extractors=extractors, evaluators=evaluators - ) - ) - - metric_filter = beam.metrics.metric.MetricsFilter().with_name( - 'metric_computed_ExampleCount_v2_' + constants.MODEL_AGNOSTIC - ) - actual_metrics_count = ( - pipeline.run() - .metrics() - .query(filter=metric_filter)['counters'][0] - .committed - ) - self.assertEqual(actual_metrics_count, 1) + actual_metrics_count = ( + pipeline.run() + .metrics() + .query(filter=metric_filter)["counters"][0] + .committed + ) + self.assertEqual(actual_metrics_count, 1) -if __name__ == '__main__': - tf.compat.v1.enable_v2_behavior() - tf.test.main() +if __name__ == "__main__": + tf.compat.v1.enable_v2_behavior() + tf.test.main() diff --git a/tensorflow_model_analysis/evaluators/metrics_validator.py b/tensorflow_model_analysis/evaluators/metrics_validator.py index 2a47e0d7d8..7c628465d5 100644 --- a/tensorflow_model_analysis/evaluators/metrics_validator.py +++ b/tensorflow_model_analysis/evaluators/metrics_validator.py @@ -15,11 +15,11 @@ import math from typing import Any, Dict, Iterable, List, Tuple, Union + import numpy as np -from tensorflow_model_analysis.metrics import metric_specs -from tensorflow_model_analysis.metrics import metric_types -from tensorflow_model_analysis.proto import config_pb2 -from tensorflow_model_analysis.proto import validation_result_pb2 + +from tensorflow_model_analysis.metrics import metric_specs, metric_types +from tensorflow_model_analysis.proto import config_pb2, validation_result_pb2 from tensorflow_model_analysis.slicer import slicer_lib as slicer from tensorflow_model_analysis.utils import model_util @@ -35,199 +35,199 @@ def validate_metrics( sliced_metrics: Tuple[ Union[slicer.SliceKeyType, slicer.CrossSliceKeyType], - Dict['metric_types.MetricKey', Any], + Dict["metric_types.MetricKey", Any], ], eval_config: config_pb2.EvalConfig, ) -> validation_result_pb2.ValidationResult: - """Check the metrics and check whether they should be validated.""" - # Find out which model is baseline. - baseline_spec = model_util.get_baseline_model_spec(eval_config) - baseline_model_name = baseline_spec.name if baseline_spec else None + """Check the metrics and check whether they should be validated.""" + # Find out which model is baseline. + baseline_spec = model_util.get_baseline_model_spec(eval_config) + baseline_model_name = baseline_spec.name if baseline_spec else None - sliced_key, metrics = sliced_metrics - thresholds = metric_specs.metric_thresholds_from_metrics_specs( - eval_config.metrics_specs, eval_config=eval_config - ) - is_cross_slice = slicer.is_cross_slice_key(sliced_key) + sliced_key, metrics = sliced_metrics + thresholds = metric_specs.metric_thresholds_from_metrics_specs( + eval_config.metrics_specs, eval_config=eval_config + ) + is_cross_slice = slicer.is_cross_slice_key(sliced_key) - def _check_threshold( - key: metric_types.MetricKey, threshold: _ThresholdType, metric: Any - ) -> bool: - """Verify a metric given its metric key and metric value.""" - metric = float(metric) - if isinstance(threshold, config_pb2.GenericValueThreshold): - lower_bound, upper_bound = -np.inf, np.inf - if threshold.HasField('lower_bound'): - lower_bound = threshold.lower_bound.value - if threshold.HasField('upper_bound'): - upper_bound = threshold.upper_bound.value - return metric >= lower_bound and metric <= upper_bound - elif isinstance(threshold, config_pb2.GenericChangeThreshold): - diff = metric + def _check_threshold( + key: metric_types.MetricKey, threshold: _ThresholdType, metric: Any + ) -> bool: + """Verify a metric given its metric key and metric value.""" + metric = float(metric) + if isinstance(threshold, config_pb2.GenericValueThreshold): + lower_bound, upper_bound = -np.inf, np.inf + if threshold.HasField("lower_bound"): + lower_bound = threshold.lower_bound.value + if threshold.HasField("upper_bound"): + upper_bound = threshold.upper_bound.value + return metric >= lower_bound and metric <= upper_bound + elif isinstance(threshold, config_pb2.GenericChangeThreshold): + diff = metric - if threshold.HasField('absolute'): - absolute = threshold.absolute.value - if threshold.direction == config_pb2.MetricDirection.LOWER_IS_BETTER: - abs_result = diff <= absolute - elif threshold.direction == config_pb2.MetricDirection.HIGHER_IS_BETTER: - abs_result = diff >= absolute - else: - raise ValueError( - 'Unexpected change threshold direction: {}.'.format(threshold) - ) - else: - abs_result = True + if threshold.HasField("absolute"): + absolute = threshold.absolute.value + if threshold.direction == config_pb2.MetricDirection.LOWER_IS_BETTER: + abs_result = diff <= absolute + elif threshold.direction == config_pb2.MetricDirection.HIGHER_IS_BETTER: + abs_result = diff >= absolute + else: + raise ValueError( + f"Unexpected change threshold direction: {threshold}." + ) + else: + abs_result = True - if threshold.HasField('relative'): - metric_baseline = float( - metrics[key.make_baseline_key(baseline_model_name)] - ) - if math.isclose(metric_baseline, 0.0): - ratio = float('nan') - else: - ratio = diff / metric_baseline + if threshold.HasField("relative"): + metric_baseline = float( + metrics[key.make_baseline_key(baseline_model_name)] + ) + if math.isclose(metric_baseline, 0.0): + ratio = float("nan") + else: + ratio = diff / metric_baseline - relative = threshold.relative.value - if threshold.direction == config_pb2.MetricDirection.LOWER_IS_BETTER: - rel_result = ratio <= relative - elif threshold.direction == config_pb2.MetricDirection.HIGHER_IS_BETTER: - rel_result = ratio >= relative - else: - raise ValueError( - 'Unexpected change threshold direction: {}.'.format(threshold) - ) - else: - rel_result = True + relative = threshold.relative.value + if threshold.direction == config_pb2.MetricDirection.LOWER_IS_BETTER: + rel_result = ratio <= relative + elif threshold.direction == config_pb2.MetricDirection.HIGHER_IS_BETTER: + rel_result = ratio >= relative + else: + raise ValueError( + f"Unexpected change threshold direction: {threshold}." + ) + else: + rel_result = True - return abs_result and rel_result - else: - raise ValueError('Unknown threshold: {}'.format(threshold)) + return abs_result and rel_result + else: + raise ValueError(f"Unknown threshold: {threshold}") - def _copy_metric(metric, to): - # Will add more types when more MetricValue are supported. - to.double_value.value = float(metric) + def _copy_metric(metric, to): + # Will add more types when more MetricValue are supported. + to.double_value.value = float(metric) - def _copy_threshold(threshold, to): - if isinstance(threshold, config_pb2.GenericValueThreshold): - to.value_threshold.CopyFrom(threshold) - if isinstance(threshold, config_pb2.GenericChangeThreshold): - to.change_threshold.CopyFrom(threshold) + def _copy_threshold(threshold, to): + if isinstance(threshold, config_pb2.GenericValueThreshold): + to.value_threshold.CopyFrom(threshold) + if isinstance(threshold, config_pb2.GenericChangeThreshold): + to.change_threshold.CopyFrom(threshold) - def _add_to_set(s, v): - """Adds value to set. Returns true if didn't exist.""" - if v in s: - return False - else: - s.add(v) - return True + def _add_to_set(s, v): + """Adds value to set. Returns true if didn't exist.""" + if v in s: + return False + else: + s.add(v) + return True - # Empty metrics per slice is considered validated. - result = validation_result_pb2.ValidationResult(validation_ok=True) - validation_for_slice = validation_result_pb2.MetricsValidationForSlice() - unchecked_thresholds = dict(thresholds) - for metric_key, metric in metrics.items(): - if metric_key not in thresholds: - continue - del unchecked_thresholds[metric_key] - # Not meaningful to check threshold for baseline model, thus always return - # True if such threshold is configured. We also do not compare Message type - # metrics. - if metric_key.model_name == baseline_model_name: - continue - msg = '' - existing_failures = set() - for slice_spec, threshold in thresholds[metric_key]: - if slice_spec is not None: - if isinstance(slice_spec, config_pb2.SlicingSpec) and ( - is_cross_slice - or not slicer.SingleSliceSpec(spec=slice_spec).is_slice_applicable( - sliced_key - ) - ): - continue - if isinstance(slice_spec, config_pb2.CrossSlicingSpec) and ( - not is_cross_slice - or not slicer.is_cross_slice_applicable( - cross_slice_key=sliced_key, cross_slicing_spec=slice_spec - ) - ): - continue - elif is_cross_slice: - continue - try: - check_result = _check_threshold(metric_key, threshold, metric) - except ValueError: - msg = """ + # Empty metrics per slice is considered validated. + result = validation_result_pb2.ValidationResult(validation_ok=True) + validation_for_slice = validation_result_pb2.MetricsValidationForSlice() + unchecked_thresholds = dict(thresholds) + for metric_key, metric in metrics.items(): + if metric_key not in thresholds: + continue + del unchecked_thresholds[metric_key] + # Not meaningful to check threshold for baseline model, thus always return + # True if such threshold is configured. We also do not compare Message type + # metrics. + if metric_key.model_name == baseline_model_name: + continue + msg = "" + existing_failures = set() + for slice_spec, threshold in thresholds[metric_key]: + if slice_spec is not None: + if isinstance(slice_spec, config_pb2.SlicingSpec) and ( + is_cross_slice + or not slicer.SingleSliceSpec(spec=slice_spec).is_slice_applicable( + sliced_key + ) + ): + continue + if isinstance(slice_spec, config_pb2.CrossSlicingSpec) and ( + not is_cross_slice + or not slicer.is_cross_slice_applicable( + cross_slice_key=sliced_key, cross_slicing_spec=slice_spec + ) + ): + continue + elif is_cross_slice: + continue + try: + check_result = _check_threshold(metric_key, threshold, metric) + except ValueError: + msg = f""" Invalid metrics or threshold for comparison: The type of the metric - is: {}, the metric value is: {}, and the threshold is: {}. - """.format(type(metric), metric, threshold) - check_result = False - else: - msg = '' - if not check_result: - # The same threshold values could be set for multiple matching slice - # specs. Only store the first match. - # - # Note that hashing by SerializeToString() is only safe if used within - # the same process. - if not _add_to_set(existing_failures, threshold.SerializeToString()): - continue - failure = validation_for_slice.failures.add() - failure.metric_key.CopyFrom(metric_key.to_proto()) - _copy_metric(metric, failure.metric_value) - _copy_threshold(threshold, failure.metric_threshold) - failure.message = msg - # Track we have completed a validation check for slice spec and metric - slicing_details = result.validation_details.slicing_details.add() - if slice_spec is not None: - if isinstance(slice_spec, config_pb2.SlicingSpec): - slicing_details.slicing_spec.CopyFrom(slice_spec) + is: {type(metric)}, the metric value is: {metric}, and the threshold is: {threshold}. + """ + check_result = False + else: + msg = "" + if not check_result: + # The same threshold values could be set for multiple matching slice + # specs. Only store the first match. + # + # Note that hashing by SerializeToString() is only safe if used within + # the same process. + if not _add_to_set(existing_failures, threshold.SerializeToString()): + continue + failure = validation_for_slice.failures.add() + failure.metric_key.CopyFrom(metric_key.to_proto()) + _copy_metric(metric, failure.metric_value) + _copy_threshold(threshold, failure.metric_threshold) + failure.message = msg + # Track we have completed a validation check for slice spec and metric + slicing_details = result.validation_details.slicing_details.add() + if slice_spec is not None: + if isinstance(slice_spec, config_pb2.SlicingSpec): + slicing_details.slicing_spec.CopyFrom(slice_spec) + else: + slicing_details.cross_slicing_spec.CopyFrom(slice_spec) + else: + slicing_details.slicing_spec.CopyFrom(config_pb2.SlicingSpec()) + slicing_details.num_matching_slices = 1 + # All unchecked thresholds are considered failures. + for metric_key, thresholds in unchecked_thresholds.items(): + if metric_key.model_name == baseline_model_name: + continue + existing_failures = set() + for slice_spec, threshold in thresholds: + if slice_spec is not None: + if is_cross_slice != isinstance( + slice_spec, config_pb2.CrossSlicingSpec + ): + continue + if is_cross_slice and not slicer.is_cross_slice_applicable( + cross_slice_key=sliced_key, cross_slicing_spec=slice_spec + ): + continue + elif is_cross_slice: + continue + # The same threshold values could be set for multiple matching slice + # specs. Only store the first match. + # + # Note that hashing by SerializeToString() is only safe if used within + # the same process. + if not _add_to_set(existing_failures, threshold.SerializeToString()): + continue + failure = validation_for_slice.failures.add() + failure.metric_key.CopyFrom(metric_key.to_proto()) + _copy_threshold(threshold, failure.metric_threshold) + failure.message = "Metric not found." + # Any failure leads to overall failure. + if validation_for_slice.failures: + if not is_cross_slice: + validation_for_slice.slice_key.CopyFrom( + slicer.serialize_slice_key(sliced_key) + ) else: - slicing_details.cross_slicing_spec.CopyFrom(slice_spec) - else: - slicing_details.slicing_spec.CopyFrom(config_pb2.SlicingSpec()) - slicing_details.num_matching_slices = 1 - # All unchecked thresholds are considered failures. - for metric_key, thresholds in unchecked_thresholds.items(): - if metric_key.model_name == baseline_model_name: - continue - existing_failures = set() - for slice_spec, threshold in thresholds: - if slice_spec is not None: - if is_cross_slice != isinstance( - slice_spec, config_pb2.CrossSlicingSpec - ): - continue - if is_cross_slice and not slicer.is_cross_slice_applicable( - cross_slice_key=sliced_key, cross_slicing_spec=slice_spec - ): - continue - elif is_cross_slice: - continue - # The same threshold values could be set for multiple matching slice - # specs. Only store the first match. - # - # Note that hashing by SerializeToString() is only safe if used within - # the same process. - if not _add_to_set(existing_failures, threshold.SerializeToString()): - continue - failure = validation_for_slice.failures.add() - failure.metric_key.CopyFrom(metric_key.to_proto()) - _copy_threshold(threshold, failure.metric_threshold) - failure.message = 'Metric not found.' - # Any failure leads to overall failure. - if validation_for_slice.failures: - if not is_cross_slice: - validation_for_slice.slice_key.CopyFrom( - slicer.serialize_slice_key(sliced_key) - ) - else: - validation_for_slice.cross_slice_key.CopyFrom( - slicer.serialize_cross_slice_key(sliced_key) - ) - result.validation_ok = False - result.metric_validations_per_slice.append(validation_for_slice) - return result + validation_for_slice.cross_slice_key.CopyFrom( + slicer.serialize_cross_slice_key(sliced_key) + ) + result.validation_ok = False + result.metric_validations_per_slice.append(validation_for_slice) + return result def _hashed_slicing_details( @@ -236,67 +236,69 @@ def _hashed_slicing_details( Union[slicer.SingleSliceSpec, slicer.CrossSliceSpec], validation_result_pb2.SlicingDetails, ]: - """Returns hash table of slicing details keyed by serialized slice spec.""" - hashed_details = {} - for details in slicing_details: - hashable_slice_spec = slicer.deserialize_slice_spec(details.slicing_spec) - if hashable_slice_spec not in hashed_details: - hashed_details[hashable_slice_spec] = details - return hashed_details + """Returns hash table of slicing details keyed by serialized slice spec.""" + hashed_details = {} + for details in slicing_details: + hashable_slice_spec = slicer.deserialize_slice_spec(details.slicing_spec) + if hashable_slice_spec not in hashed_details: + hashed_details[hashable_slice_spec] = details + return hashed_details def merge_details( a: validation_result_pb2.ValidationResult, b: validation_result_pb2.ValidationResult, ): - """Merges validation details in ValidationtResult b into ValidationResult a.""" - hashed_details = _hashed_slicing_details(b.validation_details.slicing_details) - # Combine a with matching values from b - for details in a.validation_details.slicing_details: - hashable_slice_spec = slicer.deserialize_slice_spec(details.slicing_spec) - if hashable_slice_spec in hashed_details: - details.num_matching_slices = ( - details.num_matching_slices - + hashed_details[hashable_slice_spec].num_matching_slices - ) - del hashed_details[hashable_slice_spec] - # Add any values from b not matched in a - for details in hashed_details.values(): - a.validation_details.slicing_details.append(details) + """Merges validation details in ValidationtResult b into ValidationResult a.""" + hashed_details = _hashed_slicing_details(b.validation_details.slicing_details) + # Combine a with matching values from b + for details in a.validation_details.slicing_details: + hashable_slice_spec = slicer.deserialize_slice_spec(details.slicing_spec) + if hashable_slice_spec in hashed_details: + details.num_matching_slices = ( + details.num_matching_slices + + hashed_details[hashable_slice_spec].num_matching_slices + ) + del hashed_details[hashable_slice_spec] + # Add any values from b not matched in a + for details in hashed_details.values(): + a.validation_details.slicing_details.append(details) def get_missing_slices( slicing_details: Iterable[validation_result_pb2.SlicingDetails], eval_config: config_pb2.EvalConfig, ) -> List[Union[config_pb2.SlicingSpec, config_pb2.CrossSlicingSpec]]: - """Returns specs that are defined in the EvalConfig but not found in details. + """Returns specs that are defined in the EvalConfig but not found in details. - Args: - slicing_details: Slicing details. - eval_config: Eval config. + Args: + ---- + slicing_details: Slicing details. + eval_config: Eval config. - Returns: - List of missing slices or empty list if none are missing. - """ - hashed_details = _hashed_slicing_details(slicing_details) - thresholds = metric_specs.metric_thresholds_from_metrics_specs( - eval_config.metrics_specs, eval_config=eval_config - ) - baseline_spec = model_util.get_baseline_model_spec(eval_config) - baseline_model_name = baseline_spec.name if baseline_spec else None - missing_slices = [] - for metric_key, sliced_thresholds in thresholds.items(): - # Skip baseline. - if metric_key.model_name == baseline_model_name: - continue - for slice_spec, _ in sliced_thresholds: - if not slice_spec: - slice_spec = config_pb2.SlicingSpec() - hashable_slice_spec = slicer.deserialize_slice_spec(slice_spec) - if hashable_slice_spec not in hashed_details: - missing_slices.append(slice_spec) - # Same slice may be used by other metrics/thresholds, only add once - hashed_details[hashable_slice_spec] = ( - validation_result_pb2.SlicingDetails() - ) - return missing_slices + Returns: + ------- + List of missing slices or empty list if none are missing. + """ + hashed_details = _hashed_slicing_details(slicing_details) + thresholds = metric_specs.metric_thresholds_from_metrics_specs( + eval_config.metrics_specs, eval_config=eval_config + ) + baseline_spec = model_util.get_baseline_model_spec(eval_config) + baseline_model_name = baseline_spec.name if baseline_spec else None + missing_slices = [] + for metric_key, sliced_thresholds in thresholds.items(): + # Skip baseline. + if metric_key.model_name == baseline_model_name: + continue + for slice_spec, _ in sliced_thresholds: + if not slice_spec: + slice_spec = config_pb2.SlicingSpec() + hashable_slice_spec = slicer.deserialize_slice_spec(slice_spec) + if hashable_slice_spec not in hashed_details: + missing_slices.append(slice_spec) + # Same slice may be used by other metrics/thresholds, only add once + hashed_details[hashable_slice_spec] = ( + validation_result_pb2.SlicingDetails() + ) + return missing_slices diff --git a/tensorflow_model_analysis/evaluators/metrics_validator_test.py b/tensorflow_model_analysis/evaluators/metrics_validator_test.py index 10a5c1c8ed..eeda86e368 100644 --- a/tensorflow_model_analysis/evaluators/metrics_validator_test.py +++ b/tensorflow_model_analysis/evaluators/metrics_validator_test.py @@ -13,162 +13,161 @@ # limitations under the License. """Test for MetricsAndPlotsEvaluator.""" -from absl.testing import parameterized import tensorflow as tf +from absl.testing import parameterized +from google.protobuf import text_format + from tensorflow_model_analysis.api import types from tensorflow_model_analysis.evaluators import metrics_validator from tensorflow_model_analysis.metrics import metric_types -from tensorflow_model_analysis.proto import config_pb2 -from tensorflow_model_analysis.proto import validation_result_pb2 +from tensorflow_model_analysis.proto import config_pb2, validation_result_pb2 from tensorflow_model_analysis.slicer import slicer_lib as slicer from tensorflow_model_analysis.utils import test_util -from google.protobuf import text_format # Tests involiving slices: (, , ) -_NO_SLICE_TEST = ('no_slice', None, (())) -_GLOBAL_SLICE_TEST = ('global_slice', [config_pb2.SlicingSpec()], (())) +_NO_SLICE_TEST = ("no_slice", None, (())) +_GLOBAL_SLICE_TEST = ("global_slice", [config_pb2.SlicingSpec()], (())) _FEATURE_SLICE_TEST = ( - 'feature_slice', - [config_pb2.SlicingSpec(feature_keys=['feature1'])], - (('feature1', 'value1'),), + "feature_slice", + [config_pb2.SlicingSpec(feature_keys=["feature1"])], + (("feature1", "value1"),), ) _FEATURE_VALUE_SLICE_TEST = ( - 'feature_value_slice', - [config_pb2.SlicingSpec(feature_values={'feature1': 'value1'})], - (('feature1', 'value1'),), + "feature_value_slice", + [config_pb2.SlicingSpec(feature_values={"feature1": "value1"})], + (("feature1", "value1"),), ) _MULTIPLE_SLICES_TEST = ( - 'multiple_slices', + "multiple_slices", [ - config_pb2.SlicingSpec(feature_values={'feature1': 'value1'}), - config_pb2.SlicingSpec(feature_values={'feature2': 'value2'}), + config_pb2.SlicingSpec(feature_values={"feature1": "value1"}), + config_pb2.SlicingSpec(feature_values={"feature2": "value2"}), ], - (('feature1', 'value1'),), + (("feature1", "value1"),), ) _UNMATCHED_SINGLE_SLICE_TEST = ( - 'single_slice', - [config_pb2.SlicingSpec(feature_keys='feature1')], - (('unmatched_feature', 'unmatched_value'),), + "single_slice", + [config_pb2.SlicingSpec(feature_keys="feature1")], + (("unmatched_feature", "unmatched_value"),), ) _UNMATCHED_MULTIPLE_SLICES_TEST = ( - 'multiple_slices', + "multiple_slices", [ - config_pb2.SlicingSpec(feature_values={'feature1': 'value1'}), - config_pb2.SlicingSpec(feature_values={'feature2': 'value2'}), + config_pb2.SlicingSpec(feature_values={"feature1": "value1"}), + config_pb2.SlicingSpec(feature_values={"feature2": "value2"}), ], - (('unmatched_feature', 'unmatched_value'),), + (("unmatched_feature", "unmatched_value"),), ) # Cross slice tests: (, , ) _CROSS_SLICE_GLOBAL_TEST = ( - 'global_slice', + "global_slice", [ config_pb2.CrossSlicingSpec( baseline_spec=config_pb2.SlicingSpec(), slicing_specs=[ - config_pb2.SlicingSpec(feature_values={'feature2': 'value2'}) + config_pb2.SlicingSpec(feature_values={"feature2": "value2"}) ], ) ], - ((()), (('feature2', 'value2'),)), + ((()), (("feature2", "value2"),)), ) _SINGLE_CROSS_SLICE_TEST = ( - 'single_slice', + "single_slice", [ config_pb2.CrossSlicingSpec( - baseline_spec=config_pb2.SlicingSpec(feature_keys=['feature1']), + baseline_spec=config_pb2.SlicingSpec(feature_keys=["feature1"]), slicing_specs=[ - config_pb2.SlicingSpec(feature_values={'feature2': 'value2'}) + config_pb2.SlicingSpec(feature_values={"feature2": "value2"}) ], ) ], - ((('feature1', 'value1'),), (('feature2', 'value2'),)), + ((("feature1", "value1"),), (("feature2", "value2"),)), ) _MULTIPLE_CROSS_SLICE_TEST = ( - 'multiple_slice', + "multiple_slice", [ config_pb2.CrossSlicingSpec( - baseline_spec=config_pb2.SlicingSpec(feature_keys=['feature1']), + baseline_spec=config_pb2.SlicingSpec(feature_keys=["feature1"]), slicing_specs=[ - config_pb2.SlicingSpec(feature_values={'feature2': 'value2'}) + config_pb2.SlicingSpec(feature_values={"feature2": "value2"}) ], ), config_pb2.CrossSlicingSpec( - baseline_spec=config_pb2.SlicingSpec(feature_keys=['feature2']), + baseline_spec=config_pb2.SlicingSpec(feature_keys=["feature2"]), slicing_specs=[ - config_pb2.SlicingSpec(feature_values={'feature3': 'value3'}) + config_pb2.SlicingSpec(feature_values={"feature3": "value3"}) ], ), ], - ((('feature2', 'value2'),), (('feature3', 'value3'),)), + ((("feature2", "value2"),), (("feature3", "value3"),)), ) _CROSS_SLICE_MULTIPLE_SLICING_SPEC_TEST = ( - 'multiple_slicing_spec', + "multiple_slicing_spec", [ config_pb2.CrossSlicingSpec( - baseline_spec=config_pb2.SlicingSpec(feature_keys=['feature1']), + baseline_spec=config_pb2.SlicingSpec(feature_keys=["feature1"]), slicing_specs=[ - config_pb2.SlicingSpec(feature_values={'feature2': 'value2'}), - config_pb2.SlicingSpec(feature_keys=['feature3']), + config_pb2.SlicingSpec(feature_values={"feature2": "value2"}), + config_pb2.SlicingSpec(feature_keys=["feature3"]), ], ) ], - ((('feature1', 'value1'),), (('feature3', 'value3'),)), + ((("feature1", "value1"),), (("feature3", "value3"),)), ) _UNMATCHED_CROSS_SLICE_TEST = ( - 'unmatched_cross_slice', + "unmatched_cross_slice", [ config_pb2.CrossSlicingSpec( - baseline_spec=config_pb2.SlicingSpec(feature_keys=['feature1']), + baseline_spec=config_pb2.SlicingSpec(feature_keys=["feature1"]), slicing_specs=[ - config_pb2.SlicingSpec(feature_values={'feature2': 'value2'}) + config_pb2.SlicingSpec(feature_values={"feature2": "value2"}) ], ), config_pb2.CrossSlicingSpec( - baseline_spec=config_pb2.SlicingSpec(feature_keys=['feature2']), + baseline_spec=config_pb2.SlicingSpec(feature_keys=["feature2"]), slicing_specs=[ - config_pb2.SlicingSpec(feature_values={'feature3': 'value3'}) + config_pb2.SlicingSpec(feature_values={"feature3": "value3"}) ], ), ], - ((('feature1', 'value1'),), (('feature3', 'value3'),)), + ((("feature1", "value1"),), (("feature3", "value3"),)), ) class MetricsValidatorTest( test_util.TensorflowModelAnalysisTest, parameterized.TestCase ): - - def testValidateMetricsInvalidThreshold(self): - eval_config = config_pb2.EvalConfig( - model_specs=[ - config_pb2.ModelSpec(), - ], - slicing_specs=[config_pb2.SlicingSpec()], - metrics_specs=[ - config_pb2.MetricsSpec( - thresholds={ - 'invalid_threshold': config_pb2.MetricThreshold( - value_threshold=config_pb2.GenericValueThreshold( - lower_bound={'value': 0.2} + def testValidateMetricsInvalidThreshold(self): + eval_config = config_pb2.EvalConfig( + model_specs=[ + config_pb2.ModelSpec(), + ], + slicing_specs=[config_pb2.SlicingSpec()], + metrics_specs=[ + config_pb2.MetricsSpec( + thresholds={ + "invalid_threshold": config_pb2.MetricThreshold( + value_threshold=config_pb2.GenericValueThreshold( + lower_bound={"value": 0.2} + ) ) - ) - } - ) - ], - ) - sliced_metrics = ( - (()), - { - metric_types.MetricKey( - name='weighted_example_count', example_weighted=True - ): 1.5, - }, - ) - result = metrics_validator.validate_metrics(sliced_metrics, eval_config) - self.assertFalse(result.validation_ok) - expected = text_format.Parse( - """ + } + ) + ], + ) + sliced_metrics = ( + (()), + { + metric_types.MetricKey( + name="weighted_example_count", example_weighted=True + ): 1.5, + }, + ) + result = metrics_validator.validate_metrics(sliced_metrics, eval_config) + self.assertFalse(result.validation_ok) + expected = text_format.Parse( + """ metric_validations_per_slice { slice_key { } @@ -186,59 +185,57 @@ def testValidateMetricsInvalidThreshold(self): message: 'Metric not found.' } }""", - validation_result_pb2.ValidationResult(), - ) - self.assertProtoEquals(expected, result) + validation_result_pb2.ValidationResult(), + ) + self.assertProtoEquals(expected, result) - @parameterized.named_parameters( - _NO_SLICE_TEST, - _GLOBAL_SLICE_TEST, - _FEATURE_SLICE_TEST, - _FEATURE_VALUE_SLICE_TEST, - _MULTIPLE_SLICES_TEST, - ) - def testValidateMetricsMetricTDistributionValueAndThreshold( - self, slicing_specs, slice_key - ): - threshold = config_pb2.MetricThreshold( - value_threshold=config_pb2.GenericValueThreshold( - lower_bound={'value': 0.9} + @parameterized.named_parameters( + _NO_SLICE_TEST, + _GLOBAL_SLICE_TEST, + _FEATURE_SLICE_TEST, + _FEATURE_VALUE_SLICE_TEST, + _MULTIPLE_SLICES_TEST, + ) + def testValidateMetricsMetricTDistributionValueAndThreshold( + self, slicing_specs, slice_key + ): + threshold = config_pb2.MetricThreshold( + value_threshold=config_pb2.GenericValueThreshold(lower_bound={"value": 0.9}) ) - ) - eval_config = config_pb2.EvalConfig( - model_specs=[ - config_pb2.ModelSpec(), - ], - slicing_specs=slicing_specs, - metrics_specs=[ - config_pb2.MetricsSpec( - metrics=[ - config_pb2.MetricConfig( - class_name='AUC', - threshold=threshold if slicing_specs is None else None, - per_slice_thresholds=[ - config_pb2.PerSliceMetricThreshold( - slicing_specs=slicing_specs, threshold=threshold - ) - ], - ), - ], - model_names=[''], - ), - ], - ) - sliced_metrics = ( - slice_key, - { - metric_types.MetricKey(name='auc'): types.ValueWithTDistribution( - sample_mean=0.91, unsampled_value=0.8 - ) - }, - ) - result = metrics_validator.validate_metrics(sliced_metrics, eval_config) - self.assertFalse(result.validation_ok) - expected = text_format.Parse( - """ + eval_config = config_pb2.EvalConfig( + model_specs=[ + config_pb2.ModelSpec(), + ], + slicing_specs=slicing_specs, + metrics_specs=[ + config_pb2.MetricsSpec( + metrics=[ + config_pb2.MetricConfig( + class_name="AUC", + threshold=threshold if slicing_specs is None else None, + per_slice_thresholds=[ + config_pb2.PerSliceMetricThreshold( + slicing_specs=slicing_specs, threshold=threshold + ) + ], + ), + ], + model_names=[""], + ), + ], + ) + sliced_metrics = ( + slice_key, + { + metric_types.MetricKey(name="auc"): types.ValueWithTDistribution( + sample_mean=0.91, unsampled_value=0.8 + ) + }, + ) + result = metrics_validator.validate_metrics(sliced_metrics, eval_config) + self.assertFalse(result.validation_ok) + expected = text_format.Parse( + """ metric_validations_per_slice { failures { metric_key { @@ -252,85 +249,81 @@ def testValidateMetricsMetricTDistributionValueAndThreshold( } } }""", - validation_result_pb2.ValidationResult(), - ) - expected.metric_validations_per_slice[0].failures[ - 0 - ].metric_threshold.CopyFrom(threshold) - expected.metric_validations_per_slice[0].slice_key.CopyFrom( - slicer.serialize_slice_key(slice_key) - ) - for spec in slicing_specs or [None]: - if spec is None or slicer.SingleSliceSpec(spec=spec).is_slice_applicable( - slice_key - ): - slicing_details = expected.validation_details.slicing_details.add() - if spec is not None: - slicing_details.slicing_spec.CopyFrom(spec) - else: - slicing_details.slicing_spec.CopyFrom(config_pb2.SlicingSpec()) - slicing_details.num_matching_slices = 1 - self.assertEqual(result, expected) + validation_result_pb2.ValidationResult(), + ) + expected.metric_validations_per_slice[0].failures[0].metric_threshold.CopyFrom( + threshold + ) + expected.metric_validations_per_slice[0].slice_key.CopyFrom( + slicer.serialize_slice_key(slice_key) + ) + for spec in slicing_specs or [None]: + if spec is None or slicer.SingleSliceSpec(spec=spec).is_slice_applicable( + slice_key + ): + slicing_details = expected.validation_details.slicing_details.add() + if spec is not None: + slicing_details.slicing_spec.CopyFrom(spec) + else: + slicing_details.slicing_spec.CopyFrom(config_pb2.SlicingSpec()) + slicing_details.num_matching_slices = 1 + self.assertEqual(result, expected) - @parameterized.named_parameters( - _NO_SLICE_TEST, - _GLOBAL_SLICE_TEST, - _FEATURE_SLICE_TEST, - _FEATURE_VALUE_SLICE_TEST, - _MULTIPLE_SLICES_TEST, - ) - def testValidateMetricsMetricTDistributionChangeAndThreshold( - self, slicing_specs, slice_key - ): - threshold = config_pb2.MetricThreshold( - change_threshold=config_pb2.GenericChangeThreshold( - direction=config_pb2.MetricDirection.LOWER_IS_BETTER, - absolute={'value': -1}, + @parameterized.named_parameters( + _NO_SLICE_TEST, + _GLOBAL_SLICE_TEST, + _FEATURE_SLICE_TEST, + _FEATURE_VALUE_SLICE_TEST, + _MULTIPLE_SLICES_TEST, + ) + def testValidateMetricsMetricTDistributionChangeAndThreshold( + self, slicing_specs, slice_key + ): + threshold = config_pb2.MetricThreshold( + change_threshold=config_pb2.GenericChangeThreshold( + direction=config_pb2.MetricDirection.LOWER_IS_BETTER, + absolute={"value": -1}, + ) ) - ) - eval_config = config_pb2.EvalConfig( - model_specs=[ - config_pb2.ModelSpec(), - config_pb2.ModelSpec(name='baseline', is_baseline=True), - ], - slicing_specs=slicing_specs, - metrics_specs=[ - config_pb2.MetricsSpec( - metrics=[ - config_pb2.MetricConfig( - class_name='AUC', - threshold=threshold if slicing_specs is None else None, - per_slice_thresholds=[ - config_pb2.PerSliceMetricThreshold( - slicing_specs=slicing_specs, threshold=threshold - ) - ], - ), - ], - model_names=[''], - ), - ], - ) - sliced_metrics = ( - slice_key, - { - # This is the mean of the diff. - metric_types.MetricKey( - name='auc', model_name='baseline' - ): types.ValueWithTDistribution( - sample_mean=0.91, unsampled_value=0.6 - ), - metric_types.MetricKey( - name='auc', is_diff=True - ): types.ValueWithTDistribution( - sample_mean=0.1, unsampled_value=0.1 - ), - }, - ) - result = metrics_validator.validate_metrics(sliced_metrics, eval_config) - self.assertFalse(result.validation_ok) - expected = text_format.Parse( - """ + eval_config = config_pb2.EvalConfig( + model_specs=[ + config_pb2.ModelSpec(), + config_pb2.ModelSpec(name="baseline", is_baseline=True), + ], + slicing_specs=slicing_specs, + metrics_specs=[ + config_pb2.MetricsSpec( + metrics=[ + config_pb2.MetricConfig( + class_name="AUC", + threshold=threshold if slicing_specs is None else None, + per_slice_thresholds=[ + config_pb2.PerSliceMetricThreshold( + slicing_specs=slicing_specs, threshold=threshold + ) + ], + ), + ], + model_names=[""], + ), + ], + ) + sliced_metrics = ( + slice_key, + { + # This is the mean of the diff. + metric_types.MetricKey( + name="auc", model_name="baseline" + ): types.ValueWithTDistribution(sample_mean=0.91, unsampled_value=0.6), + metric_types.MetricKey( + name="auc", is_diff=True + ): types.ValueWithTDistribution(sample_mean=0.1, unsampled_value=0.1), + }, + ) + result = metrics_validator.validate_metrics(sliced_metrics, eval_config) + self.assertFalse(result.validation_ok) + expected = text_format.Parse( + """ metric_validations_per_slice { failures { metric_key { @@ -345,77 +338,73 @@ def testValidateMetricsMetricTDistributionChangeAndThreshold( } } }""", - validation_result_pb2.ValidationResult(), - ) - expected.metric_validations_per_slice[0].failures[ - 0 - ].metric_threshold.CopyFrom(threshold) - expected.metric_validations_per_slice[0].slice_key.CopyFrom( - slicer.serialize_slice_key(slice_key) - ) - for spec in slicing_specs or [None]: - if spec is None or slicer.SingleSliceSpec(spec=spec).is_slice_applicable( - slice_key - ): - slicing_details = expected.validation_details.slicing_details.add() - if spec is not None: - slicing_details.slicing_spec.CopyFrom(spec) - else: - slicing_details.slicing_spec.CopyFrom(config_pb2.SlicingSpec()) - slicing_details.num_matching_slices = 1 - self.assertAlmostEqual(result, expected) + validation_result_pb2.ValidationResult(), + ) + expected.metric_validations_per_slice[0].failures[0].metric_threshold.CopyFrom( + threshold + ) + expected.metric_validations_per_slice[0].slice_key.CopyFrom( + slicer.serialize_slice_key(slice_key) + ) + for spec in slicing_specs or [None]: + if spec is None or slicer.SingleSliceSpec(spec=spec).is_slice_applicable( + slice_key + ): + slicing_details = expected.validation_details.slicing_details.add() + if spec is not None: + slicing_details.slicing_spec.CopyFrom(spec) + else: + slicing_details.slicing_spec.CopyFrom(config_pb2.SlicingSpec()) + slicing_details.num_matching_slices = 1 + self.assertAlmostEqual(result, expected) - @parameterized.named_parameters( - _NO_SLICE_TEST, - _GLOBAL_SLICE_TEST, - _FEATURE_SLICE_TEST, - _FEATURE_VALUE_SLICE_TEST, - _MULTIPLE_SLICES_TEST, - ) - def testValidateMetricsMetricValueAndThreshold( - self, slicing_specs, slice_key - ): - threshold = config_pb2.MetricThreshold( - value_threshold=config_pb2.GenericValueThreshold( - upper_bound={'value': 1} + @parameterized.named_parameters( + _NO_SLICE_TEST, + _GLOBAL_SLICE_TEST, + _FEATURE_SLICE_TEST, + _FEATURE_VALUE_SLICE_TEST, + _MULTIPLE_SLICES_TEST, + ) + def testValidateMetricsMetricValueAndThreshold(self, slicing_specs, slice_key): + threshold = config_pb2.MetricThreshold( + value_threshold=config_pb2.GenericValueThreshold(upper_bound={"value": 1}) ) - ) - eval_config = config_pb2.EvalConfig( - model_specs=[ - config_pb2.ModelSpec(), - ], - slicing_specs=slicing_specs, - metrics_specs=[ - config_pb2.MetricsSpec( - metrics=[ - config_pb2.MetricConfig( - class_name='WeightedExampleCount', - # 1.5 < 1, NOT OK. - threshold=threshold if slicing_specs is None else None, - per_slice_thresholds=[ - config_pb2.PerSliceMetricThreshold( - slicing_specs=slicing_specs, threshold=threshold - ) - ], - ), - ], - model_names=[''], - example_weights=config_pb2.ExampleWeightOptions(weighted=True), - ), - ], - ) - sliced_metrics = ( - slice_key, - { - metric_types.MetricKey( - name='weighted_example_count', example_weighted=True - ): 1.5, - }, - ) - result = metrics_validator.validate_metrics(sliced_metrics, eval_config) - self.assertFalse(result.validation_ok) - expected = text_format.Parse( - """ + eval_config = config_pb2.EvalConfig( + model_specs=[ + config_pb2.ModelSpec(), + ], + slicing_specs=slicing_specs, + metrics_specs=[ + config_pb2.MetricsSpec( + metrics=[ + config_pb2.MetricConfig( + class_name="WeightedExampleCount", + # 1.5 < 1, NOT OK. + threshold=threshold if slicing_specs is None else None, + per_slice_thresholds=[ + config_pb2.PerSliceMetricThreshold( + slicing_specs=slicing_specs, threshold=threshold + ) + ], + ), + ], + model_names=[""], + example_weights=config_pb2.ExampleWeightOptions(weighted=True), + ), + ], + ) + sliced_metrics = ( + slice_key, + { + metric_types.MetricKey( + name="weighted_example_count", example_weighted=True + ): 1.5, + }, + ) + result = metrics_validator.validate_metrics(sliced_metrics, eval_config) + self.assertFalse(result.validation_ok) + expected = text_format.Parse( + """ metric_validations_per_slice { failures { metric_key { @@ -429,792 +418,743 @@ def testValidateMetricsMetricValueAndThreshold( } } }""", - validation_result_pb2.ValidationResult(), - ) - expected.metric_validations_per_slice[0].failures[ - 0 - ].metric_threshold.CopyFrom(threshold) - expected.metric_validations_per_slice[0].slice_key.CopyFrom( - slicer.serialize_slice_key(slice_key) - ) - for spec in slicing_specs or [None]: - if spec is None or slicer.SingleSliceSpec(spec=spec).is_slice_applicable( - slice_key - ): - slicing_details = expected.validation_details.slicing_details.add() - if spec is not None: - slicing_details.slicing_spec.CopyFrom(spec) - else: - slicing_details.slicing_spec.CopyFrom(config_pb2.SlicingSpec()) - slicing_details.num_matching_slices = 1 - self.assertEqual(result, expected) + validation_result_pb2.ValidationResult(), + ) + expected.metric_validations_per_slice[0].failures[0].metric_threshold.CopyFrom( + threshold + ) + expected.metric_validations_per_slice[0].slice_key.CopyFrom( + slicer.serialize_slice_key(slice_key) + ) + for spec in slicing_specs or [None]: + if spec is None or slicer.SingleSliceSpec(spec=spec).is_slice_applicable( + slice_key + ): + slicing_details = expected.validation_details.slicing_details.add() + if spec is not None: + slicing_details.slicing_spec.CopyFrom(spec) + else: + slicing_details.slicing_spec.CopyFrom(config_pb2.SlicingSpec()) + slicing_details.num_matching_slices = 1 + self.assertEqual(result, expected) - @parameterized.named_parameters( - _UNMATCHED_SINGLE_SLICE_TEST, _UNMATCHED_MULTIPLE_SLICES_TEST - ) - def testValidateMetricsMetricValueAndThresholdIgnoreUnmatchedSlice( - self, slicing_specs, slice_key - ): - threshold = config_pb2.MetricThreshold( - value_threshold=config_pb2.GenericValueThreshold( - upper_bound={'value': 1} + @parameterized.named_parameters( + _UNMATCHED_SINGLE_SLICE_TEST, _UNMATCHED_MULTIPLE_SLICES_TEST + ) + def testValidateMetricsMetricValueAndThresholdIgnoreUnmatchedSlice( + self, slicing_specs, slice_key + ): + threshold = config_pb2.MetricThreshold( + value_threshold=config_pb2.GenericValueThreshold(upper_bound={"value": 1}) ) - ) - eval_config = config_pb2.EvalConfig( - model_specs=[ - config_pb2.ModelSpec(), - ], - slicing_specs=slicing_specs, - metrics_specs=[ - config_pb2.MetricsSpec( - metrics=[ - config_pb2.MetricConfig( - class_name='WeightedExampleCount', - # 1.5 < 1, NOT OK. - per_slice_thresholds=[ - config_pb2.PerSliceMetricThreshold( - slicing_specs=slicing_specs, threshold=threshold - ) - ], - ), - ], - model_names=[''], - example_weights=config_pb2.ExampleWeightOptions(weighted=True), - ), - ], - ) - sliced_metrics = ( - slice_key, - { - metric_types.MetricKey( - name='weighted_example_count', example_weighted=True - ): 1.5, - }, - ) - result = metrics_validator.validate_metrics(sliced_metrics, eval_config) - self.assertTrue(result.validation_ok) + eval_config = config_pb2.EvalConfig( + model_specs=[ + config_pb2.ModelSpec(), + ], + slicing_specs=slicing_specs, + metrics_specs=[ + config_pb2.MetricsSpec( + metrics=[ + config_pb2.MetricConfig( + class_name="WeightedExampleCount", + # 1.5 < 1, NOT OK. + per_slice_thresholds=[ + config_pb2.PerSliceMetricThreshold( + slicing_specs=slicing_specs, threshold=threshold + ) + ], + ), + ], + model_names=[""], + example_weights=config_pb2.ExampleWeightOptions(weighted=True), + ), + ], + ) + sliced_metrics = ( + slice_key, + { + metric_types.MetricKey( + name="weighted_example_count", example_weighted=True + ): 1.5, + }, + ) + result = metrics_validator.validate_metrics(sliced_metrics, eval_config) + self.assertTrue(result.validation_ok) - @parameterized.named_parameters( - _NO_SLICE_TEST, - _GLOBAL_SLICE_TEST, - _FEATURE_SLICE_TEST, - _FEATURE_VALUE_SLICE_TEST, - _MULTIPLE_SLICES_TEST, - ) - def testValidateMetricsValueThresholdUpperBoundFail( - self, slicing_specs, slice_key - ): - threshold = config_pb2.MetricThreshold( - value_threshold=config_pb2.GenericValueThreshold( - upper_bound={'value': 1} + @parameterized.named_parameters( + _NO_SLICE_TEST, + _GLOBAL_SLICE_TEST, + _FEATURE_SLICE_TEST, + _FEATURE_VALUE_SLICE_TEST, + _MULTIPLE_SLICES_TEST, + ) + def testValidateMetricsValueThresholdUpperBoundFail(self, slicing_specs, slice_key): + threshold = config_pb2.MetricThreshold( + value_threshold=config_pb2.GenericValueThreshold(upper_bound={"value": 1}) ) - ) - eval_config = config_pb2.EvalConfig( - model_specs=[ - config_pb2.ModelSpec(), - ], - slicing_specs=slicing_specs, - metrics_specs=[ - config_pb2.MetricsSpec( - metrics=[ - config_pb2.MetricConfig( - class_name='WeightedExampleCount', - # 1.5 < 1, NOT OK. - threshold=threshold if slicing_specs is None else None, - per_slice_thresholds=[ - config_pb2.PerSliceMetricThreshold( - slicing_specs=slicing_specs, threshold=threshold - ) - ], - ), - ], - model_names=[''], - example_weights=config_pb2.ExampleWeightOptions(weighted=True), - ), - ], - ) - sliced_metrics = ( - slice_key, - { - metric_types.MetricKey(name='weighted_example_count'): 1.5, - }, - ) - result = metrics_validator.validate_metrics(sliced_metrics, eval_config) - self.assertFalse(result.validation_ok) + eval_config = config_pb2.EvalConfig( + model_specs=[ + config_pb2.ModelSpec(), + ], + slicing_specs=slicing_specs, + metrics_specs=[ + config_pb2.MetricsSpec( + metrics=[ + config_pb2.MetricConfig( + class_name="WeightedExampleCount", + # 1.5 < 1, NOT OK. + threshold=threshold if slicing_specs is None else None, + per_slice_thresholds=[ + config_pb2.PerSliceMetricThreshold( + slicing_specs=slicing_specs, threshold=threshold + ) + ], + ), + ], + model_names=[""], + example_weights=config_pb2.ExampleWeightOptions(weighted=True), + ), + ], + ) + sliced_metrics = ( + slice_key, + { + metric_types.MetricKey(name="weighted_example_count"): 1.5, + }, + ) + result = metrics_validator.validate_metrics(sliced_metrics, eval_config) + self.assertFalse(result.validation_ok) - @parameterized.named_parameters( - _NO_SLICE_TEST, - _GLOBAL_SLICE_TEST, - _FEATURE_SLICE_TEST, - _FEATURE_VALUE_SLICE_TEST, - _MULTIPLE_SLICES_TEST, - ) - def testValidateMetricsValueThresholdLowerBoundFail( - self, slicing_specs, slice_key - ): - threshold = config_pb2.MetricThreshold( - value_threshold=config_pb2.GenericValueThreshold( - lower_bound={'value': 1} + @parameterized.named_parameters( + _NO_SLICE_TEST, + _GLOBAL_SLICE_TEST, + _FEATURE_SLICE_TEST, + _FEATURE_VALUE_SLICE_TEST, + _MULTIPLE_SLICES_TEST, + ) + def testValidateMetricsValueThresholdLowerBoundFail(self, slicing_specs, slice_key): + threshold = config_pb2.MetricThreshold( + value_threshold=config_pb2.GenericValueThreshold(lower_bound={"value": 1}) ) - ) - eval_config = config_pb2.EvalConfig( - model_specs=[ - config_pb2.ModelSpec(), - ], - slicing_specs=slicing_specs, - metrics_specs=[ - config_pb2.MetricsSpec( - metrics=[ - config_pb2.MetricConfig( - class_name='WeightedExampleCount', - # 0 > 1, NOT OK. - threshold=threshold if slicing_specs is None else None, - per_slice_thresholds=[ - config_pb2.PerSliceMetricThreshold( - slicing_specs=slicing_specs, threshold=threshold - ) - ], - ), - ], - model_names=[''], - example_weights=config_pb2.ExampleWeightOptions(weighted=True), - ), - ], - ) - sliced_metrics = ( - slice_key, - { - metric_types.MetricKey( - name='weighted_example_count', example_weighted=True - ): 0, - }, - ) - result = metrics_validator.validate_metrics(sliced_metrics, eval_config) - self.assertFalse(result.validation_ok) + eval_config = config_pb2.EvalConfig( + model_specs=[ + config_pb2.ModelSpec(), + ], + slicing_specs=slicing_specs, + metrics_specs=[ + config_pb2.MetricsSpec( + metrics=[ + config_pb2.MetricConfig( + class_name="WeightedExampleCount", + # 0 > 1, NOT OK. + threshold=threshold if slicing_specs is None else None, + per_slice_thresholds=[ + config_pb2.PerSliceMetricThreshold( + slicing_specs=slicing_specs, threshold=threshold + ) + ], + ), + ], + model_names=[""], + example_weights=config_pb2.ExampleWeightOptions(weighted=True), + ), + ], + ) + sliced_metrics = ( + slice_key, + { + metric_types.MetricKey( + name="weighted_example_count", example_weighted=True + ): 0, + }, + ) + result = metrics_validator.validate_metrics(sliced_metrics, eval_config) + self.assertFalse(result.validation_ok) - @parameterized.named_parameters( - _NO_SLICE_TEST, - _GLOBAL_SLICE_TEST, - _FEATURE_SLICE_TEST, - _FEATURE_VALUE_SLICE_TEST, - _MULTIPLE_SLICES_TEST, - ) - def testValidateMetricsValueThresholdUpperBoundPass( - self, slicing_specs, slice_key - ): - threshold = config_pb2.MetricThreshold( - value_threshold=config_pb2.GenericValueThreshold( - upper_bound={'value': 1} + @parameterized.named_parameters( + _NO_SLICE_TEST, + _GLOBAL_SLICE_TEST, + _FEATURE_SLICE_TEST, + _FEATURE_VALUE_SLICE_TEST, + _MULTIPLE_SLICES_TEST, + ) + def testValidateMetricsValueThresholdUpperBoundPass(self, slicing_specs, slice_key): + threshold = config_pb2.MetricThreshold( + value_threshold=config_pb2.GenericValueThreshold(upper_bound={"value": 1}) ) - ) - eval_config = config_pb2.EvalConfig( - model_specs=[ - config_pb2.ModelSpec(), - ], - slicing_specs=slicing_specs, - metrics_specs=[ - config_pb2.MetricsSpec( - metrics=[ - config_pb2.MetricConfig( - class_name='WeightedExampleCount', - # 0 < 1, OK. - threshold=threshold if slicing_specs is None else None, - per_slice_thresholds=[ - config_pb2.PerSliceMetricThreshold( - slicing_specs=slicing_specs, threshold=threshold - ) - ], - ), - ], - model_names=[''], - example_weights=config_pb2.ExampleWeightOptions(weighted=True), - ), - ], - ) - sliced_metrics = ( - slice_key, - { - metric_types.MetricKey( - name='weighted_example_count', example_weighted=True - ): 0, - }, - ) - result = metrics_validator.validate_metrics(sliced_metrics, eval_config) - self.assertTrue(result.validation_ok) + eval_config = config_pb2.EvalConfig( + model_specs=[ + config_pb2.ModelSpec(), + ], + slicing_specs=slicing_specs, + metrics_specs=[ + config_pb2.MetricsSpec( + metrics=[ + config_pb2.MetricConfig( + class_name="WeightedExampleCount", + # 0 < 1, OK. + threshold=threshold if slicing_specs is None else None, + per_slice_thresholds=[ + config_pb2.PerSliceMetricThreshold( + slicing_specs=slicing_specs, threshold=threshold + ) + ], + ), + ], + model_names=[""], + example_weights=config_pb2.ExampleWeightOptions(weighted=True), + ), + ], + ) + sliced_metrics = ( + slice_key, + { + metric_types.MetricKey( + name="weighted_example_count", example_weighted=True + ): 0, + }, + ) + result = metrics_validator.validate_metrics(sliced_metrics, eval_config) + self.assertTrue(result.validation_ok) - @parameterized.named_parameters( - _NO_SLICE_TEST, - _GLOBAL_SLICE_TEST, - _FEATURE_SLICE_TEST, - _FEATURE_VALUE_SLICE_TEST, - _MULTIPLE_SLICES_TEST, - ) - def testValidateMetricsValueThresholdLowerBoundPass( - self, slicing_specs, slice_key - ): - threshold = config_pb2.MetricThreshold( - value_threshold=config_pb2.GenericValueThreshold( - lower_bound={'value': 1} + @parameterized.named_parameters( + _NO_SLICE_TEST, + _GLOBAL_SLICE_TEST, + _FEATURE_SLICE_TEST, + _FEATURE_VALUE_SLICE_TEST, + _MULTIPLE_SLICES_TEST, + ) + def testValidateMetricsValueThresholdLowerBoundPass(self, slicing_specs, slice_key): + threshold = config_pb2.MetricThreshold( + value_threshold=config_pb2.GenericValueThreshold(lower_bound={"value": 1}) ) - ) - eval_config = config_pb2.EvalConfig( - model_specs=[ - config_pb2.ModelSpec(), - ], - slicing_specs=slicing_specs, - metrics_specs=[ - config_pb2.MetricsSpec( - metrics=[ - config_pb2.MetricConfig( - class_name='WeightedExampleCount', - # 2 > 1, OK. - threshold=threshold if slicing_specs is None else None, - per_slice_thresholds=[ - config_pb2.PerSliceMetricThreshold( - slicing_specs=slicing_specs, threshold=threshold - ) - ], - ), - ], - model_names=[''], - example_weights=config_pb2.ExampleWeightOptions(weighted=True), - ), - ], - ) - sliced_metrics = ( - slice_key, - { - metric_types.MetricKey( - name='weighted_example_count', example_weighted=True - ): 2, - }, - ) - result = metrics_validator.validate_metrics(sliced_metrics, eval_config) - self.assertTrue(result.validation_ok) + eval_config = config_pb2.EvalConfig( + model_specs=[ + config_pb2.ModelSpec(), + ], + slicing_specs=slicing_specs, + metrics_specs=[ + config_pb2.MetricsSpec( + metrics=[ + config_pb2.MetricConfig( + class_name="WeightedExampleCount", + # 2 > 1, OK. + threshold=threshold if slicing_specs is None else None, + per_slice_thresholds=[ + config_pb2.PerSliceMetricThreshold( + slicing_specs=slicing_specs, threshold=threshold + ) + ], + ), + ], + model_names=[""], + example_weights=config_pb2.ExampleWeightOptions(weighted=True), + ), + ], + ) + sliced_metrics = ( + slice_key, + { + metric_types.MetricKey( + name="weighted_example_count", example_weighted=True + ): 2, + }, + ) + result = metrics_validator.validate_metrics(sliced_metrics, eval_config) + self.assertTrue(result.validation_ok) - @parameterized.named_parameters( - _NO_SLICE_TEST, - _GLOBAL_SLICE_TEST, - _FEATURE_SLICE_TEST, - _FEATURE_VALUE_SLICE_TEST, - _MULTIPLE_SLICES_TEST, - ) - def testValidateMetricsChangeThresholdAbsoluteFail( - self, slicing_specs, slice_key - ): - threshold = config_pb2.MetricThreshold( - change_threshold=config_pb2.GenericChangeThreshold( - direction=config_pb2.MetricDirection.LOWER_IS_BETTER, - absolute={'value': -1}, + @parameterized.named_parameters( + _NO_SLICE_TEST, + _GLOBAL_SLICE_TEST, + _FEATURE_SLICE_TEST, + _FEATURE_VALUE_SLICE_TEST, + _MULTIPLE_SLICES_TEST, + ) + def testValidateMetricsChangeThresholdAbsoluteFail(self, slicing_specs, slice_key): + threshold = config_pb2.MetricThreshold( + change_threshold=config_pb2.GenericChangeThreshold( + direction=config_pb2.MetricDirection.LOWER_IS_BETTER, + absolute={"value": -1}, + ) ) - ) - eval_config = config_pb2.EvalConfig( - model_specs=[ - config_pb2.ModelSpec(), - config_pb2.ModelSpec(name='baseline', is_baseline=True), - ], - slicing_specs=slicing_specs, - metrics_specs=[ - config_pb2.MetricsSpec( - metrics=[ - config_pb2.MetricConfig( - class_name='MeanPrediction', - # Diff = 0 - .333 = -.333 < -1, NOT OK. - threshold=threshold if slicing_specs is None else None, - per_slice_thresholds=[ - config_pb2.PerSliceMetricThreshold( - slicing_specs=slicing_specs, threshold=threshold - ) - ], - ) - ], - model_names=[''], - ), - ], - ) - sliced_metrics = ( - slice_key, - { - metric_types.MetricKey( - name='mean_prediction', model_name='baseline' - ): 0.333, - metric_types.MetricKey( - name='mean_prediction', is_diff=True - ): -0.333, - }, - ) - result = metrics_validator.validate_metrics(sliced_metrics, eval_config) - self.assertFalse(result.validation_ok) + eval_config = config_pb2.EvalConfig( + model_specs=[ + config_pb2.ModelSpec(), + config_pb2.ModelSpec(name="baseline", is_baseline=True), + ], + slicing_specs=slicing_specs, + metrics_specs=[ + config_pb2.MetricsSpec( + metrics=[ + config_pb2.MetricConfig( + class_name="MeanPrediction", + # Diff = 0 - .333 = -.333 < -1, NOT OK. + threshold=threshold if slicing_specs is None else None, + per_slice_thresholds=[ + config_pb2.PerSliceMetricThreshold( + slicing_specs=slicing_specs, threshold=threshold + ) + ], + ) + ], + model_names=[""], + ), + ], + ) + sliced_metrics = ( + slice_key, + { + metric_types.MetricKey( + name="mean_prediction", model_name="baseline" + ): 0.333, + metric_types.MetricKey(name="mean_prediction", is_diff=True): -0.333, + }, + ) + result = metrics_validator.validate_metrics(sliced_metrics, eval_config) + self.assertFalse(result.validation_ok) - @parameterized.named_parameters( - _NO_SLICE_TEST, - _GLOBAL_SLICE_TEST, - _FEATURE_SLICE_TEST, - _FEATURE_VALUE_SLICE_TEST, - _MULTIPLE_SLICES_TEST, - ) - def testValidateMetricsChangeThresholdRelativeFail( - self, slicing_specs, slice_key - ): - threshold = config_pb2.MetricThreshold( - change_threshold=config_pb2.GenericChangeThreshold( - direction=config_pb2.MetricDirection.LOWER_IS_BETTER, - relative={'value': -2}, + @parameterized.named_parameters( + _NO_SLICE_TEST, + _GLOBAL_SLICE_TEST, + _FEATURE_SLICE_TEST, + _FEATURE_VALUE_SLICE_TEST, + _MULTIPLE_SLICES_TEST, + ) + def testValidateMetricsChangeThresholdRelativeFail(self, slicing_specs, slice_key): + threshold = config_pb2.MetricThreshold( + change_threshold=config_pb2.GenericChangeThreshold( + direction=config_pb2.MetricDirection.LOWER_IS_BETTER, + relative={"value": -2}, + ) ) - ) - eval_config = config_pb2.EvalConfig( - model_specs=[ - config_pb2.ModelSpec(), - config_pb2.ModelSpec(name='baseline', is_baseline=True), - ], - slicing_specs=slicing_specs, - metrics_specs=[ - config_pb2.MetricsSpec( - metrics=[ - config_pb2.MetricConfig( - class_name='MeanPrediction', - # Diff = -.333 - # Diff% = -.333/.333 = -100% < -200%, NOT OK. - threshold=threshold if slicing_specs is None else None, - per_slice_thresholds=[ - config_pb2.PerSliceMetricThreshold( - slicing_specs=slicing_specs, threshold=threshold - ) - ], - ) - ], - model_names=[''], - ), - ], - ) - sliced_metrics = ( - slice_key, - { - metric_types.MetricKey( - name='mean_prediction', model_name='baseline' - ): 0.333, - metric_types.MetricKey( - name='mean_prediction', is_diff=True - ): -0.333, - }, - ) - result = metrics_validator.validate_metrics(sliced_metrics, eval_config) - self.assertFalse(result.validation_ok) + eval_config = config_pb2.EvalConfig( + model_specs=[ + config_pb2.ModelSpec(), + config_pb2.ModelSpec(name="baseline", is_baseline=True), + ], + slicing_specs=slicing_specs, + metrics_specs=[ + config_pb2.MetricsSpec( + metrics=[ + config_pb2.MetricConfig( + class_name="MeanPrediction", + # Diff = -.333 + # Diff% = -.333/.333 = -100% < -200%, NOT OK. + threshold=threshold if slicing_specs is None else None, + per_slice_thresholds=[ + config_pb2.PerSliceMetricThreshold( + slicing_specs=slicing_specs, threshold=threshold + ) + ], + ) + ], + model_names=[""], + ), + ], + ) + sliced_metrics = ( + slice_key, + { + metric_types.MetricKey( + name="mean_prediction", model_name="baseline" + ): 0.333, + metric_types.MetricKey(name="mean_prediction", is_diff=True): -0.333, + }, + ) + result = metrics_validator.validate_metrics(sliced_metrics, eval_config) + self.assertFalse(result.validation_ok) - @parameterized.named_parameters( - _NO_SLICE_TEST, - _GLOBAL_SLICE_TEST, - _FEATURE_SLICE_TEST, - _FEATURE_VALUE_SLICE_TEST, - _MULTIPLE_SLICES_TEST, - ) - def testValidateMetricsChangeThresholdAbsolutePass( - self, slicing_specs, slice_key - ): - threshold = config_pb2.MetricThreshold( - change_threshold=config_pb2.GenericChangeThreshold( - direction=config_pb2.MetricDirection.LOWER_IS_BETTER, - absolute={'value': 0}, + @parameterized.named_parameters( + _NO_SLICE_TEST, + _GLOBAL_SLICE_TEST, + _FEATURE_SLICE_TEST, + _FEATURE_VALUE_SLICE_TEST, + _MULTIPLE_SLICES_TEST, + ) + def testValidateMetricsChangeThresholdAbsolutePass(self, slicing_specs, slice_key): + threshold = config_pb2.MetricThreshold( + change_threshold=config_pb2.GenericChangeThreshold( + direction=config_pb2.MetricDirection.LOWER_IS_BETTER, + absolute={"value": 0}, + ) ) - ) - eval_config = config_pb2.EvalConfig( - model_specs=[ - config_pb2.ModelSpec(), - config_pb2.ModelSpec(name='baseline', is_baseline=True), - ], - slicing_specs=slicing_specs, - metrics_specs=[ - config_pb2.MetricsSpec( - metrics=[ - config_pb2.MetricConfig( - class_name='MeanPrediction', - # Diff = 0 - .333 = -.333 < 0, OK. - threshold=threshold if slicing_specs is None else None, - per_slice_thresholds=[ - config_pb2.PerSliceMetricThreshold( - slicing_specs=slicing_specs, threshold=threshold - ) - ], - ) - ], - model_names=[''], - ), - ], - ) - sliced_metrics = ( - slice_key, - { - metric_types.MetricKey( - name='mean_prediction', model_name='baseline' - ): 0.333, - metric_types.MetricKey( - name='mean_prediction', is_diff=True - ): -0.333, - }, - ) - result = metrics_validator.validate_metrics(sliced_metrics, eval_config) - self.assertTrue(result.validation_ok) + eval_config = config_pb2.EvalConfig( + model_specs=[ + config_pb2.ModelSpec(), + config_pb2.ModelSpec(name="baseline", is_baseline=True), + ], + slicing_specs=slicing_specs, + metrics_specs=[ + config_pb2.MetricsSpec( + metrics=[ + config_pb2.MetricConfig( + class_name="MeanPrediction", + # Diff = 0 - .333 = -.333 < 0, OK. + threshold=threshold if slicing_specs is None else None, + per_slice_thresholds=[ + config_pb2.PerSliceMetricThreshold( + slicing_specs=slicing_specs, threshold=threshold + ) + ], + ) + ], + model_names=[""], + ), + ], + ) + sliced_metrics = ( + slice_key, + { + metric_types.MetricKey( + name="mean_prediction", model_name="baseline" + ): 0.333, + metric_types.MetricKey(name="mean_prediction", is_diff=True): -0.333, + }, + ) + result = metrics_validator.validate_metrics(sliced_metrics, eval_config) + self.assertTrue(result.validation_ok) - @parameterized.named_parameters( - _NO_SLICE_TEST, - _GLOBAL_SLICE_TEST, - _FEATURE_SLICE_TEST, - _FEATURE_VALUE_SLICE_TEST, - _MULTIPLE_SLICES_TEST, - ) - def testValidateNativeDiffMetricsChangeThresholdAbsolutePass( - self, slicing_specs, slice_key - ): - # We must import metric so that it is registered and can be deserialized - # from class name specified in the config. - from tensorflow_model_analysis.metrics import model_cosine_similarity # pylint: disable=g-import-not-at-top, unused-import - # Diff = 0.99 >= 0.9, OK. - threshold = config_pb2.MetricThreshold( - change_threshold=config_pb2.GenericChangeThreshold( - direction=config_pb2.MetricDirection.HIGHER_IS_BETTER, - absolute={'value': 0.9}, + @parameterized.named_parameters( + _NO_SLICE_TEST, + _GLOBAL_SLICE_TEST, + _FEATURE_SLICE_TEST, + _FEATURE_VALUE_SLICE_TEST, + _MULTIPLE_SLICES_TEST, + ) + def testValidateNativeDiffMetricsChangeThresholdAbsolutePass( + self, slicing_specs, slice_key + ): + # We must import metric so that it is registered and can be deserialized + # from class name specified in the config. + # Diff = 0.99 >= 0.9, OK. + threshold = config_pb2.MetricThreshold( + change_threshold=config_pb2.GenericChangeThreshold( + direction=config_pb2.MetricDirection.HIGHER_IS_BETTER, + absolute={"value": 0.9}, + ) ) - ) - eval_config = config_pb2.EvalConfig( - model_specs=[ - config_pb2.ModelSpec(), - config_pb2.ModelSpec(name='baseline', is_baseline=True), - ], - slicing_specs=slicing_specs, - metrics_specs=[ - config_pb2.MetricsSpec( - metrics=[ - config_pb2.MetricConfig( - class_name='ModelCosineSimilarity', - threshold=threshold if slicing_specs is None else None, - per_slice_thresholds=[ - config_pb2.PerSliceMetricThreshold( - slicing_specs=slicing_specs, threshold=threshold - ) - ], - ) - ], - model_names=[''], - ), - ], - ) - sliced_metrics = ( - slice_key, - { - metric_types.MetricKey( - name='model_cosine_similarity', is_diff=True - ): 0.99, - }, - ) - result = metrics_validator.validate_metrics(sliced_metrics, eval_config) - self.assertTrue(result.validation_ok) + eval_config = config_pb2.EvalConfig( + model_specs=[ + config_pb2.ModelSpec(), + config_pb2.ModelSpec(name="baseline", is_baseline=True), + ], + slicing_specs=slicing_specs, + metrics_specs=[ + config_pb2.MetricsSpec( + metrics=[ + config_pb2.MetricConfig( + class_name="ModelCosineSimilarity", + threshold=threshold if slicing_specs is None else None, + per_slice_thresholds=[ + config_pb2.PerSliceMetricThreshold( + slicing_specs=slicing_specs, threshold=threshold + ) + ], + ) + ], + model_names=[""], + ), + ], + ) + sliced_metrics = ( + slice_key, + { + metric_types.MetricKey( + name="model_cosine_similarity", is_diff=True + ): 0.99, + }, + ) + result = metrics_validator.validate_metrics(sliced_metrics, eval_config) + self.assertTrue(result.validation_ok) - @parameterized.named_parameters( - _NO_SLICE_TEST, - _GLOBAL_SLICE_TEST, - _FEATURE_SLICE_TEST, - _FEATURE_VALUE_SLICE_TEST, - _MULTIPLE_SLICES_TEST, - ) - def testValidateMetricsChangeThresholdRelativePass( - self, slicing_specs, slice_key - ): - threshold = config_pb2.MetricThreshold( - change_threshold=config_pb2.GenericChangeThreshold( - direction=config_pb2.MetricDirection.LOWER_IS_BETTER, - relative={'value': 0}, + @parameterized.named_parameters( + _NO_SLICE_TEST, + _GLOBAL_SLICE_TEST, + _FEATURE_SLICE_TEST, + _FEATURE_VALUE_SLICE_TEST, + _MULTIPLE_SLICES_TEST, + ) + def testValidateMetricsChangeThresholdRelativePass(self, slicing_specs, slice_key): + threshold = config_pb2.MetricThreshold( + change_threshold=config_pb2.GenericChangeThreshold( + direction=config_pb2.MetricDirection.LOWER_IS_BETTER, + relative={"value": 0}, + ) ) - ) - eval_config = config_pb2.EvalConfig( - model_specs=[ - config_pb2.ModelSpec(), - config_pb2.ModelSpec(name='baseline', is_baseline=True), - ], - slicing_specs=slicing_specs, - metrics_specs=[ - config_pb2.MetricsSpec( - metrics=[ - config_pb2.MetricConfig( - class_name='MeanPrediction', - # Diff = -.333 - # Diff% = -.333/.333 = -100% < 0%, OK. - threshold=threshold if slicing_specs is None else None, - per_slice_thresholds=[ - config_pb2.PerSliceMetricThreshold( - slicing_specs=slicing_specs, threshold=threshold - ) - ], - ) - ], - model_names=[''], - ), - ], - ) - sliced_metrics = ( - slice_key, - { - metric_types.MetricKey( - name='mean_prediction', model_name='baseline' - ): 0.333, - metric_types.MetricKey( - name='mean_prediction', is_diff=True - ): -0.333, - }, - ) - result = metrics_validator.validate_metrics(sliced_metrics, eval_config) - self.assertTrue(result.validation_ok) + eval_config = config_pb2.EvalConfig( + model_specs=[ + config_pb2.ModelSpec(), + config_pb2.ModelSpec(name="baseline", is_baseline=True), + ], + slicing_specs=slicing_specs, + metrics_specs=[ + config_pb2.MetricsSpec( + metrics=[ + config_pb2.MetricConfig( + class_name="MeanPrediction", + # Diff = -.333 + # Diff% = -.333/.333 = -100% < 0%, OK. + threshold=threshold if slicing_specs is None else None, + per_slice_thresholds=[ + config_pb2.PerSliceMetricThreshold( + slicing_specs=slicing_specs, threshold=threshold + ) + ], + ) + ], + model_names=[""], + ), + ], + ) + sliced_metrics = ( + slice_key, + { + metric_types.MetricKey( + name="mean_prediction", model_name="baseline" + ): 0.333, + metric_types.MetricKey(name="mean_prediction", is_diff=True): -0.333, + }, + ) + result = metrics_validator.validate_metrics(sliced_metrics, eval_config) + self.assertTrue(result.validation_ok) - @parameterized.named_parameters( - _NO_SLICE_TEST, - _GLOBAL_SLICE_TEST, - _FEATURE_SLICE_TEST, - _FEATURE_VALUE_SLICE_TEST, - _MULTIPLE_SLICES_TEST, - ) - def testValidateMetricsChangeThresholdHigherIsBetterPass( - self, slicing_specs, slice_key - ): - threshold = config_pb2.MetricThreshold( - change_threshold=config_pb2.GenericChangeThreshold( - direction=config_pb2.MetricDirection.HIGHER_IS_BETTER, - absolute={'value': -1}, + @parameterized.named_parameters( + _NO_SLICE_TEST, + _GLOBAL_SLICE_TEST, + _FEATURE_SLICE_TEST, + _FEATURE_VALUE_SLICE_TEST, + _MULTIPLE_SLICES_TEST, + ) + def testValidateMetricsChangeThresholdHigherIsBetterPass( + self, slicing_specs, slice_key + ): + threshold = config_pb2.MetricThreshold( + change_threshold=config_pb2.GenericChangeThreshold( + direction=config_pb2.MetricDirection.HIGHER_IS_BETTER, + absolute={"value": -1}, + ) ) - ) - eval_config = config_pb2.EvalConfig( - model_specs=[ - config_pb2.ModelSpec(), - config_pb2.ModelSpec(name='baseline', is_baseline=True), - ], - slicing_specs=slicing_specs, - metrics_specs=[ - config_pb2.MetricsSpec( - metrics=[ - config_pb2.MetricConfig( - class_name='MeanPrediction', - # Diff = -.333 > -1, OK. - threshold=threshold if slicing_specs is None else None, - per_slice_thresholds=[ - config_pb2.PerSliceMetricThreshold( - slicing_specs=slicing_specs, threshold=threshold - ) - ], - ) - ], - model_names=[''], - ), - ], - ) - sliced_metrics = ( - slice_key, - { - metric_types.MetricKey( - name='mean_prediction', model_name='baseline' - ): 0.333, - metric_types.MetricKey( - name='mean_prediction', is_diff=True - ): -0.333, - }, - ) - result = metrics_validator.validate_metrics(sliced_metrics, eval_config) - self.assertTrue(result.validation_ok) + eval_config = config_pb2.EvalConfig( + model_specs=[ + config_pb2.ModelSpec(), + config_pb2.ModelSpec(name="baseline", is_baseline=True), + ], + slicing_specs=slicing_specs, + metrics_specs=[ + config_pb2.MetricsSpec( + metrics=[ + config_pb2.MetricConfig( + class_name="MeanPrediction", + # Diff = -.333 > -1, OK. + threshold=threshold if slicing_specs is None else None, + per_slice_thresholds=[ + config_pb2.PerSliceMetricThreshold( + slicing_specs=slicing_specs, threshold=threshold + ) + ], + ) + ], + model_names=[""], + ), + ], + ) + sliced_metrics = ( + slice_key, + { + metric_types.MetricKey( + name="mean_prediction", model_name="baseline" + ): 0.333, + metric_types.MetricKey(name="mean_prediction", is_diff=True): -0.333, + }, + ) + result = metrics_validator.validate_metrics(sliced_metrics, eval_config) + self.assertTrue(result.validation_ok) - @parameterized.named_parameters( - _NO_SLICE_TEST, - _GLOBAL_SLICE_TEST, - _FEATURE_SLICE_TEST, - _FEATURE_VALUE_SLICE_TEST, - _MULTIPLE_SLICES_TEST, - ) - def testValidateMetricsChangeThresholdEqualPass( - self, slicing_specs, slice_key - ): - # Change thresholds. - threshold1 = config_pb2.MetricThreshold( - change_threshold=config_pb2.GenericChangeThreshold( - direction=config_pb2.MetricDirection.HIGHER_IS_BETTER, - absolute={'value': -0.333}, - relative={'value': -0.333}, + @parameterized.named_parameters( + _NO_SLICE_TEST, + _GLOBAL_SLICE_TEST, + _FEATURE_SLICE_TEST, + _FEATURE_VALUE_SLICE_TEST, + _MULTIPLE_SLICES_TEST, + ) + def testValidateMetricsChangeThresholdEqualPass(self, slicing_specs, slice_key): + # Change thresholds. + threshold1 = config_pb2.MetricThreshold( + change_threshold=config_pb2.GenericChangeThreshold( + direction=config_pb2.MetricDirection.HIGHER_IS_BETTER, + absolute={"value": -0.333}, + relative={"value": -0.333}, + ) ) - ) - threshold2 = config_pb2.MetricThreshold( - change_threshold=config_pb2.GenericChangeThreshold( - direction=config_pb2.MetricDirection.LOWER_IS_BETTER, - absolute={'value': -0.333}, - relative={'value': -0.333}, + threshold2 = config_pb2.MetricThreshold( + change_threshold=config_pb2.GenericChangeThreshold( + direction=config_pb2.MetricDirection.LOWER_IS_BETTER, + absolute={"value": -0.333}, + relative={"value": -0.333}, + ) ) - ) - # Value thresholds. - threshold3 = config_pb2.MetricThreshold( - value_threshold=config_pb2.GenericValueThreshold( - lower_bound={'value': 1} + # Value thresholds. + threshold3 = config_pb2.MetricThreshold( + value_threshold=config_pb2.GenericValueThreshold(lower_bound={"value": 1}) ) - ) - threshold4 = config_pb2.MetricThreshold( - value_threshold=config_pb2.GenericValueThreshold( - upper_bound={'value': 1} + threshold4 = config_pb2.MetricThreshold( + value_threshold=config_pb2.GenericValueThreshold(upper_bound={"value": 1}) ) - ) - eval_config = config_pb2.EvalConfig( - model_specs=[ - config_pb2.ModelSpec(name='candidate'), - config_pb2.ModelSpec(name='baseline', is_baseline=True), - ], - slicing_specs=slicing_specs, - metrics_specs=[ - config_pb2.MetricsSpec( - metrics=[ - config_pb2.MetricConfig( - class_name='MeanPrediction', - # Diff = -.333 == -.333, OK. - threshold=threshold1 if slicing_specs is None else None, - per_slice_thresholds=[ - config_pb2.PerSliceMetricThreshold( - slicing_specs=slicing_specs, - threshold=threshold1, - ) - ], - ), - config_pb2.MetricConfig( - class_name='MeanLabel', - # Diff = -.333 == -.333, OK. - threshold=threshold2 if slicing_specs is None else None, - per_slice_thresholds=[ - config_pb2.PerSliceMetricThreshold( - slicing_specs=slicing_specs, - threshold=threshold2, - ) - ], - ), - ], - model_names=['candidate'], - ), - config_pb2.MetricsSpec( - metrics=[ - config_pb2.MetricConfig( - class_name='ExampleCount', - # 1 == 1, OK. - threshold=threshold3 if slicing_specs is None else None, - per_slice_thresholds=[ - config_pb2.PerSliceMetricThreshold( - slicing_specs=slicing_specs, - threshold=threshold3, - ) - ], - ) - ], - model_names=['candidate'], - example_weights=config_pb2.ExampleWeightOptions( - unweighted=True + eval_config = config_pb2.EvalConfig( + model_specs=[ + config_pb2.ModelSpec(name="candidate"), + config_pb2.ModelSpec(name="baseline", is_baseline=True), + ], + slicing_specs=slicing_specs, + metrics_specs=[ + config_pb2.MetricsSpec( + metrics=[ + config_pb2.MetricConfig( + class_name="MeanPrediction", + # Diff = -.333 == -.333, OK. + threshold=threshold1 if slicing_specs is None else None, + per_slice_thresholds=[ + config_pb2.PerSliceMetricThreshold( + slicing_specs=slicing_specs, + threshold=threshold1, + ) + ], + ), + config_pb2.MetricConfig( + class_name="MeanLabel", + # Diff = -.333 == -.333, OK. + threshold=threshold2 if slicing_specs is None else None, + per_slice_thresholds=[ + config_pb2.PerSliceMetricThreshold( + slicing_specs=slicing_specs, + threshold=threshold2, + ) + ], + ), + ], + model_names=["candidate"], ), - ), - config_pb2.MetricsSpec( - metrics=[ - config_pb2.MetricConfig( - class_name='WeightedExampleCount', - # 1 == 1, OK. - threshold=threshold4 if slicing_specs is None else None, - per_slice_thresholds=[ - config_pb2.PerSliceMetricThreshold( - slicing_specs=slicing_specs, - threshold=threshold4, - ) - ], - ) - ], - model_names=['candidate'], - example_weights=config_pb2.ExampleWeightOptions(weighted=True), - ), - ], - ) - sliced_metrics = ( - slice_key, - { - metric_types.MetricKey( - name='mean_prediction', model_name='candidate' - ): 0.677, - metric_types.MetricKey( - name='mean_prediction', model_name='baseline' - ): 1, - metric_types.MetricKey( - name='mean_prediction', is_diff=True, model_name='candidate' - ): -0.333, - metric_types.MetricKey( - name='mean_label', model_name='candidate' - ): 0.677, - metric_types.MetricKey(name='mean_label', model_name='baseline'): 1, - metric_types.MetricKey( - name='mean_label', is_diff=True, model_name='candidate' - ): -0.333, - metric_types.MetricKey( - name='example_count', model_name='candidate' - ): 1, - metric_types.MetricKey( - name='weighted_example_count', - model_name='candidate', - example_weighted=True, - ): 1, - }, - ) - result = metrics_validator.validate_metrics(sliced_metrics, eval_config) - self.assertTrue(result.validation_ok) + config_pb2.MetricsSpec( + metrics=[ + config_pb2.MetricConfig( + class_name="ExampleCount", + # 1 == 1, OK. + threshold=threshold3 if slicing_specs is None else None, + per_slice_thresholds=[ + config_pb2.PerSliceMetricThreshold( + slicing_specs=slicing_specs, + threshold=threshold3, + ) + ], + ) + ], + model_names=["candidate"], + example_weights=config_pb2.ExampleWeightOptions(unweighted=True), + ), + config_pb2.MetricsSpec( + metrics=[ + config_pb2.MetricConfig( + class_name="WeightedExampleCount", + # 1 == 1, OK. + threshold=threshold4 if slicing_specs is None else None, + per_slice_thresholds=[ + config_pb2.PerSliceMetricThreshold( + slicing_specs=slicing_specs, + threshold=threshold4, + ) + ], + ) + ], + model_names=["candidate"], + example_weights=config_pb2.ExampleWeightOptions(weighted=True), + ), + ], + ) + sliced_metrics = ( + slice_key, + { + metric_types.MetricKey( + name="mean_prediction", model_name="candidate" + ): 0.677, + metric_types.MetricKey( + name="mean_prediction", model_name="baseline" + ): 1, + metric_types.MetricKey( + name="mean_prediction", is_diff=True, model_name="candidate" + ): -0.333, + metric_types.MetricKey( + name="mean_label", model_name="candidate" + ): 0.677, + metric_types.MetricKey(name="mean_label", model_name="baseline"): 1, + metric_types.MetricKey( + name="mean_label", is_diff=True, model_name="candidate" + ): -0.333, + metric_types.MetricKey(name="example_count", model_name="candidate"): 1, + metric_types.MetricKey( + name="weighted_example_count", + model_name="candidate", + example_weighted=True, + ): 1, + }, + ) + result = metrics_validator.validate_metrics(sliced_metrics, eval_config) + self.assertTrue(result.validation_ok) - @parameterized.named_parameters( - _NO_SLICE_TEST, - _GLOBAL_SLICE_TEST, - _FEATURE_SLICE_TEST, - _FEATURE_VALUE_SLICE_TEST, - _MULTIPLE_SLICES_TEST, - ) - def testValidateMetricsChangeThresholdHigherIsBetterFail( - self, slicing_specs, slice_key - ): - threshold = config_pb2.MetricThreshold( - change_threshold=config_pb2.GenericChangeThreshold( - direction=config_pb2.MetricDirection.HIGHER_IS_BETTER, - absolute={'value': 0}, + @parameterized.named_parameters( + _NO_SLICE_TEST, + _GLOBAL_SLICE_TEST, + _FEATURE_SLICE_TEST, + _FEATURE_VALUE_SLICE_TEST, + _MULTIPLE_SLICES_TEST, + ) + def testValidateMetricsChangeThresholdHigherIsBetterFail( + self, slicing_specs, slice_key + ): + threshold = config_pb2.MetricThreshold( + change_threshold=config_pb2.GenericChangeThreshold( + direction=config_pb2.MetricDirection.HIGHER_IS_BETTER, + absolute={"value": 0}, + ) ) - ) - eval_config = config_pb2.EvalConfig( - model_specs=[ - config_pb2.ModelSpec(), - config_pb2.ModelSpec(name='baseline', is_baseline=True), - ], - slicing_specs=slicing_specs, - metrics_specs=[ - config_pb2.MetricsSpec( - metrics=[ - config_pb2.MetricConfig( - class_name='MeanPrediction', - # Diff = -.333 > 0, NOT OK. - threshold=threshold if slicing_specs is None else None, - per_slice_thresholds=[ - config_pb2.PerSliceMetricThreshold( - slicing_specs=slicing_specs, threshold=threshold - ) - ], - ) - ], - model_names=[''], - ), - ], - ) - sliced_metrics = ( - slice_key, - { - metric_types.MetricKey( - name='mean_prediction', model_name='baseline' - ): 0.333, - metric_types.MetricKey( - name='mean_prediction', is_diff=True - ): -0.333, - }, - ) - result = metrics_validator.validate_metrics(sliced_metrics, eval_config) - self.assertFalse(result.validation_ok) + eval_config = config_pb2.EvalConfig( + model_specs=[ + config_pb2.ModelSpec(), + config_pb2.ModelSpec(name="baseline", is_baseline=True), + ], + slicing_specs=slicing_specs, + metrics_specs=[ + config_pb2.MetricsSpec( + metrics=[ + config_pb2.MetricConfig( + class_name="MeanPrediction", + # Diff = -.333 > 0, NOT OK. + threshold=threshold if slicing_specs is None else None, + per_slice_thresholds=[ + config_pb2.PerSliceMetricThreshold( + slicing_specs=slicing_specs, threshold=threshold + ) + ], + ) + ], + model_names=[""], + ), + ], + ) + sliced_metrics = ( + slice_key, + { + metric_types.MetricKey( + name="mean_prediction", model_name="baseline" + ): 0.333, + metric_types.MetricKey(name="mean_prediction", is_diff=True): -0.333, + }, + ) + result = metrics_validator.validate_metrics(sliced_metrics, eval_config) + self.assertFalse(result.validation_ok) - def testMergeDetails(self): - a = text_format.Parse( - """ + def testMergeDetails(self): + a = text_format.Parse( + """ validation_details { slicing_details { slicing_spec {} @@ -1227,11 +1167,11 @@ def testMergeDetails(self): num_matching_slices: 1 } }""", - validation_result_pb2.ValidationResult(), - ) + validation_result_pb2.ValidationResult(), + ) - b = text_format.Parse( - """ + b = text_format.Parse( + """ validation_details { slicing_details { slicing_spec { @@ -1246,11 +1186,11 @@ def testMergeDetails(self): num_matching_slices: 2 } }""", - validation_result_pb2.ValidationResult(), - ) + validation_result_pb2.ValidationResult(), + ) - expected = text_format.Parse( - """ + expected = text_format.Parse( + """ validation_details { slicing_details { slicing_spec {} @@ -1269,58 +1209,56 @@ def testMergeDetails(self): num_matching_slices: 1 } }""", - validation_result_pb2.ValidationResult(), - ) + validation_result_pb2.ValidationResult(), + ) - metrics_validator.merge_details(a, b) - self.assertProtoEquals(expected, a) + metrics_validator.merge_details(a, b) + self.assertProtoEquals(expected, a) - def testGetMissingSlices(self): - slicing_specs = [ - config_pb2.SlicingSpec(), - config_pb2.SlicingSpec(feature_values={'feature1': 'value1'}), - config_pb2.SlicingSpec(feature_values={'feature2': 'value2'}), - ] - threshold = config_pb2.MetricThreshold( - value_threshold=config_pb2.GenericValueThreshold( - upper_bound={'value': 1} + def testGetMissingSlices(self): + slicing_specs = [ + config_pb2.SlicingSpec(), + config_pb2.SlicingSpec(feature_values={"feature1": "value1"}), + config_pb2.SlicingSpec(feature_values={"feature2": "value2"}), + ] + threshold = config_pb2.MetricThreshold( + value_threshold=config_pb2.GenericValueThreshold(upper_bound={"value": 1}) ) - ) - eval_config = config_pb2.EvalConfig( - model_specs=[ - config_pb2.ModelSpec(), - ], - slicing_specs=slicing_specs, - metrics_specs=[ - config_pb2.MetricsSpec( - metrics=[ - config_pb2.MetricConfig( - class_name='WeightedExampleCount', - # 1.5 < 1, NOT OK. - per_slice_thresholds=[ - config_pb2.PerSliceMetricThreshold( - slicing_specs=slicing_specs, threshold=threshold - ) - ], - ), - ], - model_names=[''], - example_weights=config_pb2.ExampleWeightOptions(weighted=True), - ), - ], - ) - sliced_metrics = ( - (('feature1', 'value1'),), - { - metric_types.MetricKey( - name='weighted_example_count', example_weighted=True - ): 0, - }, - ) - result = metrics_validator.validate_metrics(sliced_metrics, eval_config) + eval_config = config_pb2.EvalConfig( + model_specs=[ + config_pb2.ModelSpec(), + ], + slicing_specs=slicing_specs, + metrics_specs=[ + config_pb2.MetricsSpec( + metrics=[ + config_pb2.MetricConfig( + class_name="WeightedExampleCount", + # 1.5 < 1, NOT OK. + per_slice_thresholds=[ + config_pb2.PerSliceMetricThreshold( + slicing_specs=slicing_specs, threshold=threshold + ) + ], + ), + ], + model_names=[""], + example_weights=config_pb2.ExampleWeightOptions(weighted=True), + ), + ], + ) + sliced_metrics = ( + (("feature1", "value1"),), + { + metric_types.MetricKey( + name="weighted_example_count", example_weighted=True + ): 0, + }, + ) + result = metrics_validator.validate_metrics(sliced_metrics, eval_config) - expected_checks = text_format.Parse( - """ + expected_checks = text_format.Parse( + """ validation_ok: true validation_details { slicing_details { @@ -1333,217 +1271,211 @@ def testGetMissingSlices(self): num_matching_slices: 1 } }""", - validation_result_pb2.ValidationResult(), - ) + validation_result_pb2.ValidationResult(), + ) - self.assertProtoEquals(expected_checks, result) + self.assertProtoEquals(expected_checks, result) - missing = metrics_validator.get_missing_slices( - result.validation_details.slicing_details, eval_config - ) - self.assertLen(missing, 2) - self.assertProtoEquals(missing[0], slicing_specs[0]) - self.assertProtoEquals(missing[1], slicing_specs[2]) + missing = metrics_validator.get_missing_slices( + result.validation_details.slicing_details, eval_config + ) + self.assertLen(missing, 2) + self.assertProtoEquals(missing[0], slicing_specs[0]) + self.assertProtoEquals(missing[1], slicing_specs[2]) - @parameterized.named_parameters( - _NO_SLICE_TEST, - _SINGLE_CROSS_SLICE_TEST, - _CROSS_SLICE_GLOBAL_TEST, - _MULTIPLE_CROSS_SLICE_TEST, - _CROSS_SLICE_MULTIPLE_SLICING_SPEC_TEST, - ) - def testValidateMetricsCrossSliceThresholdPass( - self, cross_slicing_specs, slice_key - ): - threshold = config_pb2.MetricThreshold( - value_threshold=config_pb2.GenericValueThreshold( - upper_bound={'value': 1} + @parameterized.named_parameters( + _NO_SLICE_TEST, + _SINGLE_CROSS_SLICE_TEST, + _CROSS_SLICE_GLOBAL_TEST, + _MULTIPLE_CROSS_SLICE_TEST, + _CROSS_SLICE_MULTIPLE_SLICING_SPEC_TEST, + ) + def testValidateMetricsCrossSliceThresholdPass( + self, cross_slicing_specs, slice_key + ): + threshold = config_pb2.MetricThreshold( + value_threshold=config_pb2.GenericValueThreshold(upper_bound={"value": 1}) ) - ) - eval_config = config_pb2.EvalConfig( - model_specs=[ - config_pb2.ModelSpec(), - ], - cross_slicing_specs=cross_slicing_specs, - metrics_specs=[ - config_pb2.MetricsSpec( - metrics=[ - config_pb2.MetricConfig( - class_name='WeightedExampleCount', - # 1.5 < 1, NOT OK. - threshold=( - threshold if cross_slicing_specs is None else None + eval_config = config_pb2.EvalConfig( + model_specs=[ + config_pb2.ModelSpec(), + ], + cross_slicing_specs=cross_slicing_specs, + metrics_specs=[ + config_pb2.MetricsSpec( + metrics=[ + config_pb2.MetricConfig( + class_name="WeightedExampleCount", + # 1.5 < 1, NOT OK. + threshold=( + threshold if cross_slicing_specs is None else None + ), + cross_slice_thresholds=[ + config_pb2.CrossSliceMetricThreshold( + cross_slicing_specs=cross_slicing_specs, + threshold=threshold, + ) + ], ), - cross_slice_thresholds=[ - config_pb2.CrossSliceMetricThreshold( - cross_slicing_specs=cross_slicing_specs, - threshold=threshold, - ) - ], - ), - ], - model_names=[''], - example_weights=config_pb2.ExampleWeightOptions(weighted=True), - ), - ], - ) - sliced_metrics = ( - slice_key, - { - metric_types.MetricKey( - name='weighted_example_count', example_weighted=True - ): 0, - }, - ) - result = metrics_validator.validate_metrics(sliced_metrics, eval_config) - self.assertTrue(result.validation_ok) + ], + model_names=[""], + example_weights=config_pb2.ExampleWeightOptions(weighted=True), + ), + ], + ) + sliced_metrics = ( + slice_key, + { + metric_types.MetricKey( + name="weighted_example_count", example_weighted=True + ): 0, + }, + ) + result = metrics_validator.validate_metrics(sliced_metrics, eval_config) + self.assertTrue(result.validation_ok) - @parameterized.named_parameters( - _NO_SLICE_TEST, - _SINGLE_CROSS_SLICE_TEST, - _CROSS_SLICE_GLOBAL_TEST, - _MULTIPLE_CROSS_SLICE_TEST, - _CROSS_SLICE_MULTIPLE_SLICING_SPEC_TEST, - ) - def testValidateMetricsCrossSliceThresholdFail( - self, cross_slicing_specs, slice_key - ): - threshold = config_pb2.MetricThreshold( - value_threshold=config_pb2.GenericValueThreshold( - upper_bound={'value': 1} + @parameterized.named_parameters( + _NO_SLICE_TEST, + _SINGLE_CROSS_SLICE_TEST, + _CROSS_SLICE_GLOBAL_TEST, + _MULTIPLE_CROSS_SLICE_TEST, + _CROSS_SLICE_MULTIPLE_SLICING_SPEC_TEST, + ) + def testValidateMetricsCrossSliceThresholdFail( + self, cross_slicing_specs, slice_key + ): + threshold = config_pb2.MetricThreshold( + value_threshold=config_pb2.GenericValueThreshold(upper_bound={"value": 1}) ) - ) - eval_config = config_pb2.EvalConfig( - model_specs=[ - config_pb2.ModelSpec(), - ], - cross_slicing_specs=cross_slicing_specs, - metrics_specs=[ - config_pb2.MetricsSpec( - metrics=[ - config_pb2.MetricConfig( - class_name='WeightedExampleCount', - # 1.5 < 1, NOT OK. - threshold=( - threshold if cross_slicing_specs is None else None + eval_config = config_pb2.EvalConfig( + model_specs=[ + config_pb2.ModelSpec(), + ], + cross_slicing_specs=cross_slicing_specs, + metrics_specs=[ + config_pb2.MetricsSpec( + metrics=[ + config_pb2.MetricConfig( + class_name="WeightedExampleCount", + # 1.5 < 1, NOT OK. + threshold=( + threshold if cross_slicing_specs is None else None + ), + cross_slice_thresholds=[ + config_pb2.CrossSliceMetricThreshold( + cross_slicing_specs=cross_slicing_specs, + threshold=threshold, + ) + ], ), - cross_slice_thresholds=[ - config_pb2.CrossSliceMetricThreshold( - cross_slicing_specs=cross_slicing_specs, - threshold=threshold, - ) - ], - ), - ], - model_names=[''], - example_weights=config_pb2.ExampleWeightOptions(weighted=True), - ), - ], - ) - sliced_metrics = ( - slice_key, - { - metric_types.MetricKey( - name='weighted_example_count', example_weighted=True - ): 1.5, - }, - ) - result = metrics_validator.validate_metrics(sliced_metrics, eval_config) - self.assertFalse(result.validation_ok) + ], + model_names=[""], + example_weights=config_pb2.ExampleWeightOptions(weighted=True), + ), + ], + ) + sliced_metrics = ( + slice_key, + { + metric_types.MetricKey( + name="weighted_example_count", example_weighted=True + ): 1.5, + }, + ) + result = metrics_validator.validate_metrics(sliced_metrics, eval_config) + self.assertFalse(result.validation_ok) - @parameterized.named_parameters(_UNMATCHED_CROSS_SLICE_TEST) - def testValidateMetricsCrossSliceThresholdUnmacthed( - self, cross_slicing_specs, slice_key - ): - threshold = config_pb2.MetricThreshold( - value_threshold=config_pb2.GenericValueThreshold( - upper_bound={'value': 1} + @parameterized.named_parameters(_UNMATCHED_CROSS_SLICE_TEST) + def testValidateMetricsCrossSliceThresholdUnmacthed( + self, cross_slicing_specs, slice_key + ): + threshold = config_pb2.MetricThreshold( + value_threshold=config_pb2.GenericValueThreshold(upper_bound={"value": 1}) ) - ) - eval_config = config_pb2.EvalConfig( - model_specs=[ - config_pb2.ModelSpec(), - ], - cross_slicing_specs=cross_slicing_specs, - metrics_specs=[ - config_pb2.MetricsSpec( - metrics=[ - config_pb2.MetricConfig( - class_name='WeightedExampleCount', - # 1.5 < 1, NOT OK. - threshold=( - threshold if cross_slicing_specs is None else None + eval_config = config_pb2.EvalConfig( + model_specs=[ + config_pb2.ModelSpec(), + ], + cross_slicing_specs=cross_slicing_specs, + metrics_specs=[ + config_pb2.MetricsSpec( + metrics=[ + config_pb2.MetricConfig( + class_name="WeightedExampleCount", + # 1.5 < 1, NOT OK. + threshold=( + threshold if cross_slicing_specs is None else None + ), + cross_slice_thresholds=[ + config_pb2.CrossSliceMetricThreshold( + cross_slicing_specs=cross_slicing_specs, + threshold=threshold, + ) + ], ), - cross_slice_thresholds=[ - config_pb2.CrossSliceMetricThreshold( - cross_slicing_specs=cross_slicing_specs, - threshold=threshold, - ) - ], - ), - ], - model_names=[''], - example_weights=config_pb2.ExampleWeightOptions(weighted=True), - ), - ], - ) - sliced_metrics = ( - slice_key, - { - metric_types.MetricKey( - name='weighted_example_count', example_weighted=True - ): 0, - }, - ) - result = metrics_validator.validate_metrics(sliced_metrics, eval_config) - self.assertTrue(result.validation_ok) + ], + model_names=[""], + example_weights=config_pb2.ExampleWeightOptions(weighted=True), + ), + ], + ) + sliced_metrics = ( + slice_key, + { + metric_types.MetricKey( + name="weighted_example_count", example_weighted=True + ): 0, + }, + ) + result = metrics_validator.validate_metrics(sliced_metrics, eval_config) + self.assertTrue(result.validation_ok) - def testValidateMetricsDivByZero(self): - threshold = config_pb2.MetricThreshold( - change_threshold=config_pb2.GenericChangeThreshold( - direction=config_pb2.MetricDirection.HIGHER_IS_BETTER, - relative={'value': 0.1}, + def testValidateMetricsDivByZero(self): + threshold = config_pb2.MetricThreshold( + change_threshold=config_pb2.GenericChangeThreshold( + direction=config_pb2.MetricDirection.HIGHER_IS_BETTER, + relative={"value": 0.1}, + ) ) - ) - slicing_specs = [config_pb2.SlicingSpec()] - eval_config = config_pb2.EvalConfig( - model_specs=[ - config_pb2.ModelSpec(name='candidate'), - config_pb2.ModelSpec(name='baseline', is_baseline=True), - ], - slicing_specs=slicing_specs, - metrics_specs=[ - config_pb2.MetricsSpec( - metrics=[ - config_pb2.MetricConfig( - class_name='MeanPrediction', - threshold=threshold if slicing_specs is None else None, - per_slice_thresholds=[ - config_pb2.PerSliceMetricThreshold( - slicing_specs=slicing_specs, threshold=threshold - ) - ], - ) - ], - model_names=['baseline', 'candidate'], - ), - ], - ) - sliced_metrics = ( - (()), - { - metric_types.MetricKey( - name='mean_prediction', model_name='baseline' - ): 0.0, - metric_types.MetricKey( - name='mean_prediction', model_name='candidate', is_diff=True - ): 0.1, - }, - ) - result = metrics_validator.validate_metrics(sliced_metrics, eval_config) - self.assertFalse(result.validation_ok) + slicing_specs = [config_pb2.SlicingSpec()] + eval_config = config_pb2.EvalConfig( + model_specs=[ + config_pb2.ModelSpec(name="candidate"), + config_pb2.ModelSpec(name="baseline", is_baseline=True), + ], + slicing_specs=slicing_specs, + metrics_specs=[ + config_pb2.MetricsSpec( + metrics=[ + config_pb2.MetricConfig( + class_name="MeanPrediction", + threshold=threshold if slicing_specs is None else None, + per_slice_thresholds=[ + config_pb2.PerSliceMetricThreshold( + slicing_specs=slicing_specs, threshold=threshold + ) + ], + ) + ], + model_names=["baseline", "candidate"], + ), + ], + ) + sliced_metrics = ( + (()), + { + metric_types.MetricKey( + name="mean_prediction", model_name="baseline" + ): 0.0, + metric_types.MetricKey( + name="mean_prediction", model_name="candidate", is_diff=True + ): 0.1, + }, + ) + result = metrics_validator.validate_metrics(sliced_metrics, eval_config) + self.assertFalse(result.validation_ok) -if __name__ == '__main__': - tf.compat.v1.enable_v2_behavior() - tf.test.main() +if __name__ == "__main__": + tf.compat.v1.enable_v2_behavior() + tf.test.main() diff --git a/tensorflow_model_analysis/evaluators/poisson_bootstrap.py b/tensorflow_model_analysis/evaluators/poisson_bootstrap.py index b7826f1755..a7be911714 100644 --- a/tensorflow_model_analysis/evaluators/poisson_bootstrap.py +++ b/tensorflow_model_analysis/evaluators/poisson_bootstrap.py @@ -17,6 +17,7 @@ import apache_beam as beam import numpy as np + from tensorflow_model_analysis.api import types from tensorflow_model_analysis.evaluators import confidence_intervals_util from tensorflow_model_analysis.metrics import metric_types @@ -26,37 +27,35 @@ DEFAULT_NUM_BOOTSTRAP_SAMPLES = 20 _FULL_SAMPLE_ID = -1 -_AccumulatorType = TypeVar('_AccumulatorType') +_AccumulatorType = TypeVar("_AccumulatorType") class _BootstrapCombineFn(beam_util.DelegatingCombineFn): - """CombineFn wrapper which adds poisson resampling to input elements.""" + """CombineFn wrapper which adds poisson resampling to input elements.""" - def __init__( - self, combine_fn: beam.CombineFn, random_seed: Optional[int] = None - ): - super().__init__(combine_fn) - self._random_seed = random_seed + def __init__(self, combine_fn: beam.CombineFn, random_seed: Optional[int] = None): + super().__init__(combine_fn) + self._random_seed = random_seed - def setup(self): - super().setup() - self._random_state = np.random.RandomState(self._random_seed) + def setup(self): + super().setup() + self._random_state = np.random.RandomState(self._random_seed) - def add_input( - self, accumulator: _AccumulatorType, element: Any - ) -> _AccumulatorType: - for sampled_element in [element] * int(self._random_state.poisson(1, 1)): - accumulator = self._combine_fn.add_input(accumulator, sampled_element) - return accumulator + def add_input( + self, accumulator: _AccumulatorType, element: Any + ) -> _AccumulatorType: + for sampled_element in [element] * int(self._random_state.poisson(1, 1)): + accumulator = self._combine_fn.add_input(accumulator, sampled_element) + return accumulator def _add_sample_id( # pylint: disable=invalid-name slice_key, metrics_dict: metric_types.MetricsDict, sample_id: int = 0 ): - # sample_id has a default value in order to satisfy requirement of MapTuple - return slice_key, confidence_intervals_util.SampleMetrics( - metrics=metrics_dict, sample_id=sample_id - ) + # sample_id has a default value in order to satisfy requirement of MapTuple + return slice_key, confidence_intervals_util.SampleMetrics( + metrics=metrics_dict, sample_id=sample_id + ) @beam.ptransform_fn @@ -70,88 +69,89 @@ def _ComputeBootstrapSample( # pylint disable=invalid-name seed: int, hot_key_fanout: int, ) -> beam.PCollection[confidence_intervals_util.SampleMetrics]: - """Computes a single bootstrap sample from SlicedExtracts. - - Args: - sliced_extracts: Incoming PCollection consisting of slice key and extracts. - sample_id: The sample_id to attach to the computed metrics as part of the - returned SampleMetrics objects. - computations_combine_fn: a beam.CombineFn instance that takes input elements - of type Extracts and returns a MetricsDict. This will be invoked as part - of a CombinePerKey in which the key is a slice key. - derived_metrics_ptransform: A PTransform which adds derived metrics to the - results of the computations_combine_fn. This PTransform should both input - and output a single PCollection with elements of type MetricsDict where - the output MetricsDict includes additional derived metrics. - seed: The seed to use when doing resampling. Note that this is only useful - in testing or when using a single worker, as otherwise Beam will introduce - non-determinism in when using distributed computation. - hot_key_fanout: The hot key fanout factor to use when calling - beam.CombinePerKey with the computations_combine_fn on replicates. Note - that these replicates will in expectation have the same size as the input - PCollection of extracts and will use the normal set of slices keys. - - Returns: - A PCollection of sliced SampleMetrics objects, containing the metrics dicts - for a given slice, computed from the resampled extracts, along with the - provided sample_id. - """ - return ( - sliced_extracts - | 'CombineSampledMetricsPerSlice' - >> beam.CombinePerKey( - _BootstrapCombineFn(computations_combine_fn, seed) - ).with_hot_key_fanout(hot_key_fanout) - | 'AddSampledDerivedCrossSliceAndDiffMetrics' - >> derived_metrics_ptransform - | 'AddSampleIdToValue' - >> beam.MapTuple(_add_sample_id, sample_id=sample_id) - ) - - -class _BootstrapSampleCombineFn(confidence_intervals_util.SampleCombineFn): - """Computes the bootstrap standard error for each metric from samples.""" - - def __init__( - self, - num_bootstrap_samples: int, - skip_ci_metric_keys: Optional[Set[metric_types.MetricKey]] = None, - ): - """Initializes a _BootstrapSampleCombineFn. + """Computes a single bootstrap sample from SlicedExtracts. Args: - num_bootstrap_samples: The expected number of samples computed per slice. - skip_ci_metric_keys: Set of metric keys for which to skip confidence - interval computation. For metric keys in this set, just the point - estimate will be returned. + ---- + sliced_extracts: Incoming PCollection consisting of slice key and extracts. + sample_id: The sample_id to attach to the computed metrics as part of the + returned SampleMetrics objects. + computations_combine_fn: a beam.CombineFn instance that takes input elements + of type Extracts and returns a MetricsDict. This will be invoked as part + of a CombinePerKey in which the key is a slice key. + derived_metrics_ptransform: A PTransform which adds derived metrics to the + results of the computations_combine_fn. This PTransform should both input + and output a single PCollection with elements of type MetricsDict where + the output MetricsDict includes additional derived metrics. + seed: The seed to use when doing resampling. Note that this is only useful + in testing or when using a single worker, as otherwise Beam will introduce + non-determinism in when using distributed computation. + hot_key_fanout: The hot key fanout factor to use when calling + beam.CombinePerKey with the computations_combine_fn on replicates. Note + that these replicates will in expectation have the same size as the input + PCollection of extracts and will use the normal set of slices keys. + + Returns: + ------- + A PCollection of sliced SampleMetrics objects, containing the metrics dicts + for a given slice, computed from the resampled extracts, along with the + provided sample_id. """ - super().__init__( - num_samples=num_bootstrap_samples, - full_sample_id=_FULL_SAMPLE_ID, - skip_ci_metric_keys=skip_ci_metric_keys, + return ( + sliced_extracts + | "CombineSampledMetricsPerSlice" + >> beam.CombinePerKey( + _BootstrapCombineFn(computations_combine_fn, seed) + ).with_hot_key_fanout(hot_key_fanout) + | "AddSampledDerivedCrossSliceAndDiffMetrics" >> derived_metrics_ptransform + | "AddSampleIdToValue" >> beam.MapTuple(_add_sample_id, sample_id=sample_id) ) - def extract_output( - self, - accumulator: confidence_intervals_util.SampleCombineFn.SampleAccumulator, - ) -> metric_types.MetricsDict: - accumulator = self._validate_accumulator(accumulator) - result = {} - dof = self._num_samples - 1 - for key, point_estimate in accumulator.point_estimates.items(): - if key not in accumulator.metric_samples: - result[key] = point_estimate - else: - mean, std_error = confidence_intervals_util.mean_and_std( - accumulator.metric_samples[key], ddof=1 - ) - result[key] = types.ValueWithTDistribution( - sample_mean=mean, - sample_standard_deviation=std_error, - unsampled_value=point_estimate, - sample_degrees_of_freedom=dof, + +class _BootstrapSampleCombineFn(confidence_intervals_util.SampleCombineFn): + """Computes the bootstrap standard error for each metric from samples.""" + + def __init__( + self, + num_bootstrap_samples: int, + skip_ci_metric_keys: Optional[Set[metric_types.MetricKey]] = None, + ): + """Initializes a _BootstrapSampleCombineFn. + + Args: + ---- + num_bootstrap_samples: The expected number of samples computed per slice. + skip_ci_metric_keys: Set of metric keys for which to skip confidence + interval computation. For metric keys in this set, just the point + estimate will be returned. + """ + super().__init__( + num_samples=num_bootstrap_samples, + full_sample_id=_FULL_SAMPLE_ID, + skip_ci_metric_keys=skip_ci_metric_keys, ) - return result # pytype: disable=bad-return-type # numpy-scalars + + def extract_output( + self, + accumulator: confidence_intervals_util.SampleCombineFn.SampleAccumulator, + ) -> metric_types.MetricsDict: + accumulator = self._validate_accumulator(accumulator) + result = {} + dof = self._num_samples - 1 + for key, point_estimate in accumulator.point_estimates.items(): + if key not in accumulator.metric_samples: + result[key] = point_estimate + else: + mean, std_error = confidence_intervals_util.mean_and_std( + accumulator.metric_samples[key], ddof=1 + ) + result[key] = types.ValueWithTDistribution( + sample_mean=mean, + sample_standard_deviation=std_error, + unsampled_value=point_estimate, + sample_degrees_of_freedom=dof, + ) + return result # pytype: disable=bad-return-type # numpy-scalars @beam.ptransform_fn @@ -168,74 +168,76 @@ def ComputeWithConfidenceIntervals( # pylint: disable=invalid-name ) -> beam.pvalue.PCollection[ Tuple[slicer.SliceKeyOrCrossSliceKeyType, metric_types.MetricsDict] ]: - """PTransform for computing metrics using T-Distribution values. - - Args: - sliced_extracts: Incoming PCollection consisting of slice key and extracts. - computations_combine_fn: a beam.CombineFn instance that takes input elements - of type Extracts and returns a MetricsDict. This will be invoked as part - of a CombinePerKey in which the key is a slice key. - derived_metrics_ptransform: A PTransform which adds derived metrics to the - results of the computations_combine_fn. This PTransform should both input - and output a single PCollection with elements of type MetricsDict where - the output MetricsDict includes additional derived metrics. - num_bootstrap_samples: The number of bootstrap replicates to use in - computing the bootstrap standard error. - hot_key_fanout: The hot key fanout factor to use when calling - beam.CombinePerKey with the computations_combine_fn on replicates. Note - that these replicates will in expectation have the same size as the input - PCollection of extracts and will use the normal set of slices keys. - skip_ci_metric_keys: Set of metric keys for which to skip confidence - interval computation. For metric keys in this set, just the unsampled - value will be returned. - random_seed_for_testing: Seed to use for unit testing, because - nondeterministic tests stink. Each partition will use this value + i. - - Returns: - PCollection of (slice key, dict of metrics) - """ - if num_bootstrap_samples < 1: - raise ValueError( - 'num_bootstrap_samples should be > 0, got %d' % num_bootstrap_samples - ) + """PTransform for computing metrics using T-Distribution values. - unsampled_metrics = ( - sliced_extracts - | 'CombineUnsampledMetricsPerSlice' - >> beam.CombinePerKey(computations_combine_fn).with_hot_key_fanout( - hot_key_fanout - ) - | 'AddDerivedMetrics' >> derived_metrics_ptransform - | 'AddUnsampledSampleId' >> beam.MapTuple(_add_sample_id, _FULL_SAMPLE_ID) - ) - - sampled_metrics = [] - for sample_id in range(num_bootstrap_samples): - seed = ( - None - if random_seed_for_testing is None - else random_seed_for_testing + sample_id - ) - sampled_metrics.append( + Args: + ---- + sliced_extracts: Incoming PCollection consisting of slice key and extracts. + computations_combine_fn: a beam.CombineFn instance that takes input elements + of type Extracts and returns a MetricsDict. This will be invoked as part + of a CombinePerKey in which the key is a slice key. + derived_metrics_ptransform: A PTransform which adds derived metrics to the + results of the computations_combine_fn. This PTransform should both input + and output a single PCollection with elements of type MetricsDict where + the output MetricsDict includes additional derived metrics. + num_bootstrap_samples: The number of bootstrap replicates to use in + computing the bootstrap standard error. + hot_key_fanout: The hot key fanout factor to use when calling + beam.CombinePerKey with the computations_combine_fn on replicates. Note + that these replicates will in expectation have the same size as the input + PCollection of extracts and will use the normal set of slices keys. + skip_ci_metric_keys: Set of metric keys for which to skip confidence + interval computation. For metric keys in this set, just the unsampled + value will be returned. + random_seed_for_testing: Seed to use for unit testing, because + nondeterministic tests stink. Each partition will use this value + i. + + Returns: + ------- + PCollection of (slice key, dict of metrics) + """ + if num_bootstrap_samples < 1: + raise ValueError( + "num_bootstrap_samples should be > 0, got %d" % num_bootstrap_samples + ) + + unsampled_metrics = ( sliced_extracts - | f'ComputeBootstrapSample[{sample_id}]' - >> _ComputeBootstrapSample( # pylint: disable=no-value-for-parameter - sample_id=sample_id, - computations_combine_fn=computations_combine_fn, - derived_metrics_ptransform=derived_metrics_ptransform, - seed=seed, - hot_key_fanout=hot_key_fanout, + | "CombineUnsampledMetricsPerSlice" + >> beam.CombinePerKey(computations_combine_fn).with_hot_key_fanout( + hot_key_fanout ) + | "AddDerivedMetrics" >> derived_metrics_ptransform + | "AddUnsampledSampleId" >> beam.MapTuple(_add_sample_id, _FULL_SAMPLE_ID) ) - return ( - sampled_metrics + [unsampled_metrics] - | 'FlattenBootstrapPartitions' >> beam.Flatten() - | 'CombineSamplesPerSlice' - >> beam.CombinePerKey( - _BootstrapSampleCombineFn( - num_bootstrap_samples=num_bootstrap_samples, - skip_ci_metric_keys=skip_ci_metric_keys, - ) - ) - ) + sampled_metrics = [] + for sample_id in range(num_bootstrap_samples): + seed = ( + None + if random_seed_for_testing is None + else random_seed_for_testing + sample_id + ) + sampled_metrics.append( + sliced_extracts + | f"ComputeBootstrapSample[{sample_id}]" + >> _ComputeBootstrapSample( # pylint: disable=no-value-for-parameter + sample_id=sample_id, + computations_combine_fn=computations_combine_fn, + derived_metrics_ptransform=derived_metrics_ptransform, + seed=seed, + hot_key_fanout=hot_key_fanout, + ) + ) + + return ( + sampled_metrics + [unsampled_metrics] + | "FlattenBootstrapPartitions" >> beam.Flatten() + | "CombineSamplesPerSlice" + >> beam.CombinePerKey( + _BootstrapSampleCombineFn( + num_bootstrap_samples=num_bootstrap_samples, + skip_ci_metric_keys=skip_ci_metric_keys, + ) + ) + ) diff --git a/tensorflow_model_analysis/evaluators/poisson_bootstrap_test.py b/tensorflow_model_analysis/evaluators/poisson_bootstrap_test.py index 1789ef4fec..67109be022 100644 --- a/tensorflow_model_analysis/evaluators/poisson_bootstrap_test.py +++ b/tensorflow_model_analysis/evaluators/poisson_bootstrap_test.py @@ -13,337 +13,334 @@ # limitations under the License. """Test for using the poisson bootstrap API.""" -from absl.testing import absltest import apache_beam as beam -from apache_beam.testing import util import numpy as np +from absl.testing import absltest +from apache_beam.testing import util + from tensorflow_model_analysis.api import types -from tensorflow_model_analysis.evaluators import confidence_intervals_util -from tensorflow_model_analysis.evaluators import poisson_bootstrap -from tensorflow_model_analysis.metrics import binary_confusion_matrices -from tensorflow_model_analysis.metrics import metric_types +from tensorflow_model_analysis.evaluators import ( + confidence_intervals_util, + poisson_bootstrap, +) +from tensorflow_model_analysis.metrics import binary_confusion_matrices, metric_types class PoissonBootstrapTest(absltest.TestCase): + def test_bootstrap_combine_fn(self): + with beam.Pipeline() as pipeline: + result = ( + pipeline + | "Create" >> beam.Create(range(5), reshuffle=False) + | "BootstrapCombine" + >> beam.CombineGlobally( + poisson_bootstrap._BootstrapCombineFn( + combine_fn=beam.combiners.ToListCombineFn(), random_seed=0 + ) + ) + ) - def test_bootstrap_combine_fn(self): - with beam.Pipeline() as pipeline: - result = ( - pipeline - | 'Create' >> beam.Create(range(5), reshuffle=False) - | 'BootstrapCombine' - >> beam.CombineGlobally( - poisson_bootstrap._BootstrapCombineFn( - combine_fn=beam.combiners.ToListCombineFn(), random_seed=0 - ) - ) - ) - - def check_result(got_pcoll): - self.assertLen(got_pcoll, 1) - self.assertEqual([0, 0, 1, 2, 3, 3, 4, 4], got_pcoll[0]) + def check_result(got_pcoll): + self.assertLen(got_pcoll, 1) + self.assertEqual([0, 0, 1, 2, 3, 3, 4, 4], got_pcoll[0]) - util.assert_that(result, check_result) + util.assert_that(result, check_result) - def test_boostrap_sample_combine_fn(self): - metric_key = metric_types.MetricKey(name='metric') - samples = [ - confidence_intervals_util.SampleMetrics( - sample_id=0, metrics={metric_key: 0} - ), - confidence_intervals_util.SampleMetrics( - sample_id=1, metrics={metric_key: 7} - ), - confidence_intervals_util.SampleMetrics( - sample_id=poisson_bootstrap._FULL_SAMPLE_ID, metrics={metric_key: 4} - ), - ] + def test_boostrap_sample_combine_fn(self): + metric_key = metric_types.MetricKey(name="metric") + samples = [ + confidence_intervals_util.SampleMetrics( + sample_id=0, metrics={metric_key: 0} + ), + confidence_intervals_util.SampleMetrics( + sample_id=1, metrics={metric_key: 7} + ), + confidence_intervals_util.SampleMetrics( + sample_id=poisson_bootstrap._FULL_SAMPLE_ID, metrics={metric_key: 4} + ), + ] - with beam.Pipeline() as pipeline: - result = ( - pipeline - | 'Create' >> beam.Create(samples, reshuffle=False) - | 'CombineSamples' - >> beam.CombineGlobally( - poisson_bootstrap._BootstrapSampleCombineFn( - num_bootstrap_samples=2 - ) - ) - ) + with beam.Pipeline() as pipeline: + result = ( + pipeline + | "Create" >> beam.Create(samples, reshuffle=False) + | "CombineSamples" + >> beam.CombineGlobally( + poisson_bootstrap._BootstrapSampleCombineFn(num_bootstrap_samples=2) + ) + ) - def check_result(got_pcoll): - self.assertLen(got_pcoll, 1) - metrics = got_pcoll[0] + def check_result(got_pcoll): + self.assertLen(got_pcoll, 1) + metrics = got_pcoll[0] - self.assertIn(metric_key, metrics) - self.assertAlmostEqual(metrics[metric_key].sample_mean, 3.5, delta=0.1) - self.assertAlmostEqual( - metrics[metric_key].sample_standard_deviation, 4.94, delta=0.1 - ) - self.assertEqual(metrics[metric_key].sample_degrees_of_freedom, 1) - self.assertEqual(metrics[metric_key].unsampled_value, 4.0) + self.assertIn(metric_key, metrics) + self.assertAlmostEqual(metrics[metric_key].sample_mean, 3.5, delta=0.1) + self.assertAlmostEqual( + metrics[metric_key].sample_standard_deviation, 4.94, delta=0.1 + ) + self.assertEqual(metrics[metric_key].sample_degrees_of_freedom, 1) + self.assertEqual(metrics[metric_key].unsampled_value, 4.0) - util.assert_that(result, check_result) + util.assert_that(result, check_result) - def test_boostrap_sample_combine_fn_per_slice(self): - x_key = metric_types.MetricKey('x') - y_key = metric_types.MetricKey('y') - cm_key = metric_types.MetricKey('confusion_matrix') - cm_metric = binary_confusion_matrices.Matrices( - thresholds=[0.5], tp=[0], fp=[1], tn=[2], fn=[3] - ) - skipped_metric_key = metric_types.MetricKey('skipped_metric') - slice_key1 = (('slice_feature', 1),) - slice_key2 = (('slice_feature', 2),) - samples = [ - # unsampled value for slice 1 - ( - slice_key1, - confidence_intervals_util.SampleMetrics( - sample_id=poisson_bootstrap._FULL_SAMPLE_ID, - metrics={ - x_key: 1.6, - y_key: 16, - cm_key: cm_metric, - skipped_metric_key: 100, - }, + def test_boostrap_sample_combine_fn_per_slice(self): + x_key = metric_types.MetricKey("x") + y_key = metric_types.MetricKey("y") + cm_key = metric_types.MetricKey("confusion_matrix") + cm_metric = binary_confusion_matrices.Matrices( + thresholds=[0.5], tp=[0], fp=[1], tn=[2], fn=[3] + ) + skipped_metric_key = metric_types.MetricKey("skipped_metric") + slice_key1 = (("slice_feature", 1),) + slice_key2 = (("slice_feature", 2),) + samples = [ + # unsampled value for slice 1 + ( + slice_key1, + confidence_intervals_util.SampleMetrics( + sample_id=poisson_bootstrap._FULL_SAMPLE_ID, + metrics={ + x_key: 1.6, + y_key: 16, + cm_key: cm_metric, + skipped_metric_key: 100, + }, + ), ), - ), - # sample values 1 of 2 for slice 1 - ( - slice_key1, - confidence_intervals_util.SampleMetrics( - sample_id=0, - metrics={ - x_key: 1, - y_key: 10, - cm_key: cm_metric, - skipped_metric_key: 45, - }, + # sample values 1 of 2 for slice 1 + ( + slice_key1, + confidence_intervals_util.SampleMetrics( + sample_id=0, + metrics={ + x_key: 1, + y_key: 10, + cm_key: cm_metric, + skipped_metric_key: 45, + }, + ), ), - ), - # sample values 2 of 2 for slice 1 - ( - slice_key1, - confidence_intervals_util.SampleMetrics( - sample_id=1, - metrics={ - x_key: 2, - y_key: 20, - cm_key: cm_metric, - skipped_metric_key: 55, - }, + # sample values 2 of 2 for slice 1 + ( + slice_key1, + confidence_intervals_util.SampleMetrics( + sample_id=1, + metrics={ + x_key: 2, + y_key: 20, + cm_key: cm_metric, + skipped_metric_key: 55, + }, + ), ), - ), - # unsampled value for slice 2 - ( - slice_key2, - confidence_intervals_util.SampleMetrics( - sample_id=poisson_bootstrap._FULL_SAMPLE_ID, - metrics={ - x_key: 3.3, - y_key: 33, - cm_key: cm_metric, - skipped_metric_key: 1000, - }, + # unsampled value for slice 2 + ( + slice_key2, + confidence_intervals_util.SampleMetrics( + sample_id=poisson_bootstrap._FULL_SAMPLE_ID, + metrics={ + x_key: 3.3, + y_key: 33, + cm_key: cm_metric, + skipped_metric_key: 1000, + }, + ), ), - ), - # sample values 1 of 2 for slice 2 - ( - slice_key2, - confidence_intervals_util.SampleMetrics( - sample_id=0, - metrics={ - x_key: 2, - y_key: 20, - cm_key: cm_metric, - skipped_metric_key: 450, - }, + # sample values 1 of 2 for slice 2 + ( + slice_key2, + confidence_intervals_util.SampleMetrics( + sample_id=0, + metrics={ + x_key: 2, + y_key: 20, + cm_key: cm_metric, + skipped_metric_key: 450, + }, + ), ), - ), - # sample values 2 of 2 for slice 2 - ( - slice_key2, - confidence_intervals_util.SampleMetrics( - sample_id=1, - metrics={ - x_key: 4, - y_key: 40, - cm_key: cm_metric, - skipped_metric_key: 550, - }, + # sample values 2 of 2 for slice 2 + ( + slice_key2, + confidence_intervals_util.SampleMetrics( + sample_id=1, + metrics={ + x_key: 4, + y_key: 40, + cm_key: cm_metric, + skipped_metric_key: 550, + }, + ), ), - ), - ] + ] - with beam.Pipeline() as pipeline: - result = ( - pipeline - | 'Create' >> beam.Create(samples, reshuffle=False) - | 'CombineSamplesPerKey' - >> beam.CombinePerKey( - poisson_bootstrap._BootstrapSampleCombineFn( - num_bootstrap_samples=2, - skip_ci_metric_keys=[skipped_metric_key], - ) - ) - ) + with beam.Pipeline() as pipeline: + result = ( + pipeline + | "Create" >> beam.Create(samples, reshuffle=False) + | "CombineSamplesPerKey" + >> beam.CombinePerKey( + poisson_bootstrap._BootstrapSampleCombineFn( + num_bootstrap_samples=2, + skip_ci_metric_keys=[skipped_metric_key], + ) + ) + ) - def check_result(got_pcoll): - expected_pcoll = [ - ( - slice_key1, - { - x_key: types.ValueWithTDistribution( - sample_mean=1.5, - # sample_standard_deviation=0.5 - sample_standard_deviation=np.std([1, 2], ddof=1), - sample_degrees_of_freedom=1, - unsampled_value=1.6, - ), - y_key: types.ValueWithTDistribution( - sample_mean=15.0, - # sample_standard_deviation=5, - sample_standard_deviation=np.std([10, 20], ddof=1), - sample_degrees_of_freedom=1, - unsampled_value=16, + def check_result(got_pcoll): + expected_pcoll = [ + ( + slice_key1, + { + x_key: types.ValueWithTDistribution( + sample_mean=1.5, + # sample_standard_deviation=0.5 + sample_standard_deviation=np.std([1, 2], ddof=1), + sample_degrees_of_freedom=1, + unsampled_value=1.6, + ), + y_key: types.ValueWithTDistribution( + sample_mean=15.0, + # sample_standard_deviation=5, + sample_standard_deviation=np.std([10, 20], ddof=1), + sample_degrees_of_freedom=1, + unsampled_value=16, + ), + cm_key: types.ValueWithTDistribution( + sample_mean=cm_metric, + sample_standard_deviation=cm_metric * 0, + sample_degrees_of_freedom=1, + unsampled_value=cm_metric, + ), + skipped_metric_key: 100, + }, ), - cm_key: types.ValueWithTDistribution( - sample_mean=cm_metric, - sample_standard_deviation=cm_metric * 0, - sample_degrees_of_freedom=1, - unsampled_value=cm_metric, + ( + slice_key2, + { + x_key: types.ValueWithTDistribution( + sample_mean=3.0, + # sample_standard_deviation=1, + sample_standard_deviation=np.std([2, 4], ddof=1), + sample_degrees_of_freedom=1, + unsampled_value=3.3, + ), + y_key: types.ValueWithTDistribution( + sample_mean=30.0, + # sample_standard_deviation=10, + sample_standard_deviation=np.std([20, 40], ddof=1), + sample_degrees_of_freedom=1, + unsampled_value=33, + ), + cm_key: types.ValueWithTDistribution( + sample_mean=cm_metric, + sample_standard_deviation=cm_metric * 0, + sample_degrees_of_freedom=1, + unsampled_value=cm_metric, + ), + skipped_metric_key: 1000, + }, ), - skipped_metric_key: 100, - }, + ] + self.assertCountEqual(expected_pcoll, got_pcoll) + + util.assert_that(result, check_result) + + def test_bootstrap_sample_combine_fn_sample_is_nan(self): + metric_key = metric_types.MetricKey("metric") + # the sample value is irrelevant for this test as we only verify counters. + samples = [ + # unsampled value + ( + confidence_intervals_util.SampleMetrics( + sample_id=poisson_bootstrap._FULL_SAMPLE_ID, + metrics={ + metric_key: 2, + }, + ) ), ( - slice_key2, - { - x_key: types.ValueWithTDistribution( - sample_mean=3.0, - # sample_standard_deviation=1, - sample_standard_deviation=np.std([2, 4], ddof=1), - sample_degrees_of_freedom=1, - unsampled_value=3.3, - ), - y_key: types.ValueWithTDistribution( - sample_mean=30.0, - # sample_standard_deviation=10, - sample_standard_deviation=np.std([20, 40], ddof=1), - sample_degrees_of_freedom=1, - unsampled_value=33, - ), - cm_key: types.ValueWithTDistribution( - sample_mean=cm_metric, - sample_standard_deviation=cm_metric * 0, - sample_degrees_of_freedom=1, - unsampled_value=cm_metric, - ), - skipped_metric_key: 1000, - }, + confidence_intervals_util.SampleMetrics( + sample_id=0, metrics={metric_key: 2} + ) + ), + ( + confidence_intervals_util.SampleMetrics( + sample_id=1, metrics={metric_key: float("nan")} + ) ), ] - self.assertCountEqual(expected_pcoll, got_pcoll) - util.assert_that(result, check_result) + with beam.Pipeline() as pipeline: + result = ( + pipeline + | "Create" >> beam.Create(samples, reshuffle=False) + | "CombineSamplesPerKey" + >> beam.CombineGlobally( + poisson_bootstrap._BootstrapSampleCombineFn(num_bootstrap_samples=2) + ) + ) + + def check_result(got_pcoll): + self.assertLen(got_pcoll, 1) + metrics = got_pcoll[0] + + self.assertIn(metric_key, metrics) + self.assertTrue(np.isnan(metrics[metric_key].sample_mean)) + self.assertTrue(np.isnan(metrics[metric_key].sample_standard_deviation)) + self.assertEqual(metrics[metric_key].sample_degrees_of_freedom, 1) + self.assertEqual(metrics[metric_key].unsampled_value, 2.0) - def test_bootstrap_sample_combine_fn_sample_is_nan(self): - metric_key = metric_types.MetricKey('metric') - # the sample value is irrelevant for this test as we only verify counters. - samples = [ - # unsampled value - ( + util.assert_that(result, check_result) + + def test_boostrap_sample_combine_fn_numpy_overflow(self): + sample_values = np.random.RandomState(seed=0).randint(0, 1e10, 20) + metric_key = metric_types.MetricKey("metric") + samples = [ confidence_intervals_util.SampleMetrics( sample_id=poisson_bootstrap._FULL_SAMPLE_ID, metrics={ - metric_key: 2, + metric_key: 1, }, ) - ), - ( - confidence_intervals_util.SampleMetrics( - sample_id=0, metrics={metric_key: 2} + ] + for sample_id, value in enumerate(sample_values): + samples.append( + confidence_intervals_util.SampleMetrics( + sample_id=sample_id, + metrics={ + metric_key: value, + }, + ) ) - ), - ( - confidence_intervals_util.SampleMetrics( - sample_id=1, metrics={metric_key: float('nan')} + with beam.Pipeline() as pipeline: + result = ( + pipeline + | "Create" >> beam.Create(samples, reshuffle=False) + | "CombineSamples" + >> beam.CombineGlobally( + poisson_bootstrap._BootstrapSampleCombineFn( + num_bootstrap_samples=20 + ) + ) ) - ), - ] - - with beam.Pipeline() as pipeline: - result = ( - pipeline - | 'Create' >> beam.Create(samples, reshuffle=False) - | 'CombineSamplesPerKey' - >> beam.CombineGlobally( - poisson_bootstrap._BootstrapSampleCombineFn( - num_bootstrap_samples=2 - ) - ) - ) - - def check_result(got_pcoll): - self.assertLen(got_pcoll, 1) - metrics = got_pcoll[0] - - self.assertIn(metric_key, metrics) - self.assertTrue(np.isnan(metrics[metric_key].sample_mean)) - self.assertTrue(np.isnan(metrics[metric_key].sample_standard_deviation)) - self.assertEqual(metrics[metric_key].sample_degrees_of_freedom, 1) - self.assertEqual(metrics[metric_key].unsampled_value, 2.0) - util.assert_that(result, check_result) - - def test_boostrap_sample_combine_fn_numpy_overflow(self): - sample_values = np.random.RandomState(seed=0).randint(0, 1e10, 20) - metric_key = metric_types.MetricKey('metric') - samples = [ - confidence_intervals_util.SampleMetrics( - sample_id=poisson_bootstrap._FULL_SAMPLE_ID, - metrics={ - metric_key: 1, - }, - ) - ] - for sample_id, value in enumerate(sample_values): - samples.append( - confidence_intervals_util.SampleMetrics( - sample_id=sample_id, - metrics={ - metric_key: value, - }, - ) - ) - with beam.Pipeline() as pipeline: - result = ( - pipeline - | 'Create' >> beam.Create(samples, reshuffle=False) - | 'CombineSamples' - >> beam.CombineGlobally( - poisson_bootstrap._BootstrapSampleCombineFn( - num_bootstrap_samples=20 - ) - ) - ) - - def check_result(got_pcoll): - expected_pcoll = [ - { - metric_key: types.ValueWithTDistribution( - sample_mean=5293977041.15, - sample_standard_deviation=3023624729.537024, - sample_degrees_of_freedom=19, - unsampled_value=1, - ), - }, - ] - self.assertCountEqual(expected_pcoll, got_pcoll) + def check_result(got_pcoll): + expected_pcoll = [ + { + metric_key: types.ValueWithTDistribution( + sample_mean=5293977041.15, + sample_standard_deviation=3023624729.537024, + sample_degrees_of_freedom=19, + unsampled_value=1, + ), + }, + ] + self.assertCountEqual(expected_pcoll, got_pcoll) - util.assert_that(result, check_result) + util.assert_that(result, check_result) -if __name__ == '__main__': - absltest.main() +if __name__ == "__main__": + absltest.main() diff --git a/tensorflow_model_analysis/evaluators/query_metrics/__init__.py b/tensorflow_model_analysis/evaluators/query_metrics/__init__.py index d44bba196a..bfbb271e0a 100644 --- a/tensorflow_model_analysis/evaluators/query_metrics/__init__.py +++ b/tensorflow_model_analysis/evaluators/query_metrics/__init__.py @@ -13,6 +13,10 @@ # limitations under the License. """Init module for TensorFlow Model Analysis query metrics.""" -from tensorflow_model_analysis.evaluators.query_metrics.min_label_position import MinLabelPositionCombineFn +from tensorflow_model_analysis.evaluators.query_metrics.min_label_position import ( + MinLabelPositionCombineFn, +) from tensorflow_model_analysis.evaluators.query_metrics.ndcg import NdcgMetricCombineFn -from tensorflow_model_analysis.evaluators.query_metrics.query_statistics import QueryStatisticsCombineFn +from tensorflow_model_analysis.evaluators.query_metrics.query_statistics import ( + QueryStatisticsCombineFn, +) diff --git a/tensorflow_model_analysis/evaluators/query_metrics/min_label_position.py b/tensorflow_model_analysis/evaluators/query_metrics/min_label_position.py index c413a7544e..75c2d1ffe4 100644 --- a/tensorflow_model_analysis/evaluators/query_metrics/min_label_position.py +++ b/tensorflow_model_analysis/evaluators/query_metrics/min_label_position.py @@ -23,118 +23,120 @@ from typing import Any, Dict, Iterable, NamedTuple import apache_beam as beam + from tensorflow_model_analysis import constants from tensorflow_model_analysis.evaluators.query_metrics import query_types from tensorflow_model_analysis.post_export_metrics import metric_keys from tensorflow_model_analysis.utils import util -_State = NamedTuple('_State', [('min_pos_sum', float), ('weight_sum', float)]) - -def _get_feature_value(fpl: query_types.FPL, key: str) -> float: - """Get value of the given feature from the features dictionary. - - The feature must have exactly one value. - - Args: - fpl: FPL - key: Key of feature to retrieve in features dictionary. - - Returns: - The singular value of the feature. - """ - feature = fpl['features'].get(key) - if feature is None: - raise ValueError( - 'feature %s not found in features %s' % (key, fpl['features']) - ) - if feature.size != 1: - raise ValueError( - 'feature %s did not contain exactly 1 value. value was: %s' - % (key, feature) - ) - return feature[0][0] +class _State(NamedTuple): + min_pos_sum: float + weight_sum: float -class MinLabelPositionCombineFn(beam.CombineFn): - """Computes minimum label position.""" +def _get_feature_value(fpl: query_types.FPL, key: str) -> float: + """Get value of the given feature from the features dictionary. - def __init__(self, label_key: str, weight_key: str): - """Initialize. + The feature must have exactly one value. Args: - label_key: The key in the labels dictionary which holds the label. Set - this to empty to if labels is a Tensor and not a dictionary. - weight_key: The key in the features dictionary which holds the weights. - Note that the weight value must be identical across all examples in the - same query. If set to empty, uses 1.0 instead. + ---- + fpl: FPL + key: Key of feature to retrieve in features dictionary. + + Returns: + ------- + The singular value of the feature. """ - if not label_key: - # If label_key is set to the empty string, the user is telling us - # that their Estimator returns a labels Tensor rather than a - # dictionary. Set the key to the magic key we use in that case. - self._label_key = util.default_dict_key(constants.LABELS_KEY) - else: - self._label_key = label_key - self._weight_key = weight_key - - def _get_label(self, fpl: query_types.FPL) -> float: - result = fpl['labels'].get(self._label_key) - if result is None: - return 0.0 - return result - - def create_accumulator(self): - return _State(min_pos_sum=0.0, weight_sum=0.0) - - def _add_states(self, left: _State, right: _State) -> _State: - return _State( - min_pos_sum=left.min_pos_sum + right.min_pos_sum, - weight_sum=left.weight_sum + right.weight_sum, - ) - - def add_input( - self, accumulator: _State, query_fpl: query_types.QueryFPL - ) -> _State: - weight = 1.0 - if self._weight_key: - weights = [ - float(_get_feature_value(fpl, self._weight_key)) - for fpl in query_fpl.fpls - ] - if weights: - if min(weights) != max(weights): - raise ValueError( - 'weights were not identical for all examples in the ' - 'query. query_id was: %s, weights were: %s' - % (query_fpl.query_id, weights) - ) - weight = weights[0] - - min_label_pos = None - for pos, fpl in enumerate(query_fpl.fpls): - if self._get_label(fpl) > 0: - min_label_pos = pos + 1 # Use 1-indexed positions - break - - state_to_add = _State(min_pos_sum=0.0, weight_sum=0.0) - if min_label_pos: - state_to_add = _State(min_pos_sum=min_label_pos, weight_sum=weight) - - return self._add_states(accumulator, state_to_add) - - def merge_accumulators(self, accumulators: Iterable[_State]) -> _State: - accumulators = iter(accumulators) - result = next(accumulators) - for accumulator in accumulators: - result = self._add_states(result, accumulator) - return result - - def extract_output(self, accumulator: _State) -> Dict[str, Any]: - if accumulator.weight_sum > 0: - return { - metric_keys.base_key( - 'average_min_label_position/%s' % self._label_key - ): (accumulator.min_pos_sum / accumulator.weight_sum) - } - return {} + feature = fpl["features"].get(key) + if feature is None: + raise ValueError("feature %s not found in features %s" % (key, fpl["features"])) + if feature.size != 1: + raise ValueError( + "feature %s did not contain exactly 1 value. value was: %s" % (key, feature) + ) + return feature[0][0] + + +class MinLabelPositionCombineFn(beam.CombineFn): + """Computes minimum label position.""" + + def __init__(self, label_key: str, weight_key: str): + """Initialize. + + Args: + ---- + label_key: The key in the labels dictionary which holds the label. Set + this to empty to if labels is a Tensor and not a dictionary. + weight_key: The key in the features dictionary which holds the weights. + Note that the weight value must be identical across all examples in the + same query. If set to empty, uses 1.0 instead. + """ + if not label_key: + # If label_key is set to the empty string, the user is telling us + # that their Estimator returns a labels Tensor rather than a + # dictionary. Set the key to the magic key we use in that case. + self._label_key = util.default_dict_key(constants.LABELS_KEY) + else: + self._label_key = label_key + self._weight_key = weight_key + + def _get_label(self, fpl: query_types.FPL) -> float: + result = fpl["labels"].get(self._label_key) + if result is None: + return 0.0 + return result + + def create_accumulator(self): + return _State(min_pos_sum=0.0, weight_sum=0.0) + + def _add_states(self, left: _State, right: _State) -> _State: + return _State( + min_pos_sum=left.min_pos_sum + right.min_pos_sum, + weight_sum=left.weight_sum + right.weight_sum, + ) + + def add_input(self, accumulator: _State, query_fpl: query_types.QueryFPL) -> _State: + weight = 1.0 + if self._weight_key: + weights = [ + float(_get_feature_value(fpl, self._weight_key)) + for fpl in query_fpl.fpls + ] + if weights: + if min(weights) != max(weights): + raise ValueError( + "weights were not identical for all examples in the " + "query. query_id was: %s, weights were: %s" + % (query_fpl.query_id, weights) + ) + weight = weights[0] + + min_label_pos = None + for pos, fpl in enumerate(query_fpl.fpls): + if self._get_label(fpl) > 0: + min_label_pos = pos + 1 # Use 1-indexed positions + break + + state_to_add = _State(min_pos_sum=0.0, weight_sum=0.0) + if min_label_pos: + state_to_add = _State(min_pos_sum=min_label_pos, weight_sum=weight) + + return self._add_states(accumulator, state_to_add) + + def merge_accumulators(self, accumulators: Iterable[_State]) -> _State: + accumulators = iter(accumulators) + result = next(accumulators) + for accumulator in accumulators: + result = self._add_states(result, accumulator) + return result + + def extract_output(self, accumulator: _State) -> Dict[str, Any]: + if accumulator.weight_sum > 0: + return { + metric_keys.base_key( + "average_min_label_position/%s" % self._label_key + ): (accumulator.min_pos_sum / accumulator.weight_sum) + } + return {} diff --git a/tensorflow_model_analysis/evaluators/query_metrics/ndcg.py b/tensorflow_model_analysis/evaluators/query_metrics/ndcg.py index 53a2249e00..ff48b9b7e5 100644 --- a/tensorflow_model_analysis/evaluators/query_metrics/ndcg.py +++ b/tensorflow_model_analysis/evaluators/query_metrics/ndcg.py @@ -27,147 +27,151 @@ import apache_beam as beam import numpy as np + from tensorflow_model_analysis.evaluators.query_metrics import query_types from tensorflow_model_analysis.post_export_metrics import metric_keys -_State = NamedTuple('_State', [('ndcg', Dict[int, float]), ('weight', float)]) - -def _get_feature_value(fpl: query_types.FPL, key: str) -> float: - """Get value of the given feature from the features dictionary. - - The feature must have exactly one value. - - Args: - fpl: FPL - key: Key of feature to retrieve in features dictionary. - - Returns: - The singular value of the feature. - """ - feature = fpl['features'].get(key) - if feature is None: - raise ValueError( - 'feature %s not found in features %s' % (key, fpl['features']) - ) - if feature.size != 1: - raise ValueError( - 'feature %s did not contain exactly 1 value. value was: %s' - % (key, feature) - ) - return feature[0][0] +class _State(NamedTuple): + ndcg: Dict[int, float] + weight: float -class NdcgMetricCombineFn(beam.CombineFn): - """Computes normalized discounted cumulative gain.""" - - def __init__(self, at_vals: List[int], gain_key: str, weight_key: str): - """Initialize. - - Args: - at_vals: A list containing the number of values to consider in calculating - the values of nDCG (eg. nDCG@at). - gain_key: The key in the features dictionary which holds the gain values. - weight_key: The key in the features dictionary which holds the weights. - Note that the weight value must be identical across all examples in the - same query. If set to empty, uses 1.0 instead. - """ - self._at_vals = at_vals - self._gain_key = gain_key - self._weight_key = weight_key +def _get_feature_value(fpl: query_types.FPL, key: str) -> float: + """Get value of the given feature from the features dictionary. - def _calculate_dcg_at_k(self, k: int, sorted_values: List[float]) -> float: - """Calculate the value of DCG@k. + The feature must have exactly one value. Args: - k: The last position to consider. - sorted_values: A list of gain values assumed to be sorted in the desired - ranking order. + ---- + fpl: FPL + key: Key of feature to retrieve in features dictionary. Returns: - The value of DCG@k. + ------- + The singular value of the feature. """ - return np.sum( - np.array(sorted_values)[:k] / np.log2(np.array(range(2, k + 2))) - ) + feature = fpl["features"].get(key) + if feature is None: + raise ValueError("feature %s not found in features %s" % (key, fpl["features"])) + if feature.size != 1: + raise ValueError( + "feature %s did not contain exactly 1 value. value was: %s" % (key, feature) + ) + return feature[0][0] - def _calculate_ndcg(self, values: List[Tuple[int, float]], k: int) -> float: - """Calculate nDCG@k, based on given rank and gain values. - - Args: - values: A list of tuples representing rank order and gain values. - k: The maximum position to consider in calculating nDCG - Returns: - The value of nDCG@k, for the given list of values. - """ - max_rank = min(k, len(values)) - ranked_values = [ - gain for _, gain in sorted(values, key=lambda x: x[0], reverse=False) - ] - optimal_values = [ - gain for _, gain in sorted(values, key=lambda x: x[1], reverse=True) - ] - dcg = self._calculate_dcg_at_k(max_rank, ranked_values) - optimal_dcg = self._calculate_dcg_at_k(max_rank, optimal_values) - if optimal_dcg > 0: - return dcg / optimal_dcg - else: - return 0 - - def _new_ndcg_dict(self): - return dict.fromkeys(self._at_vals, 0) - - def create_accumulator(self): - return _State(ndcg=self._new_ndcg_dict(), weight=0.0) - - def _add_states(self, left: _State, right: _State) -> _State: - ndcg_dict = self._new_ndcg_dict() - for at in self._at_vals: - ndcg_dict[at] = left.ndcg[at] + right.ndcg[at] - return _State(ndcg_dict, left.weight + right.weight) - - def add_input( - self, accumulator: _State, query_fpl: query_types.QueryFPL - ) -> _State: - weight = 1.0 - if self._weight_key: - weights = [ - float(_get_feature_value(fpl, self._weight_key)) - for fpl in query_fpl.fpls - ] - if weights: - if min(weights) != max(weights): - raise ValueError( - 'weights were not identical for all examples in the ' - 'query. query_id was: %s, weights were: %s' - % (query_fpl.query_id, weights) - ) - weight = weights[0] - - ndcg_dict = {} - for at in self._at_vals: - rank_gain = [ - (pos + 1, float(_get_feature_value(fpl, self._gain_key))) - for pos, fpl in enumerate(query_fpl.fpls) - ] - ndcg_dict[at] = self._calculate_ndcg(rank_gain, at) * weight - - return self._add_states(accumulator, _State(ndcg=ndcg_dict, weight=weight)) - - def merge_accumulators(self, accumulators: Iterable[_State]) -> _State: - accumulators = iter(accumulators) - result = next(accumulators) - for accumulator in accumulators: - result = self._add_states(result, accumulator) - return result - - def extract_output(self, accumulator: _State) -> Dict[str, Any]: - avg_dict = {} - for at in self._at_vals: - if accumulator.weight > 0: - avg_ndcg = accumulator.ndcg[at] / accumulator.weight - else: - avg_ndcg = 0 - avg_dict[metric_keys.base_key('ndcg@%d' % at)] = avg_ndcg - return avg_dict +class NdcgMetricCombineFn(beam.CombineFn): + """Computes normalized discounted cumulative gain.""" + + def __init__(self, at_vals: List[int], gain_key: str, weight_key: str): + """Initialize. + + Args: + ---- + at_vals: A list containing the number of values to consider in calculating + the values of nDCG (eg. nDCG@at). + gain_key: The key in the features dictionary which holds the gain values. + weight_key: The key in the features dictionary which holds the weights. + Note that the weight value must be identical across all examples in the + same query. If set to empty, uses 1.0 instead. + """ + self._at_vals = at_vals + self._gain_key = gain_key + self._weight_key = weight_key + + def _calculate_dcg_at_k(self, k: int, sorted_values: List[float]) -> float: + """Calculate the value of DCG@k. + + Args: + ---- + k: The last position to consider. + sorted_values: A list of gain values assumed to be sorted in the desired + ranking order. + + Returns: + ------- + The value of DCG@k. + """ + return np.sum(np.array(sorted_values)[:k] / np.log2(np.array(range(2, k + 2)))) + + def _calculate_ndcg(self, values: List[Tuple[int, float]], k: int) -> float: + """Calculate nDCG@k, based on given rank and gain values. + + Args: + ---- + values: A list of tuples representing rank order and gain values. + k: The maximum position to consider in calculating nDCG + + Returns: + ------- + The value of nDCG@k, for the given list of values. + """ + max_rank = min(k, len(values)) + ranked_values = [ + gain for _, gain in sorted(values, key=lambda x: x[0], reverse=False) + ] + optimal_values = [ + gain for _, gain in sorted(values, key=lambda x: x[1], reverse=True) + ] + dcg = self._calculate_dcg_at_k(max_rank, ranked_values) + optimal_dcg = self._calculate_dcg_at_k(max_rank, optimal_values) + if optimal_dcg > 0: + return dcg / optimal_dcg + else: + return 0 + + def _new_ndcg_dict(self): + return dict.fromkeys(self._at_vals, 0) + + def create_accumulator(self): + return _State(ndcg=self._new_ndcg_dict(), weight=0.0) + + def _add_states(self, left: _State, right: _State) -> _State: + ndcg_dict = self._new_ndcg_dict() + for at in self._at_vals: + ndcg_dict[at] = left.ndcg[at] + right.ndcg[at] + return _State(ndcg_dict, left.weight + right.weight) + + def add_input(self, accumulator: _State, query_fpl: query_types.QueryFPL) -> _State: + weight = 1.0 + if self._weight_key: + weights = [ + float(_get_feature_value(fpl, self._weight_key)) + for fpl in query_fpl.fpls + ] + if weights: + if min(weights) != max(weights): + raise ValueError( + "weights were not identical for all examples in the " + "query. query_id was: %s, weights were: %s" + % (query_fpl.query_id, weights) + ) + weight = weights[0] + + ndcg_dict = {} + for at in self._at_vals: + rank_gain = [ + (pos + 1, float(_get_feature_value(fpl, self._gain_key))) + for pos, fpl in enumerate(query_fpl.fpls) + ] + ndcg_dict[at] = self._calculate_ndcg(rank_gain, at) * weight + + return self._add_states(accumulator, _State(ndcg=ndcg_dict, weight=weight)) + + def merge_accumulators(self, accumulators: Iterable[_State]) -> _State: + accumulators = iter(accumulators) + result = next(accumulators) + for accumulator in accumulators: + result = self._add_states(result, accumulator) + return result + + def extract_output(self, accumulator: _State) -> Dict[str, Any]: + avg_dict = {} + for at in self._at_vals: + if accumulator.weight > 0: + avg_ndcg = accumulator.ndcg[at] / accumulator.weight + else: + avg_ndcg = 0 + avg_dict[metric_keys.base_key("ndcg@%d" % at)] = avg_ndcg + return avg_dict diff --git a/tensorflow_model_analysis/evaluators/query_metrics/query_statistics.py b/tensorflow_model_analysis/evaluators/query_metrics/query_statistics.py index 1bd2eeede0..f54eb553e9 100644 --- a/tensorflow_model_analysis/evaluators/query_metrics/query_statistics.py +++ b/tensorflow_model_analysis/evaluators/query_metrics/query_statistics.py @@ -17,62 +17,61 @@ from typing import Any, Dict, Iterable import apache_beam as beam + from tensorflow_model_analysis.evaluators.query_metrics import query_types from tensorflow_model_analysis.post_export_metrics import metric_keys @dataclasses.dataclass class _State: - """QueryStatisticsCombineFn accumulator type.""" + """QueryStatisticsCombineFn accumulator type.""" - total_queries: int - total_documents: int - min_documents: int - max_documents: int + total_queries: int + total_documents: int + min_documents: int + max_documents: int - def merge(self, other: '_State') -> None: - self.total_queries += other.total_queries - self.total_documents += other.total_documents - self.min_documents = min(self.min_documents, other.min_documents) - self.max_documents = max(self.max_documents, other.max_documents) + def merge(self, other: "_State") -> None: + self.total_queries += other.total_queries + self.total_documents += other.total_documents + self.min_documents = min(self.min_documents, other.min_documents) + self.max_documents = max(self.max_documents, other.max_documents) - def add(self, query_fpl: query_types.QueryFPL) -> None: - self.total_queries += 1 - self.total_documents += len(query_fpl.fpls) - self.min_documents = min(self.min_documents, len(query_fpl.fpls)) - self.max_documents = max(self.max_documents, len(query_fpl.fpls)) + def add(self, query_fpl: query_types.QueryFPL) -> None: + self.total_queries += 1 + self.total_documents += len(query_fpl.fpls) + self.min_documents = min(self.min_documents, len(query_fpl.fpls)) + self.max_documents = max(self.max_documents, len(query_fpl.fpls)) class QueryStatisticsCombineFn(beam.CombineFn): - """Computes simple statistics about queries.""" + """Computes simple statistics about queries.""" - LARGE_INT = 1000000000 + LARGE_INT = 1000000000 - def create_accumulator(self): - return _State( - total_queries=0, - total_documents=0, - min_documents=self.LARGE_INT, - max_documents=0, - ) + def create_accumulator(self): + return _State( + total_queries=0, + total_documents=0, + min_documents=self.LARGE_INT, + max_documents=0, + ) - def add_input( - self, accumulator: _State, query_fpl: query_types.QueryFPL - ) -> _State: - accumulator.add(query_fpl) - return accumulator + def add_input(self, accumulator: _State, query_fpl: query_types.QueryFPL) -> _State: + accumulator.add(query_fpl) + return accumulator - def merge_accumulators(self, accumulators: Iterable[_State]) -> _State: - it = iter(accumulators) - result = next(it) - for acc in it: - result.merge(acc) - return result + def merge_accumulators(self, accumulators: Iterable[_State]) -> _State: + it = iter(accumulators) + result = next(it) + for acc in it: + result.merge(acc) + return result - def extract_output(self, accumulator: _State) -> Dict[str, Any]: - return { - metric_keys.base_key('total_queries'): accumulator.total_queries, - metric_keys.base_key('total_documents'): accumulator.total_documents, - metric_keys.base_key('min_documents'): accumulator.min_documents, - metric_keys.base_key('max_documents'): accumulator.max_documents, - } + def extract_output(self, accumulator: _State) -> Dict[str, Any]: + return { + metric_keys.base_key("total_queries"): accumulator.total_queries, + metric_keys.base_key("total_documents"): accumulator.total_documents, + metric_keys.base_key("min_documents"): accumulator.min_documents, + metric_keys.base_key("max_documents"): accumulator.max_documents, + } diff --git a/tensorflow_model_analysis/evaluators/query_metrics/query_types.py b/tensorflow_model_analysis/evaluators/query_metrics/query_types.py index ba458aae76..7300900344 100644 --- a/tensorflow_model_analysis/evaluators/query_metrics/query_types.py +++ b/tensorflow_model_analysis/evaluators/query_metrics/query_types.py @@ -19,4 +19,8 @@ # Should contain features, predictions, labels FPL = Dict[str, types.DictOfTensorValue] -QueryFPL = NamedTuple('QueryFPL', [('fpls', List[FPL]), ('query_id', str)]) + + +class QueryFPL(NamedTuple): + fpls: List[FPL] + query_id: str diff --git a/tensorflow_model_analysis/evaluators/testing/confidence_interval_validation.py b/tensorflow_model_analysis/evaluators/testing/confidence_interval_validation.py index aea3ee1414..ed264e7ed6 100644 --- a/tensorflow_model_analysis/evaluators/testing/confidence_interval_validation.py +++ b/tensorflow_model_analysis/evaluators/testing/confidence_interval_validation.py @@ -17,74 +17,69 @@ import os from typing import Callable, Dict, Iterable, Iterator, List, Sequence, Tuple -from absl import app -from absl import flags import apache_beam as beam import numpy as np import tensorflow as tf +from absl import app, flags +from google.protobuf import text_format + +# from tensorflow.core.protobuf import saver_pb2 +from tensorflow.core.example import example_pb2 +from tfx_bsl.tfxio import tf_example_record + from tensorflow_model_analysis import constants from tensorflow_model_analysis.api import model_eval_lib from tensorflow_model_analysis.eval_saved_model import util from tensorflow_model_analysis.metrics import metric_types -from tensorflow_model_analysis.proto import config_pb2 -from tensorflow_model_analysis.proto import metrics_for_slice_pb2 -from tfx_bsl.tfxio import tf_example_record - -from google.protobuf import text_format -# from tensorflow.core.protobuf import saver_pb2 -from tensorflow.core.example import example_pb2 +from tensorflow_model_analysis.proto import config_pb2, metrics_for_slice_pb2 -_BINARY_CLASSIFICATION_SCENARIO = 'BINARY_CLASSIFICATION' -_REGRESSION_SCENARIO = 'REGRESSION' +_BINARY_CLASSIFICATION_SCENARIO = "BINARY_CLASSIFICATION" +_REGRESSION_SCENARIO = "REGRESSION" _SCENARIOS = [_BINARY_CLASSIFICATION_SCENARIO, _REGRESSION_SCENARIO] flags.DEFINE_enum( - 'scenario', + "scenario", None, _SCENARIOS, - 'The scenario to validate, where the ' - 'scenario encodes the task type and example generation logic', + "The scenario to validate, where the " + "scenario encodes the task type and example generation logic", ) flags.DEFINE_enum( - 'methodology', + "methodology", None, - ['JACKKNIFE', 'POISSON_BOOTSTRAP'], - 'The CI methodology to use', + ["JACKKNIFE", "POISSON_BOOTSTRAP"], + "The CI methodology to use", ) flags.DEFINE_integer( - 'num_trials', + "num_trials", None, - 'number of datasets to generate and TFMA runs to perform', + "number of datasets to generate and TFMA runs to perform", lower_bound=0, ) flags.DEFINE_integer( - 'num_examples_per_trial', + "num_examples_per_trial", None, - 'number of examples to generate in each trial dataset', + "number of examples to generate in each trial dataset", lower_bound=0, ) +flags.DEFINE_string("output_dir", None, "existing dir in which to write results") flags.DEFINE_string( - 'output_dir', None, 'existing dir in which to write results' -) -flags.DEFINE_string( - 'pipeline_options', - '', - 'Command line flags to use in constructing the Beam pipeline options. ' + "pipeline_options", + "", + "Command line flags to use in constructing the Beam pipeline options. " 'For example, "--runner=DirectRunner,--streaming=True"', ) FLAGS = flags.FLAGS _ExampleGeneratorType = Callable[[int], Iterable[example_pb2.Example]] _CIType = Tuple[float, float] -_POPULATION_OUTPUT_NAME = 'population' +_POPULATION_OUTPUT_NAME = "population" -def get_regression_scenario() -> ( - Tuple[config_pb2.EvalConfig, _ExampleGeneratorType] -): - """Returns an EvalConfig and example generator for regression.""" - eval_config = text_format.Parse( - """ +def get_regression_scenario() -> Tuple[config_pb2.EvalConfig, _ExampleGeneratorType]: + """Returns an EvalConfig and example generator for regression.""" + eval_config = text_format.Parse( + """ model_specs { label_key: "label" prediction_key: "prediction" @@ -95,26 +90,26 @@ def get_regression_scenario() -> ( metrics { class_name: "Calibration" } } """, - config_pb2.EvalConfig(), - ) + config_pb2.EvalConfig(), + ) - def generate_regression_examples( - num_examples, - ) -> Iterator[example_pb2.Example]: - for _ in range(num_examples): - yield util.make_example( - label=float(np.random.random()), prediction=float(np.random.uniform()) - ) + def generate_regression_examples( + num_examples, + ) -> Iterator[example_pb2.Example]: + for _ in range(num_examples): + yield util.make_example( + label=float(np.random.random()), prediction=float(np.random.uniform()) + ) - return eval_config, generate_regression_examples + return eval_config, generate_regression_examples def get_binary_classification_scenario() -> ( Tuple[config_pb2.EvalConfig, _ExampleGeneratorType] ): - """Returns an EvalConfig and example generator for binary classification.""" - eval_config = text_format.Parse( - """ + """Returns an EvalConfig and example generator for binary classification.""" + eval_config = text_format.Parse( + """ model_specs { label_key: "label" prediction_key: "prediction" @@ -125,19 +120,19 @@ def get_binary_classification_scenario() -> ( metrics { class_name: "BinaryPrecision" } } """, - config_pb2.EvalConfig(), - ) + config_pb2.EvalConfig(), + ) - def generate_classification_examples( - num_examples, - ) -> Iterator[example_pb2.Example]: - for _ in range(num_examples): - yield util.make_example( - label=float(np.random.choice([0, 1])), - prediction=float(np.random.uniform()), - ) + def generate_classification_examples( + num_examples, + ) -> Iterator[example_pb2.Example]: + for _ in range(num_examples): + yield util.make_example( + label=float(np.random.choice([0, 1])), + prediction=float(np.random.uniform()), + ) - return eval_config, generate_classification_examples + return eval_config, generate_classification_examples def compute_cis( @@ -147,92 +142,90 @@ def compute_cis( num_examples_per_trial: int, output_dir: str, ) -> None: - """Computes a collection of CIs and the population values for a scenario.""" - if scenario == _BINARY_CLASSIFICATION_SCENARIO: - eval_config, example_gen_fn = get_binary_classification_scenario() - elif scenario == _REGRESSION_SCENARIO: - eval_config, example_gen_fn = get_regression_scenario() - else: - raise ValueError( - f'Unexpected scenario {scenario}. Expected one of {_SCENARIOS}' + """Computes a collection of CIs and the population values for a scenario.""" + if scenario == _BINARY_CLASSIFICATION_SCENARIO: + eval_config, example_gen_fn = get_binary_classification_scenario() + elif scenario == _REGRESSION_SCENARIO: + eval_config, example_gen_fn = get_regression_scenario() + else: + raise ValueError( + f"Unexpected scenario {scenario}. Expected one of {_SCENARIOS}" + ) + eval_config.options.compute_confidence_intervals.value = True + eval_config.options.confidence_intervals.method = ( + config_pb2.ConfidenceIntervalOptions.ConfidenceIntervalMethod.Value(methodology) ) - eval_config.options.compute_confidence_intervals.value = True - eval_config.options.confidence_intervals.method = ( - config_pb2.ConfidenceIntervalOptions.ConfidenceIntervalMethod.Value( - methodology - ) - ) - pipeline_options = beam.options.pipeline_options.PipelineOptions( - FLAGS.pipeline_options.split(',') - ) - with beam.Pipeline(options=pipeline_options) as pipeline: - tfx_io = tf_example_record.TFExampleBeamRecord( - physical_format='generated', - raw_record_column_name=constants.ARROW_INPUT_COLUMN, + pipeline_options = beam.options.pipeline_options.PipelineOptions( + FLAGS.pipeline_options.split(",") ) - inputs_per_trial = [] - for i in range(num_trials): - inputs = ( - pipeline - | f'CreateExamples[{i}]' - >> beam.Create(example_gen_fn(num_examples_per_trial)) - | f'Serialize[{i}]' - >> beam.Map(lambda example: example.SerializeToString()) - | f'BatchExamples[{i}]' >> tfx_io.BeamSource() - ) - inputs_per_trial.append(inputs) + with beam.Pipeline(options=pipeline_options) as pipeline: + tfx_io = tf_example_record.TFExampleBeamRecord( + physical_format="generated", + raw_record_column_name=constants.ARROW_INPUT_COLUMN, + ) + inputs_per_trial = [] + for i in range(num_trials): + inputs = ( + pipeline + | f"CreateExamples[{i}]" + >> beam.Create(example_gen_fn(num_examples_per_trial)) + | f"Serialize[{i}]" + >> beam.Map(lambda example: example.SerializeToString()) + | f"BatchExamples[{i}]" >> tfx_io.BeamSource() + ) + inputs_per_trial.append(inputs) - trial_output_dir = os.path.join(output_dir, str(i)) - _ = ( - inputs - | f'Evaluate[{i}]' - >> model_eval_lib.ExtractEvaluateAndWriteResults( - eval_config=eval_config, output_path=trial_output_dir - ) - ) - population_output_dir = os.path.join(output_dir, _POPULATION_OUTPUT_NAME) - _ = ( - inputs_per_trial - | 'FlattenInputs' >> beam.Flatten() - | 'EvaluatePopulation' - >> model_eval_lib.ExtractEvaluateAndWriteResults( - eval_config=eval_config, output_path=population_output_dir + trial_output_dir = os.path.join(output_dir, str(i)) + _ = ( + inputs + | f"Evaluate[{i}]" + >> model_eval_lib.ExtractEvaluateAndWriteResults( + eval_config=eval_config, output_path=trial_output_dir + ) + ) + population_output_dir = os.path.join(output_dir, _POPULATION_OUTPUT_NAME) + _ = ( + inputs_per_trial + | "FlattenInputs" >> beam.Flatten() + | "EvaluatePopulation" + >> model_eval_lib.ExtractEvaluateAndWriteResults( + eval_config=eval_config, output_path=population_output_dir + ) ) - ) def load_point_estimates( trial_output_dir: str, ) -> Dict[metric_types.MetricKey, float]: - """Loads the point estimates for each metric in a TFMA run.""" - population_values = {} - path = os.path.join(trial_output_dir, 'metrics') - for rec in tf.compat.v1.io.tf_record_iterator(path): - metrics_for_slice = metrics_for_slice_pb2.MetricsForSlice.FromString(rec) - for kv in metrics_for_slice.metric_keys_and_values: - if kv.value.WhichOneof('type') != 'double_value': - continue - population_values[metric_types.MetricKey.from_proto(kv.key)] = ( - kv.value.double_value.value - ) - return population_values + """Loads the point estimates for each metric in a TFMA run.""" + population_values = {} + path = os.path.join(trial_output_dir, "metrics") + for rec in tf.compat.v1.io.tf_record_iterator(path): + metrics_for_slice = metrics_for_slice_pb2.MetricsForSlice.FromString(rec) + for kv in metrics_for_slice.metric_keys_and_values: + if kv.value.WhichOneof("type") != "double_value": + continue + population_values[metric_types.MetricKey.from_proto(kv.key)] = ( + kv.value.double_value.value + ) + return population_values def load_trial_cis( trial_output_dir: str, ) -> Dict[metric_types.MetricKey, _CIType]: - """Loads the CI (lower, upper) for each metric in a TFMA run.""" - trial_cis = {} - path = os.path.join(trial_output_dir, 'metrics') - for rec in tf.compat.v1.io.tf_record_iterator(path): - metrics_for_slice = metrics_for_slice_pb2.MetricsForSlice.FromString(rec) - for kv in metrics_for_slice.metric_keys_and_values: - if kv.value.WhichOneof('type') != 'double_value': - continue - lower = kv.confidence_interval.lower_bound.double_value.value - upper = kv.confidence_interval.upper_bound.double_value.value - trial_cis[metric_types.MetricKey.from_proto(kv.key)] = (lower, upper) - return trial_cis + """Loads the CI (lower, upper) for each metric in a TFMA run.""" + trial_cis = {} + path = os.path.join(trial_output_dir, "metrics") + for rec in tf.compat.v1.io.tf_record_iterator(path): + metrics_for_slice = metrics_for_slice_pb2.MetricsForSlice.FromString(rec) + for kv in metrics_for_slice.metric_keys_and_values: + if kv.value.WhichOneof("type") != "double_value": + continue + lower = kv.confidence_interval.lower_bound.double_value.value + upper = kv.confidence_interval.upper_bound.double_value.value + trial_cis[metric_types.MetricKey.from_proto(kv.key)] = (lower, upper) + return trial_cis def load_cis( @@ -241,48 +234,48 @@ def load_cis( Dict[metric_types.MetricKey, List[_CIType]], Dict[metric_types.MetricKey, float], ]: - """Loads the population point estimates and trial CIs from TFMA runs.""" - population_values = load_point_estimates( - os.path.join(output_dir, _POPULATION_OUTPUT_NAME) - ) - trials_cis = collections.defaultdict(list) - pattern = os.path.join(output_dir, 'trial-*') - for trial_path in tf.io.gfile.glob(pattern): - trial_cis = load_trial_cis(trial_path) - for key, ci in trial_cis.items(): - trials_cis[key].append(ci) - return trials_cis, population_values + """Loads the population point estimates and trial CIs from TFMA runs.""" + population_values = load_point_estimates( + os.path.join(output_dir, _POPULATION_OUTPUT_NAME) + ) + trials_cis = collections.defaultdict(list) + pattern = os.path.join(output_dir, "trial-*") + for trial_path in tf.io.gfile.glob(pattern): + trial_cis = load_trial_cis(trial_path) + for key, ci in trial_cis.items(): + trials_cis[key].append(ci) + return trials_cis, population_values def compute_coverage(output_dir: str) -> Dict[metric_types.MetricKey, float]: - """Computes the per-metric CI coverage fraction.""" - trial_cis, population_values = load_cis(output_dir) - coverage_counts = collections.defaultdict(int) - for metric_name, cis in trial_cis.items(): - for lower, upper in cis: - coverage_counts[metric_name] += int( - population_values[metric_name] >= lower - and population_values[metric_name] <= upper - ) + """Computes the per-metric CI coverage fraction.""" + trial_cis, population_values = load_cis(output_dir) + coverage_counts = collections.defaultdict(int) + for metric_name, cis in trial_cis.items(): + for lower, upper in cis: + coverage_counts[metric_name] += int( + population_values[metric_name] >= lower + and population_values[metric_name] <= upper + ) - coverage_rates = { - k: count / len(trial_cis[k]) for k, count in coverage_counts.items() - } - return coverage_rates + coverage_rates = { + k: count / len(trial_cis[k]) for k, count in coverage_counts.items() + } + return coverage_rates def main(argv: Sequence[str]) -> None: - del argv - compute_cis( - scenario=FLAGS.scenario, - methodology=FLAGS.methodology, - num_trials=FLAGS.num_trials, - num_examples_per_trial=FLAGS.num_examples_per_trial, - output_dir=FLAGS.output_dir, - ) - coverage_rates = compute_coverage(FLAGS.output_dir) - print(coverage_rates) + del argv + compute_cis( + scenario=FLAGS.scenario, + methodology=FLAGS.methodology, + num_trials=FLAGS.num_trials, + num_examples_per_trial=FLAGS.num_examples_per_trial, + output_dir=FLAGS.output_dir, + ) + coverage_rates = compute_coverage(FLAGS.output_dir) + print(coverage_rates) -if __name__ == '__main__': - app.run(main) +if __name__ == "__main__": + app.run(main) diff --git a/tensorflow_model_analysis/experimental/dataframe.py b/tensorflow_model_analysis/experimental/dataframe.py index be2bfa19ba..995a370b5f 100644 --- a/tensorflow_model_analysis/experimental/dataframe.py +++ b/tensorflow_model_analysis/experimental/dataframe.py @@ -14,9 +14,10 @@ """Pandas DataFrame utils for Tensorflow Model Analysis.""" from absl import logging + from tensorflow_model_analysis.api.dataframe import * # pylint: disable=wildcard-import logging.warn( - 'tfma.experimental.dataframe is moved to core library, use tfma.dataframe' - ' instead.' + "tfma.experimental.dataframe is moved to core library, use tfma.dataframe" + " instead." ) diff --git a/tensorflow_model_analysis/experimental/preprocessing_functions/__init__.py b/tensorflow_model_analysis/experimental/preprocessing_functions/__init__.py index 3e16aa879b..057c6e49c4 100644 --- a/tensorflow_model_analysis/experimental/preprocessing_functions/__init__.py +++ b/tensorflow_model_analysis/experimental/preprocessing_functions/__init__.py @@ -16,5 +16,9 @@ # pylint: disable=unused-import # pylint: disable=g-bad-import-order -from tensorflow_model_analysis.experimental.preprocessing_functions.test_util import _plus_one -from tensorflow_model_analysis.experimental.preprocessing_functions.text import whitespace_tokenization +from tensorflow_model_analysis.experimental.preprocessing_functions.test_util import ( + _plus_one, +) +from tensorflow_model_analysis.experimental.preprocessing_functions.text import ( + whitespace_tokenization, +) diff --git a/tensorflow_model_analysis/experimental/preprocessing_functions/test_util.py b/tensorflow_model_analysis/experimental/preprocessing_functions/test_util.py index df030aeae7..e81a61f343 100644 --- a/tensorflow_model_analysis/experimental/preprocessing_functions/test_util.py +++ b/tensorflow_model_analysis/experimental/preprocessing_functions/test_util.py @@ -19,4 +19,4 @@ # A simple plus one mostly for testing. @tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.float32)]) def _plus_one(input_data): - return input_data + 1 + return input_data + 1 diff --git a/tensorflow_model_analysis/experimental/preprocessing_functions/text.py b/tensorflow_model_analysis/experimental/preprocessing_functions/text.py index 249197eb8f..24c8d66597 100644 --- a/tensorflow_model_analysis/experimental/preprocessing_functions/text.py +++ b/tensorflow_model_analysis/experimental/preprocessing_functions/text.py @@ -15,6 +15,7 @@ import re import string + import tensorflow as tf _ESCAPED_PUNCTUATIONS = re.escape(string.punctuation) @@ -22,8 +23,8 @@ @tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.string)]) def whitespace_tokenization(input_data): - standardized = tf.strings.regex_replace( - tf.strings.lower(input_data), '[%s]' % _ESCAPED_PUNCTUATIONS, '' - ) - tokens = tf.strings.split(standardized) - return tf.map_fn(fn=lambda t: tf.unique(t)[0], elems=tokens) + standardized = tf.strings.regex_replace( + tf.strings.lower(input_data), "[%s]" % _ESCAPED_PUNCTUATIONS, "" + ) + tokens = tf.strings.split(standardized) + return tf.map_fn(fn=lambda t: tf.unique(t)[0], elems=tokens) diff --git a/tensorflow_model_analysis/experimental/preprocessing_functions/text_test.py b/tensorflow_model_analysis/experimental/preprocessing_functions/text_test.py index 519af48f46..d3fceb15ca 100644 --- a/tensorflow_model_analysis/experimental/preprocessing_functions/text_test.py +++ b/tensorflow_model_analysis/experimental/preprocessing_functions/text_test.py @@ -12,31 +12,32 @@ # See the License for the specific language governing permissions and # limitations under the License. """Tests for text_util.""" -from absl.testing import parameterized + import tensorflow as tf +from absl.testing import parameterized + from tensorflow_model_analysis.experimental.preprocessing_functions import text class TextTest(tf.test.TestCase, parameterized.TestCase): + @parameterized.named_parameters( + ("EmptyString", [""], [[]]), + ("SingleString", ["Test foo Bar"], [["test", "foo", "bar"]]), + ( + "BatchedString", + ["app dog", "test foo bar"], + [["app", "dog", ""], ["test", "foo", "bar"]], + ), + ) + def testWhitespaceTokenization(self, input_text, expected_output): + # TODO(b/194508683) Delete the check when TF1 is deprecated. + if tf.__version__ < "2": + return - @parameterized.named_parameters( - ('EmptyString', [''], [[]]), - ('SingleString', ['Test foo Bar'], [['test', 'foo', 'bar']]), - ( - 'BatchedString', - ['app dog', 'test foo bar'], - [['app', 'dog', ''], ['test', 'foo', 'bar']], - ), - ) - def testWhitespaceTokenization(self, input_text, expected_output): - # TODO(b/194508683) Delete the check when TF1 is deprecated. - if tf.__version__ < '2': - return - - actual = text.whitespace_tokenization(input_text).to_tensor() - expected = tf.constant(expected_output, dtype=tf.string) - self.assertAllEqual(actual, expected) + actual = text.whitespace_tokenization(input_text).to_tensor() + expected = tf.constant(expected_output, dtype=tf.string) + self.assertAllEqual(actual, expected) -if __name__ == '__main__': - tf.test.main() +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow_model_analysis/export_only/__init__.py b/tensorflow_model_analysis/export_only/__init__.py index e4ab459106..d94e15d05c 100644 --- a/tensorflow_model_analysis/export_only/__init__.py +++ b/tensorflow_model_analysis/export_only/__init__.py @@ -28,4 +28,3 @@ def eval_input_receiver_fn(): tfma_export.export.export_eval_saved_model(...) """ - diff --git a/tensorflow_model_analysis/extractors/__init__.py b/tensorflow_model_analysis/extractors/__init__.py index 171684df43..cfa5f81250 100644 --- a/tensorflow_model_analysis/extractors/__init__.py +++ b/tensorflow_model_analysis/extractors/__init__.py @@ -13,33 +13,46 @@ # limitations under the License. """Init module for TensorFlow Model Analysis extractors.""" -from tensorflow_model_analysis.extractors import legacy_meta_feature_extractor as meta_feature_extractor -from tensorflow_model_analysis.extractors.example_weights_extractor import ExampleWeightsExtractor -from tensorflow_model_analysis.extractors.extractor import Extractor -from tensorflow_model_analysis.extractors.extractor import Filter +from tensorflow_model_analysis.extractors import ( + legacy_meta_feature_extractor as meta_feature_extractor, +) +from tensorflow_model_analysis.extractors.example_weights_extractor import ( + ExampleWeightsExtractor, +) +from tensorflow_model_analysis.extractors.extractor import Extractor, Filter from tensorflow_model_analysis.extractors.features_extractor import FeaturesExtractor from tensorflow_model_analysis.extractors.labels_extractor import LabelsExtractor -from tensorflow_model_analysis.extractors.legacy_feature_extractor import FeatureExtractor +from tensorflow_model_analysis.extractors.legacy_feature_extractor import ( + FeatureExtractor, +) from tensorflow_model_analysis.extractors.legacy_input_extractor import InputExtractor -from tensorflow_model_analysis.extractors.legacy_predict_extractor import PredictExtractor -from tensorflow_model_analysis.extractors.predictions_extractor import PredictionsExtractor -from tensorflow_model_analysis.extractors.slice_key_extractor import SLICE_KEY_EXTRACTOR_STAGE_NAME -from tensorflow_model_analysis.extractors.slice_key_extractor import SliceKeyExtractor -from tensorflow_model_analysis.extractors.transformed_features_extractor import TransformedFeaturesExtractor +from tensorflow_model_analysis.extractors.legacy_predict_extractor import ( + PredictExtractor, +) +from tensorflow_model_analysis.extractors.predictions_extractor import ( + PredictionsExtractor, +) +from tensorflow_model_analysis.extractors.slice_key_extractor import ( + SLICE_KEY_EXTRACTOR_STAGE_NAME, + SliceKeyExtractor, +) +from tensorflow_model_analysis.extractors.transformed_features_extractor import ( + TransformedFeaturesExtractor, +) from tensorflow_model_analysis.extractors.unbatch_extractor import UnbatchExtractor __all__ = [ - "ExampleWeightsExtractor", - "Extractor", - "FeatureExtractor", - "FeaturesExtractor", - "Filter", - "InputExtractor", - "LabelsExtractor", - "PredictExtractor", - "PredictionsExtractor", - "SliceKeyExtractor", - "SLICE_KEY_EXTRACTOR_STAGE_NAME", - "TransformedFeaturesExtractor", - "UnbatchExtractor", + "ExampleWeightsExtractor", + "Extractor", + "FeatureExtractor", + "FeaturesExtractor", + "Filter", + "InputExtractor", + "LabelsExtractor", + "PredictExtractor", + "PredictionsExtractor", + "SliceKeyExtractor", + "SLICE_KEY_EXTRACTOR_STAGE_NAME", + "TransformedFeaturesExtractor", + "UnbatchExtractor", ] diff --git a/tensorflow_model_analysis/extractors/counterfactual_predictions_extractor.py b/tensorflow_model_analysis/extractors/counterfactual_predictions_extractor.py index 65460f0a80..14915cb22f 100644 --- a/tensorflow_model_analysis/extractors/counterfactual_predictions_extractor.py +++ b/tensorflow_model_analysis/extractors/counterfactual_predictions_extractor.py @@ -18,20 +18,18 @@ import apache_beam as beam import numpy as np import tensorflow as tf + from tensorflow_model_analysis import constants from tensorflow_model_analysis.api import types -from tensorflow_model_analysis.extractors import extractor -from tensorflow_model_analysis.extractors import predictions_extractor +from tensorflow_model_analysis.extractors import extractor, predictions_extractor from tensorflow_model_analysis.proto import config_pb2 from tensorflow_model_analysis.utils import model_util _SUPPORTED_MODEL_TYPES = frozenset([constants.TF_KERAS, constants.TF_GENERIC]) -_COUNTERFACTUAL_PREDICTIONS_EXTRACTOR_NAME = ( - 'CounterfactualPredictionsExtractor' -) +_COUNTERFACTUAL_PREDICTIONS_EXTRACTOR_NAME = "CounterfactualPredictionsExtractor" # The extracts key under which the non-CF INPUT_KEY value is temporarily stored, # when invoking one or more PredictionsExtractors on modified inputs. -_TEMP_ORIG_INPUT_KEY = 'non_counterfactual_input' +_TEMP_ORIG_INPUT_KEY = "non_counterfactual_input" CounterfactualConfig = Dict[str, str] @@ -40,220 +38,230 @@ def CounterfactualPredictionsExtractor( # pylint: disable=invalid-name eval_config: config_pb2.EvalConfig, cf_configs: Mapping[str, CounterfactualConfig], ) -> extractor.Extractor: - """Creates a CF predictions extractor by wrapping the PredictionsExtractor. - - Example usage: - - eval_config = tfma.EvalConfig(model_specs=[ - tfma.ModelSpec(name='orig', is_baseline=True), - tfma.ModelSpec(name='cf')]) - eval_shared_models = { - 'orig': eval_shared_model, - 'cf' eval_shared_model} - cf_configs = {'cf': {'x_cf': 'x'}} - extractors = tfma.default_extractors(eval_shared_models, eval_config, - custom_predictions_extractor=CounterfactualPredictionsExtractor( - eval_shared_models,cf_configs)) - tfma.run_model_analysis(eval_shared_models, eval_config, - extractors=etractors) - - Args: - eval_shared_models: The set of eval_shared_models for which to generate - predictions. If a model is to be computed with original inputs and CF - inputs, it should be provided twice, with distinct names. The name of the - model to be computed with CF inputs should match the name provided in - cf_configs as well as the ModelSpec.name in the provided EvalConfig. - eval_config: The EvalConfig for this evaluation. If a model is to be - computed with original inputs and CF inputs, it should correspond to two - ModelSpecs with distinct names. The CF model name should match the name - provided in cf_configs as well as the EvalSharedModel.model_name. - cf_configs: A mapping from a model name to the CF config which should be - used to preprocess its inputs. Any models in eval_shared_models not - specified will have their predictions computed on the original input - - Returns: - A tfma.Extractor which performs counterfactual inference along with non- - counterfactual inference. - - Raises: - ValueError if eval_shared_models is empty. - """ - eval_shared_models, cf_configs = _validate_and_update_models_and_configs( - eval_shared_models, cf_configs - ) - cf_ptransforms = {} - non_cf_models = [] - for model in eval_shared_models: - cf_config = cf_configs.get(model.model_name, None) - if cf_config: - # filter EvalConfig so that it matches single EvalSavedModel - cf_eval_config = _filter_model_specs(eval_config, [model]) - # TODO(b/258850519): Refactor default_extractors logic to expose new api - # for constructing the default predictions extractor and call it here. - predictions_ptransform = predictions_extractor.PredictionsExtractor( - eval_shared_model=model, - eval_config=cf_eval_config, - output_keypath=(constants.PREDICTIONS_KEY, model.model_name), - ).ptransform - cf_ptransforms[model.model_name] = _ExtractCounterfactualPredictions( # pylint: disable=no-value-for-parameter - config=cf_config, predictions_ptransform=predictions_ptransform - ) + """Creates a CF predictions extractor by wrapping the PredictionsExtractor. + + Example usage: + + eval_config = tfma.EvalConfig(model_specs=[ + tfma.ModelSpec(name='orig', is_baseline=True), + tfma.ModelSpec(name='cf')]) + eval_shared_models = { + 'orig': eval_shared_model, + 'cf' eval_shared_model} + cf_configs = {'cf': {'x_cf': 'x'}} + extractors = tfma.default_extractors(eval_shared_models, eval_config, + custom_predictions_extractor=CounterfactualPredictionsExtractor( + eval_shared_models,cf_configs)) + tfma.run_model_analysis(eval_shared_models, eval_config, + extractors=etractors) + + Args: + ---- + eval_shared_models: The set of eval_shared_models for which to generate + predictions. If a model is to be computed with original inputs and CF + inputs, it should be provided twice, with distinct names. The name of the + model to be computed with CF inputs should match the name provided in + cf_configs as well as the ModelSpec.name in the provided EvalConfig. + eval_config: The EvalConfig for this evaluation. If a model is to be + computed with original inputs and CF inputs, it should correspond to two + ModelSpecs with distinct names. The CF model name should match the name + provided in cf_configs as well as the EvalSharedModel.model_name. + cf_configs: A mapping from a model name to the CF config which should be + used to preprocess its inputs. Any models in eval_shared_models not + specified will have their predictions computed on the original input + + Returns: + ------- + A tfma.Extractor which performs counterfactual inference along with non- + counterfactual inference. + + Raises: + ------ + ValueError if eval_shared_models is empty. + """ + eval_shared_models, cf_configs = _validate_and_update_models_and_configs( + eval_shared_models, cf_configs + ) + cf_ptransforms = {} + non_cf_models = [] + for model in eval_shared_models: + cf_config = cf_configs.get(model.model_name, None) + if cf_config: + # filter EvalConfig so that it matches single EvalSavedModel + cf_eval_config = _filter_model_specs(eval_config, [model]) + # TODO(b/258850519): Refactor default_extractors logic to expose new api + # for constructing the default predictions extractor and call it here. + predictions_ptransform = predictions_extractor.PredictionsExtractor( + eval_shared_model=model, + eval_config=cf_eval_config, + output_keypath=(constants.PREDICTIONS_KEY, model.model_name), + ).ptransform + cf_ptransforms[model.model_name] = _ExtractCounterfactualPredictions( # pylint: disable=no-value-for-parameter + config=cf_config, predictions_ptransform=predictions_ptransform + ) + else: + non_cf_models.append(model) + non_cf_eval_config = _filter_model_specs(eval_config, non_cf_models) + if non_cf_models: + output_keypath = (constants.PREDICTIONS_KEY,) + if len(non_cf_models) == 1: + output_keypath = output_keypath + (non_cf_models[0].model_name,) + non_cf_ptransform = predictions_extractor.PredictionsExtractor( + eval_shared_model=non_cf_models, + eval_config=non_cf_eval_config, + output_keypath=output_keypath, + ).ptransform else: - non_cf_models.append(model) - non_cf_eval_config = _filter_model_specs(eval_config, non_cf_models) - if non_cf_models: - output_keypath = (constants.PREDICTIONS_KEY,) - if len(non_cf_models) == 1: - output_keypath = output_keypath + (non_cf_models[0].model_name,) - non_cf_ptransform = predictions_extractor.PredictionsExtractor( - eval_shared_model=non_cf_models, - eval_config=non_cf_eval_config, - output_keypath=output_keypath, - ).ptransform - else: - non_cf_ptransform = None - return extractor.Extractor( - stage_name=_COUNTERFACTUAL_PREDICTIONS_EXTRACTOR_NAME, - ptransform=_ExtractPredictions( # pylint: disable=no-value-for-parameter - cf_ptransforms=cf_ptransforms, non_cf_ptransform=non_cf_ptransform - ), - ) + non_cf_ptransform = None + return extractor.Extractor( + stage_name=_COUNTERFACTUAL_PREDICTIONS_EXTRACTOR_NAME, + ptransform=_ExtractPredictions( # pylint: disable=no-value-for-parameter + cf_ptransforms=cf_ptransforms, non_cf_ptransform=non_cf_ptransform + ), + ) def _validate_and_update_models_and_configs( eval_shared_models: types.MaybeMultipleEvalSharedModels, cf_configs: Mapping[str, CounterfactualConfig], ): - """Validates and updates the EvalSharedModels and CF configs. - - Args: - eval_shared_models: The set of EvalSharedModels to validate and update. - cf_configs: The CF configs to validate and update. - - Returns: - A tuple of updated eval_shared_models and cf_configs. - - Raises: - ValueError if: - - eval_shared_models is empty - - eval_shared_models are not all _SUPPORTED_MODEL_TYPES - - cf_configs is empty - - The model names in cf_configs do not match eval_shared_models - """ - eval_shared_models = model_util.verify_and_update_eval_shared_models( - eval_shared_models - ) - if not eval_shared_models: - raise ValueError( - 'The CounterfactualPredictionsExtractor requires at least one ' - f'EvalSharedModel, but got normalized models: {eval_shared_models}.' - ) - model_types = {m.model_type for m in eval_shared_models} - if not model_types.issubset(_SUPPORTED_MODEL_TYPES): - raise ValueError( - f'Only {_SUPPORTED_MODEL_TYPES} model types are supported, but found ' - f'model types: {model_types}.' - ) - if not cf_configs: - raise ValueError( - 'The CounterfactualPredictionsExtractor requires at least ' - 'one cf_configs, but got 0.' + """Validates and updates the EvalSharedModels and CF configs. + + Args: + ---- + eval_shared_models: The set of EvalSharedModels to validate and update. + cf_configs: The CF configs to validate and update. + + Returns: + ------- + A tuple of updated eval_shared_models and cf_configs. + + Raises: + ------ + ValueError if: + - eval_shared_models is empty + - eval_shared_models are not all _SUPPORTED_MODEL_TYPES + - cf_configs is empty + - The model names in cf_configs do not match eval_shared_models + """ + eval_shared_models = model_util.verify_and_update_eval_shared_models( + eval_shared_models ) - - if len(eval_shared_models) == 1: - if len(cf_configs) == 1: - # Follow the normalization logic in verify_and_update_eval_shared_models - # and rekey cf_config in single model case under the empty key, "". - cf_configs = {'': next(iter(cf_configs.values()))} - else: - raise ValueError( - 'The CounterfactualPredictionsExtractor was provided only one ' - 'EvalSharedModel, in which case exactly one config is expected, but ' - f'got {len(cf_configs)}: {cf_configs}' - ) - - configured_model_names = set(cf_configs) - eval_shared_model_names = {model.model_name for model in eval_shared_models} - unmatched_config_names = configured_model_names - eval_shared_model_names - if unmatched_config_names: - raise ValueError( - 'model_name_to_config contains model names which do not match the ' - 'eval_shared_model model_names. Configured names: ' - f'{configured_model_names}, eval_shared_models names: ' - f'{eval_shared_model_names}. Unmatched configured model names: ' - f'{unmatched_config_names}.' - ) - return eval_shared_models, cf_configs + if not eval_shared_models: + raise ValueError( + "The CounterfactualPredictionsExtractor requires at least one " + f"EvalSharedModel, but got normalized models: {eval_shared_models}." + ) + model_types = {m.model_type for m in eval_shared_models} + if not model_types.issubset(_SUPPORTED_MODEL_TYPES): + raise ValueError( + f"Only {_SUPPORTED_MODEL_TYPES} model types are supported, but found " + f"model types: {model_types}." + ) + if not cf_configs: + raise ValueError( + "The CounterfactualPredictionsExtractor requires at least " + "one cf_configs, but got 0." + ) + + if len(eval_shared_models) == 1: + if len(cf_configs) == 1: + # Follow the normalization logic in verify_and_update_eval_shared_models + # and rekey cf_config in single model case under the empty key, "". + cf_configs = {"": next(iter(cf_configs.values()))} + else: + raise ValueError( + "The CounterfactualPredictionsExtractor was provided only one " + "EvalSharedModel, in which case exactly one config is expected, but " + f"got {len(cf_configs)}: {cf_configs}" + ) + + configured_model_names = set(cf_configs) + eval_shared_model_names = {model.model_name for model in eval_shared_models} + unmatched_config_names = configured_model_names - eval_shared_model_names + if unmatched_config_names: + raise ValueError( + "model_name_to_config contains model names which do not match the " + "eval_shared_model model_names. Configured names: " + f"{configured_model_names}, eval_shared_models names: " + f"{eval_shared_model_names}. Unmatched configured model names: " + f"{unmatched_config_names}." + ) + return eval_shared_models, cf_configs def _filter_model_specs( eval_config: config_pb2.EvalConfig, eval_shared_models: Iterable[types.EvalSharedModel], ) -> config_pb2.EvalConfig: - """Filters EvalConfig.model_specs to match the set of EvalSharedModels.""" - result = config_pb2.EvalConfig() - result.CopyFrom(eval_config) - del result.model_specs[:] - model_names = [model.model_name for model in eval_shared_models] - result.model_specs.extend( - [spec for spec in eval_config.model_specs if spec.name in model_names] - ) - return result + """Filters EvalConfig.model_specs to match the set of EvalSharedModels.""" + result = config_pb2.EvalConfig() + result.CopyFrom(eval_config) + del result.model_specs[:] + model_names = [model.model_name for model in eval_shared_models] + result.model_specs.extend( + [spec for spec in eval_config.model_specs if spec.name in model_names] + ) + return result def _cf_preprocess( extracts: types.Extracts, config: CounterfactualConfig, ) -> types.Extracts: - """Preprocesses extracts for counterfactual prediction. - - This method is to be called on each Extracts object prior to applying the - wrapped prediction PTransform. - - Args: - extracts: An Extracts instance which is suitable for feeding to a non-CF - prediction extractor. - config: The counterfactual config which determines how the inputs should be - counterfactually modified. - - Returns: - An Extracts instance which, when fed to a predictions PTransform, will - produce counterfactual predictions. - """ - result = extracts.copy() - result[_TEMP_ORIG_INPUT_KEY] = result[constants.INPUT_KEY] - cf_inputs = [] - for serialized_input in result[constants.INPUT_KEY]: - cf_example = tf.train.Example.FromString(serialized_input) - for dst_key, src_key in config.items(): - cf_example.features.feature[dst_key].CopyFrom( - cf_example.features.feature[src_key] - ) - cf_inputs.append(cf_example.SerializeToString()) - cf_inputs = np.array(cf_inputs, dtype=object) - result[constants.INPUT_KEY] = cf_inputs - return result + """Preprocesses extracts for counterfactual prediction. + + This method is to be called on each Extracts object prior to applying the + wrapped prediction PTransform. + + Args: + ---- + extracts: An Extracts instance which is suitable for feeding to a non-CF + prediction extractor. + config: The counterfactual config which determines how the inputs should be + counterfactually modified. + + Returns: + ------- + An Extracts instance which, when fed to a predictions PTransform, will + produce counterfactual predictions. + """ + result = extracts.copy() + result[_TEMP_ORIG_INPUT_KEY] = result[constants.INPUT_KEY] + cf_inputs = [] + for serialized_input in result[constants.INPUT_KEY]: + cf_example = tf.train.Example.FromString(serialized_input) + for dst_key, src_key in config.items(): + cf_example.features.feature[dst_key].CopyFrom( + cf_example.features.feature[src_key] + ) + cf_inputs.append(cf_example.SerializeToString()) + cf_inputs = np.array(cf_inputs, dtype=object) + result[constants.INPUT_KEY] = cf_inputs + return result def _cf_postprocess(extracts: types.Extracts) -> types.Extracts: - """Postprocesses the result of applying a CF prediction ptransform. + """Postprocesses the result of applying a CF prediction ptransform. - This method takes in an Extracts instance that has been prepocessed by - _preprocess_cf and has had a prediction PTransform applied, and makes it look - as if just a non-CF prediction PTransform had been applied. + This method takes in an Extracts instance that has been prepocessed by + _preprocess_cf and has had a prediction PTransform applied, and makes it look + as if just a non-CF prediction PTransform had been applied. - Args: - extracts: An Extracts instance which has been preprocessed by _preprocess_cf - and gone through a prediction PTransform. + Args: + ---- + extracts: An Extracts instance which has been preprocessed by _preprocess_cf + and gone through a prediction PTransform. - Returns: - An Extracts instance which appears to have been produced by a standard - predictions PTransform. - """ - extracts = extracts.copy() - extracts[constants.INPUT_KEY] = extracts[_TEMP_ORIG_INPUT_KEY] - del extracts[_TEMP_ORIG_INPUT_KEY] - return extracts + Returns: + ------- + An Extracts instance which appears to have been produced by a standard + predictions PTransform. + """ + extracts = extracts.copy() + extracts[constants.INPUT_KEY] = extracts[_TEMP_ORIG_INPUT_KEY] + del extracts[_TEMP_ORIG_INPUT_KEY] + return extracts @beam.ptransform_fn @@ -262,13 +270,13 @@ def _ExtractCounterfactualPredictions( # pylint: disable=invalid-name config: CounterfactualConfig, predictions_ptransform: beam.PTransform, ) -> beam.PCollection[types.Extracts]: - """Computes counterfactual predictions for a single model.""" - return ( - extracts - | 'PreprocessInputs' >> beam.Map(_cf_preprocess, config=config) - | 'Predict' >> predictions_ptransform - | 'PostProcessPredictions' >> beam.Map(_cf_postprocess) - ) + """Computes counterfactual predictions for a single model.""" + return ( + extracts + | "PreprocessInputs" >> beam.Map(_cf_preprocess, config=config) + | "Predict" >> predictions_ptransform + | "PostProcessPredictions" >> beam.Map(_cf_postprocess) + ) @beam.ptransform_fn @@ -277,21 +285,23 @@ def _ExtractPredictions( # pylint: disable=invalid-name cf_ptransforms: Dict[str, beam.PTransform], non_cf_ptransform: Optional[beam.PTransform], ) -> beam.PCollection[types.Extracts]: - """Applies both CF and non-CF prediction ptransforms and merges results. - - Args: - extracts: Incoming TFMA extracts. - cf_ptransforms: A mapping from model name to - _ExtractCounterfactualPredictions ptransforms - non_cf_ptransform: Optionally, a ptransform responsible for computing the - non-counterfactual predictions. - - Returns: - A PCollection of extracts containing merged predictions from both - counterfactual and non-counterfactual models. - """ - if non_cf_ptransform: - extracts = extracts | 'PredictNonCF' >> non_cf_ptransform - for model_name, cf_ptransform in cf_ptransforms.items(): - extracts = extracts | f'PredictCF[{model_name}]' >> cf_ptransform - return extracts + """Applies both CF and non-CF prediction ptransforms and merges results. + + Args: + ---- + extracts: Incoming TFMA extracts. + cf_ptransforms: A mapping from model name to + _ExtractCounterfactualPredictions ptransforms + non_cf_ptransform: Optionally, a ptransform responsible for computing the + non-counterfactual predictions. + + Returns: + ------- + A PCollection of extracts containing merged predictions from both + counterfactual and non-counterfactual models. + """ + if non_cf_ptransform: + extracts = extracts | "PredictNonCF" >> non_cf_ptransform + for model_name, cf_ptransform in cf_ptransforms.items(): + extracts = extracts | f"PredictCF[{model_name}]" >> cf_ptransform + return extracts diff --git a/tensorflow_model_analysis/extractors/counterfactual_predictions_extractor_test.py b/tensorflow_model_analysis/extractors/counterfactual_predictions_extractor_test.py index d8b57da04c..00f6942e0d 100644 --- a/tensorflow_model_analysis/extractors/counterfactual_predictions_extractor_test.py +++ b/tensorflow_model_analysis/extractors/counterfactual_predictions_extractor_test.py @@ -16,185 +16,169 @@ import os import tempfile -from absl.testing import parameterized import apache_beam as beam -from apache_beam.testing import util import numpy as np import tensorflow as tf +from absl.testing import parameterized +from apache_beam.testing import util +from google.protobuf import text_format +from tensorflow_metadata.proto.v0 import schema_pb2 +from tfx_bsl.tfxio import tensor_adapter +from tfx_bsl.tfxio import test_util as tfx_bsl_test_util + from tensorflow_model_analysis import constants -from tensorflow_model_analysis.api import model_eval_lib -from tensorflow_model_analysis.api import types -from tensorflow_model_analysis.extractors import counterfactual_predictions_extractor -from tensorflow_model_analysis.extractors import features_extractor +from tensorflow_model_analysis.api import model_eval_lib, types +from tensorflow_model_analysis.extractors import ( + counterfactual_predictions_extractor, + features_extractor, +) from tensorflow_model_analysis.proto import config_pb2 from tensorflow_model_analysis.utils import test_util from tensorflow_model_analysis.utils.keras_lib import tf_keras -from tfx_bsl.tfxio import tensor_adapter -from tfx_bsl.tfxio import test_util as tfx_bsl_test_util - -from google.protobuf import text_format -from tensorflow_metadata.proto.v0 import schema_pb2 class IdentityParsingLayer(tf_keras.layers.Layer): - """A Kears layer which performs parsing and returns a single tensor.""" + """A Kears layer which performs parsing and returns a single tensor.""" - def __init__(self, feature_key): - self._feature_key = feature_key - super(IdentityParsingLayer, self).__init__(trainable=False) + def __init__(self, feature_key): + self._feature_key = feature_key + super(IdentityParsingLayer, self).__init__(trainable=False) - def call(self, serialized_example): - parsed = tf.io.parse_example( - serialized_example, - {self._feature_key: tf.io.FixedLenFeature(shape=[], dtype=tf.int64)}, - ) - return parsed[self._feature_key] + def call(self, serialized_example): + parsed = tf.io.parse_example( + serialized_example, + {self._feature_key: tf.io.FixedLenFeature(shape=[], dtype=tf.int64)}, + ) + return parsed[self._feature_key] class CounterfactualPredictionsExtactorTest( test_util.TensorflowModelAnalysisTest, parameterized.TestCase ): + def _makeIdentityParsingModel(self): + """Builds a Keras model that parses and returns a single input tensor.""" + inputs = tf_keras.Input(shape=(), dtype=tf.string) + outputs = IdentityParsingLayer(feature_key="x")(inputs) + model = tf_keras.Model(inputs=inputs, outputs=outputs) + path = os.path.join(tempfile.mkdtemp(), "export_dir") + tf.saved_model.save(model, path) + return path - def _makeIdentityParsingModel(self): - """Builds a Keras model that parses and returns a single input tensor.""" - inputs = tf_keras.Input(shape=(), dtype=tf.string) - outputs = IdentityParsingLayer(feature_key='x')(inputs) - model = tf_keras.Model(inputs=inputs, outputs=outputs) - path = os.path.join(tempfile.mkdtemp(), 'export_dir') - tf.saved_model.save(model, path) - return path - - @parameterized.named_parameters( - { - 'testcase_name': 'empty_eval_shared_models', - 'eval_shared_models': [], - 'cf_configs': {}, - 'expected_exception_regex': r'requires at least one EvalSharedModel', - }, - { - 'testcase_name': 'empty_cf_configs', - 'eval_shared_models': [ - types.EvalSharedModel( - model_path='', model_type=constants.TF_KERAS - ) - ], - 'cf_configs': {}, - 'expected_exception_regex': r'requires at least one cf_configs', - }, - { - 'testcase_name': 'unsupported_type', - 'eval_shared_models': [ - types.EvalSharedModel( - model_path='', model_type=constants.TF_ESTIMATOR - ) - ], - 'cf_configs': {}, - 'expected_exception_regex': r'found model types.*tf_estimator', - }, - { - 'testcase_name': 'not_exacty_one_config', - 'eval_shared_models': [ - types.EvalSharedModel( - model_path='', model_type=constants.TF_KERAS - ) - ], - 'cf_configs': {'orig': {}, 'cf': {}}, - 'expected_exception_regex': r'one config is expected, but got 2', - }, - { - 'testcase_name': 'unmatched_configs', - 'eval_shared_models': [ - types.EvalSharedModel( - model_name='orig', - model_path='', - model_type=constants.TF_KERAS, - ), - types.EvalSharedModel( - model_name='cf', model_path='', model_type=constants.TF_KERAS - ), - ], - 'cf_configs': {'orig': {}, 'cf': {}, 'cf1': {}}, - 'expected_exception_regex': r'Unmatched configured model names:.*cf1', - }, - ) - def test_validate_and_update_models_and_configs( - self, eval_shared_models, cf_configs, expected_exception_regex - ): - with self.assertRaisesRegex(ValueError, expected_exception_regex): - ( - counterfactual_predictions_extractor._validate_and_update_models_and_configs( - eval_shared_models, cf_configs - ) - ) + @parameterized.named_parameters( + { + "testcase_name": "empty_eval_shared_models", + "eval_shared_models": [], + "cf_configs": {}, + "expected_exception_regex": r"requires at least one EvalSharedModel", + }, + { + "testcase_name": "empty_cf_configs", + "eval_shared_models": [ + types.EvalSharedModel(model_path="", model_type=constants.TF_KERAS) + ], + "cf_configs": {}, + "expected_exception_regex": r"requires at least one cf_configs", + }, + { + "testcase_name": "unsupported_type", + "eval_shared_models": [ + types.EvalSharedModel(model_path="", model_type=constants.TF_ESTIMATOR) + ], + "cf_configs": {}, + "expected_exception_regex": r"found model types.*tf_estimator", + }, + { + "testcase_name": "not_exacty_one_config", + "eval_shared_models": [ + types.EvalSharedModel(model_path="", model_type=constants.TF_KERAS) + ], + "cf_configs": {"orig": {}, "cf": {}}, + "expected_exception_regex": r"one config is expected, but got 2", + }, + { + "testcase_name": "unmatched_configs", + "eval_shared_models": [ + types.EvalSharedModel( + model_name="orig", + model_path="", + model_type=constants.TF_KERAS, + ), + types.EvalSharedModel( + model_name="cf", model_path="", model_type=constants.TF_KERAS + ), + ], + "cf_configs": {"orig": {}, "cf": {}, "cf1": {}}, + "expected_exception_regex": r"Unmatched configured model names:.*cf1", + }, + ) + def test_validate_and_update_models_and_configs( + self, eval_shared_models, cf_configs, expected_exception_regex + ): + with self.assertRaisesRegex(ValueError, expected_exception_regex): + ( + counterfactual_predictions_extractor._validate_and_update_models_and_configs( + eval_shared_models, cf_configs + ) + ) - @parameterized.named_parameters( - { - 'testcase_name': 'single_non_cf_single_cf', - 'eval_shared_model_names': ['orig', 'cf'], - 'model_specs': [ - config_pb2.ModelSpec( - name='orig', signature_name='serving_default' - ), - config_pb2.ModelSpec(name='cf', signature_name='serving_default'), - ], - 'cf_configs': {'cf': {'x': 'x_cf1'}}, - 'expected_predictions': { - 'orig': np.array([1, 2]), - 'cf': np.array([1, 1]), - }, - }, - { - 'testcase_name': 'single_cf', - 'eval_shared_model_names': [''], - 'model_specs': [ - config_pb2.ModelSpec(signature_name='serving_default') - ], - 'cf_configs': {'cf': {'x': 'x_cf1'}}, - 'expected_predictions': {'': np.array([1, 1])}, - }, - { - 'testcase_name': 'single_non_cf_multiple_cf', - 'eval_shared_model_names': ['orig', 'cf1', 'cf2'], - 'model_specs': [ - config_pb2.ModelSpec( - name='orig', signature_name='serving_default' - ), - config_pb2.ModelSpec( - name='cf1', signature_name='serving_default' - ), - config_pb2.ModelSpec( - name='cf2', signature_name='serving_default' - ), - ], - 'cf_configs': {'cf1': {'x': 'x_cf1'}, 'cf2': {'x': 'x_cf2'}}, - 'expected_predictions': { - 'orig': np.array([1, 2]), - 'cf1': np.array([1, 1]), - 'cf2': np.array([2, 2]), - }, - }, - ) - def test_cf_predictions_extractor( - self, - eval_shared_model_names, - model_specs, - cf_configs, - expected_predictions, - ): - model_path = self._makeIdentityParsingModel() - eval_config = config_pb2.EvalConfig(model_specs=model_specs) - eval_shared_models = [] - for model_name in eval_shared_model_names: - eval_shared_models.append( - model_eval_lib.default_eval_shared_model( - eval_saved_model_path=model_path, - tags=[tf.saved_model.SERVING], - model_name=model_name, - eval_config=eval_config, - ) - ) - schema = text_format.Parse( - """ + @parameterized.named_parameters( + { + "testcase_name": "single_non_cf_single_cf", + "eval_shared_model_names": ["orig", "cf"], + "model_specs": [ + config_pb2.ModelSpec(name="orig", signature_name="serving_default"), + config_pb2.ModelSpec(name="cf", signature_name="serving_default"), + ], + "cf_configs": {"cf": {"x": "x_cf1"}}, + "expected_predictions": { + "orig": np.array([1, 2]), + "cf": np.array([1, 1]), + }, + }, + { + "testcase_name": "single_cf", + "eval_shared_model_names": [""], + "model_specs": [config_pb2.ModelSpec(signature_name="serving_default")], + "cf_configs": {"cf": {"x": "x_cf1"}}, + "expected_predictions": {"": np.array([1, 1])}, + }, + { + "testcase_name": "single_non_cf_multiple_cf", + "eval_shared_model_names": ["orig", "cf1", "cf2"], + "model_specs": [ + config_pb2.ModelSpec(name="orig", signature_name="serving_default"), + config_pb2.ModelSpec(name="cf1", signature_name="serving_default"), + config_pb2.ModelSpec(name="cf2", signature_name="serving_default"), + ], + "cf_configs": {"cf1": {"x": "x_cf1"}, "cf2": {"x": "x_cf2"}}, + "expected_predictions": { + "orig": np.array([1, 2]), + "cf1": np.array([1, 1]), + "cf2": np.array([2, 2]), + }, + }, + ) + def test_cf_predictions_extractor( + self, + eval_shared_model_names, + model_specs, + cf_configs, + expected_predictions, + ): + model_path = self._makeIdentityParsingModel() + eval_config = config_pb2.EvalConfig(model_specs=model_specs) + eval_shared_models = [] + for model_name in eval_shared_model_names: + eval_shared_models.append( + model_eval_lib.default_eval_shared_model( + eval_saved_model_path=model_path, + tags=[tf.saved_model.SERVING], + model_name=model_name, + eval_config=eval_config, + ) + ) + schema = text_format.Parse( + """ feature { name: "x" type: INT @@ -208,71 +192,72 @@ def test_cf_predictions_extractor( type: INT } """, - schema_pb2.Schema(), - ) - - tfx_io = tfx_bsl_test_util.InMemoryTFExampleRecord( - schema=schema, raw_record_column_name=constants.ARROW_INPUT_COLUMN - ) - examples = [ - self._makeExample(x=1, x_cf1=1, x_cf2=2), - self._makeExample(x=2, x_cf1=1, x_cf2=2), - ] - num_examples = len(examples) - tensor_adapter_config = tensor_adapter.TensorAdapterConfig( - arrow_schema=tfx_io.ArrowSchema(), - tensor_representations=tfx_io.TensorRepresentations(), - ) - feature_extractor = features_extractor.FeaturesExtractor( - eval_config=eval_config, - tensor_representations=tensor_adapter_config.tensor_representations, - ) + schema_pb2.Schema(), + ) - cf_predictions_extractor = ( - counterfactual_predictions_extractor.CounterfactualPredictionsExtractor( - eval_shared_models=eval_shared_models, + tfx_io = tfx_bsl_test_util.InMemoryTFExampleRecord( + schema=schema, raw_record_column_name=constants.ARROW_INPUT_COLUMN + ) + examples = [ + self._makeExample(x=1, x_cf1=1, x_cf2=2), + self._makeExample(x=2, x_cf1=1, x_cf2=2), + ] + num_examples = len(examples) + tensor_adapter_config = tensor_adapter.TensorAdapterConfig( + arrow_schema=tfx_io.ArrowSchema(), + tensor_representations=tfx_io.TensorRepresentations(), + ) + feature_extractor = features_extractor.FeaturesExtractor( eval_config=eval_config, - cf_configs=cf_configs, + tensor_representations=tensor_adapter_config.tensor_representations, + ) + + cf_predictions_extractor = ( + counterfactual_predictions_extractor.CounterfactualPredictionsExtractor( + eval_shared_models=eval_shared_models, + eval_config=eval_config, + cf_configs=cf_configs, + ) ) - ) - with beam.Pipeline() as pipeline: - result = ( - pipeline - | 'Create' - >> beam.Create( - [e.SerializeToString() for e in examples], reshuffle=False - ) - | 'BatchExamples' >> tfx_io.BeamSource(batch_size=num_examples) - | 'InputsToExtracts' >> model_eval_lib.BatchedInputsToExtracts() - | 'ExtractFeatures' >> feature_extractor.ptransform - | cf_predictions_extractor.stage_name - >> cf_predictions_extractor.ptransform - ) + with beam.Pipeline() as pipeline: + result = ( + pipeline + | "Create" + >> beam.Create( + [e.SerializeToString() for e in examples], reshuffle=False + ) + | "BatchExamples" >> tfx_io.BeamSource(batch_size=num_examples) + | "InputsToExtracts" >> model_eval_lib.BatchedInputsToExtracts() + | "ExtractFeatures" >> feature_extractor.ptransform + | cf_predictions_extractor.stage_name + >> cf_predictions_extractor.ptransform + ) - def check_result(got): - try: - self.assertLen(got, 1) - # exact outputs are non-deterministic because model is trained - # so we can't assert full extracts. Instead, we just assert keys. - self.assertIn(constants.PREDICTIONS_KEY, got[0]) - np.testing.assert_equal( - got[0][constants.PREDICTIONS_KEY], - expected_predictions, - err_msg=( - f'actual:{got[0][constants.PREDICTIONS_KEY]}\n' - f'expected:{expected_predictions}' - ), - ) - self.assertNotIn( - counterfactual_predictions_extractor._TEMP_ORIG_INPUT_KEY, got[0] - ) - except AssertionError as err: - raise util.BeamAssertException(err) + def check_result(got): + try: + self.assertLen(got, 1) + # exact outputs are non-deterministic because model is trained + # so we can't assert full extracts. Instead, we just assert keys. + self.assertIn(constants.PREDICTIONS_KEY, got[0]) + np.testing.assert_equal( + got[0][constants.PREDICTIONS_KEY], + expected_predictions, + err_msg=( + f"actual:{got[0][constants.PREDICTIONS_KEY]}\n" + f"expected:{expected_predictions}" + ), + ) + self.assertNotIn( + counterfactual_predictions_extractor._TEMP_ORIG_INPUT_KEY, + got[0], + ) + except AssertionError as err: + raise util.BeamAssertException(err) - util.assert_that(result, check_result, label='result') + util.assert_that(result, check_result, label="result") -if __name__ == '__main__': - tf.compat.v1.enable_v2_behavior() - tf.test.main() +if __name__ == "__main__": + tf.compat.v1.enable_v2_behavior() + tf.test.main() diff --git a/tensorflow_model_analysis/extractors/example_weights_extractor.py b/tensorflow_model_analysis/extractors/example_weights_extractor.py index 220cce2a3a..f2e0846482 100644 --- a/tensorflow_model_analysis/extractors/example_weights_extractor.py +++ b/tensorflow_model_analysis/extractors/example_weights_extractor.py @@ -16,37 +16,40 @@ import copy import apache_beam as beam + from tensorflow_model_analysis import constants from tensorflow_model_analysis.api import types from tensorflow_model_analysis.extractors import extractor from tensorflow_model_analysis.proto import config_pb2 from tensorflow_model_analysis.utils import model_util -_EXAMPLE_WEIGHTS_EXTRACTOR_STAGE_NAME = 'ExtractExampleWeights' +_EXAMPLE_WEIGHTS_EXTRACTOR_STAGE_NAME = "ExtractExampleWeights" def ExampleWeightsExtractor( eval_config: config_pb2.EvalConfig, ) -> extractor.Extractor: - """Creates an extractor for extracting example weights. + """Creates an extractor for extracting example weights. - The extractor's PTransform uses the config's ModelSpec.example_weight_key(s) - to lookup the associated example weight values stored as features under the - tfma.FEATURES_KEY (and optionally tfma.TRANSFORMED_FEATURES_KEY) in extracts. - The resulting values are then added to the extracts under the key - tfma.EXAMPLE_WEIGHTS_KEY. + The extractor's PTransform uses the config's ModelSpec.example_weight_key(s) + to lookup the associated example weight values stored as features under the + tfma.FEATURES_KEY (and optionally tfma.TRANSFORMED_FEATURES_KEY) in extracts. + The resulting values are then added to the extracts under the key + tfma.EXAMPLE_WEIGHTS_KEY. - Args: - eval_config: Eval config. + Args: + ---- + eval_config: Eval config. - Returns: - Extractor for extracting example weights. - """ - # pylint: disable=no-value-for-parameter - return extractor.Extractor( - stage_name=_EXAMPLE_WEIGHTS_EXTRACTOR_STAGE_NAME, - ptransform=_ExtractExampleWeights(eval_config=eval_config), - ) + Returns: + ------- + Extractor for extracting example weights. + """ + # pylint: disable=no-value-for-parameter + return extractor.Extractor( + stage_name=_EXAMPLE_WEIGHTS_EXTRACTOR_STAGE_NAME, + ptransform=_ExtractExampleWeights(eval_config=eval_config), + ) @beam.ptransform_fn @@ -55,30 +58,32 @@ def ExampleWeightsExtractor( def _ExtractExampleWeights( extracts: beam.pvalue.PCollection, eval_config: config_pb2.EvalConfig ) -> beam.pvalue.PCollection: - """Extracts example weights from features extracts. + """Extracts example weights from features extracts. - Args: - extracts: PCollection containing features under tfma.FEATURES_KEY. - eval_config: Eval config. + Args: + ---- + extracts: PCollection containing features under tfma.FEATURES_KEY. + eval_config: Eval config. - Returns: - PCollection of extracts with additional example weights added under the key - tfma.EXAMPLE_WEIGHTS_KEY. - """ + Returns: + ------- + PCollection of extracts with additional example weights added under the key + tfma.EXAMPLE_WEIGHTS_KEY. + """ - def extract_example_weights( # pylint: disable=invalid-name - batched_extracts: types.Extracts, - ) -> types.Extracts: - """Extract example weights from extracts containing features.""" - result = copy.copy(batched_extracts) - example_weights = model_util.get_feature_values_for_model_spec_field( - list(eval_config.model_specs), - 'example_weight_key', - 'example_weight_keys', - result, - ) - if example_weights is not None: - result[constants.EXAMPLE_WEIGHTS_KEY] = example_weights - return result + def extract_example_weights( # pylint: disable=invalid-name + batched_extracts: types.Extracts, + ) -> types.Extracts: + """Extract example weights from extracts containing features.""" + result = copy.copy(batched_extracts) + example_weights = model_util.get_feature_values_for_model_spec_field( + list(eval_config.model_specs), + "example_weight_key", + "example_weight_keys", + result, + ) + if example_weights is not None: + result[constants.EXAMPLE_WEIGHTS_KEY] = example_weights + return result - return extracts | 'ExtractExampleWeights' >> beam.Map(extract_example_weights) + return extracts | "ExtractExampleWeights" >> beam.Map(extract_example_weights) diff --git a/tensorflow_model_analysis/extractors/example_weights_extractor_test.py b/tensorflow_model_analysis/extractors/example_weights_extractor_test.py index 789db14407..f12a043465 100644 --- a/tensorflow_model_analysis/extractors/example_weights_extractor_test.py +++ b/tensorflow_model_analysis/extractors/example_weights_extractor_test.py @@ -13,142 +13,147 @@ # limitations under the License. """Test for example weights extractor.""" -from absl.testing import parameterized import apache_beam as beam -from apache_beam.testing import util import numpy as np import tensorflow as tf +from absl.testing import parameterized +from apache_beam.testing import util +from google.protobuf import text_format +from tensorflow_metadata.proto.v0 import schema_pb2 +from tfx_bsl.tfxio import test_util as tfx_bsl_test_util + from tensorflow_model_analysis import constants from tensorflow_model_analysis.api import model_eval_lib -from tensorflow_model_analysis.extractors import example_weights_extractor -from tensorflow_model_analysis.extractors import features_extractor +from tensorflow_model_analysis.extractors import ( + example_weights_extractor, + features_extractor, +) from tensorflow_model_analysis.proto import config_pb2 from tensorflow_model_analysis.utils import test_util -from tfx_bsl.tfxio import test_util as tfx_bsl_test_util - -from google.protobuf import text_format -from tensorflow_metadata.proto.v0 import schema_pb2 class ExampleWeightsExtractorTest( test_util.TensorflowModelAnalysisTest, parameterized.TestCase ): - - @parameterized.named_parameters( - ('with_example_weight', 'example_weight'), - ('without_example_weight', None), - ) - def testExampleWeightsExtractor(self, example_weight): - model_spec = config_pb2.ModelSpec(example_weight_key=example_weight) - eval_config = config_pb2.EvalConfig(model_specs=[model_spec]) - feature_extractor = features_extractor.FeaturesExtractor(eval_config) - example_weight_extractor = ( - example_weights_extractor.ExampleWeightsExtractor(eval_config) + @parameterized.named_parameters( + ("with_example_weight", "example_weight"), + ("without_example_weight", None), ) + def testExampleWeightsExtractor(self, example_weight): + model_spec = config_pb2.ModelSpec(example_weight_key=example_weight) + eval_config = config_pb2.EvalConfig(model_specs=[model_spec]) + feature_extractor = features_extractor.FeaturesExtractor(eval_config) + example_weight_extractor = example_weights_extractor.ExampleWeightsExtractor( + eval_config + ) - example_weight_feature = '' - if example_weight is not None: - example_weight_feature = """ + example_weight_feature = "" + if example_weight is not None: + example_weight_feature = ( + """ feature { name: "%s" type: FLOAT } - """ % example_weight - schema = text_format.Parse( - example_weight_feature + """ + """ + % example_weight + ) + schema = text_format.Parse( + example_weight_feature + + """ feature { name: "fixed_int" type: INT } """, - schema_pb2.Schema(), - ) - tfx_io = tfx_bsl_test_util.InMemoryTFExampleRecord( - schema=schema, raw_record_column_name=constants.ARROW_INPUT_COLUMN - ) + schema_pb2.Schema(), + ) + tfx_io = tfx_bsl_test_util.InMemoryTFExampleRecord( + schema=schema, raw_record_column_name=constants.ARROW_INPUT_COLUMN + ) - def maybe_add_key(d, key, value): - if key is not None: - d[key] = value - return d + def maybe_add_key(d, key, value): + if key is not None: + d[key] = value + return d - example_kwargs = [ - maybe_add_key( - { - 'fixed_int': 1, - }, - example_weight, - 0.5, - ), - maybe_add_key( - { - 'fixed_int': 1, - }, - example_weight, - 0.0, - ), - maybe_add_key( - { - 'fixed_int': 2, - }, - example_weight, - 1.0, - ), - ] + example_kwargs = [ + maybe_add_key( + { + "fixed_int": 1, + }, + example_weight, + 0.5, + ), + maybe_add_key( + { + "fixed_int": 1, + }, + example_weight, + 0.0, + ), + maybe_add_key( + { + "fixed_int": 2, + }, + example_weight, + 1.0, + ), + ] - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'Create' - >> beam.Create( - [ - self._makeExample(**kwargs).SerializeToString() - for kwargs in example_kwargs - ], - reshuffle=False, - ) - | 'BatchExamples' >> tfx_io.BeamSource(batch_size=3) - | 'InputsToExtracts' >> model_eval_lib.BatchedInputsToExtracts() - | feature_extractor.stage_name >> feature_extractor.ptransform - | example_weight_extractor.stage_name - >> example_weight_extractor.ptransform - ) + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "Create" + >> beam.Create( + [ + self._makeExample(**kwargs).SerializeToString() + for kwargs in example_kwargs + ], + reshuffle=False, + ) + | "BatchExamples" >> tfx_io.BeamSource(batch_size=3) + | "InputsToExtracts" >> model_eval_lib.BatchedInputsToExtracts() + | feature_extractor.stage_name >> feature_extractor.ptransform + | example_weight_extractor.stage_name + >> example_weight_extractor.ptransform + ) - # pylint: enable=no-value-for-parameter + # pylint: enable=no-value-for-parameter - def check_result(got): - try: - self.assertLen(got, 1) - if example_weight: - self.assertAllClose( - got[0][constants.EXAMPLE_WEIGHTS_KEY], - np.array([[0.5], [0.0], [1.0]]), - ) - else: - self.assertNotIn(constants.EXAMPLE_WEIGHTS_KEY, got[0]) + def check_result(got): + try: + self.assertLen(got, 1) + if example_weight: + self.assertAllClose( + got[0][constants.EXAMPLE_WEIGHTS_KEY], + np.array([[0.5], [0.0], [1.0]]), + ) + else: + self.assertNotIn(constants.EXAMPLE_WEIGHTS_KEY, got[0]) - except AssertionError as err: - raise util.BeamAssertException(err) + except AssertionError as err: + raise util.BeamAssertException(err) - util.assert_that(result, check_result, label='result') + util.assert_that(result, check_result, label="result") - def testExampleWeightsExtractorMultiOutput(self): - model_spec = config_pb2.ModelSpec( - example_weight_keys={ - 'output1': 'example_weight1', - 'output2': 'example_weight2', - 'output3': 'example_weight3', - } - ) - eval_config = config_pb2.EvalConfig(model_specs=[model_spec]) - feature_extractor = features_extractor.FeaturesExtractor(eval_config) - example_weight_extractor = ( - example_weights_extractor.ExampleWeightsExtractor(eval_config) - ) + def testExampleWeightsExtractorMultiOutput(self): + model_spec = config_pb2.ModelSpec( + example_weight_keys={ + "output1": "example_weight1", + "output2": "example_weight2", + "output3": "example_weight3", + } + ) + eval_config = config_pb2.EvalConfig(model_specs=[model_spec]) + feature_extractor = features_extractor.FeaturesExtractor(eval_config) + example_weight_extractor = example_weights_extractor.ExampleWeightsExtractor( + eval_config + ) - schema = text_format.Parse( - """ + schema = text_format.Parse( + """ feature { name: "example_weight1" type: FLOAT @@ -162,73 +167,69 @@ def testExampleWeightsExtractorMultiOutput(self): type: INT } """, - schema_pb2.Schema(), - ) - tfx_io = tfx_bsl_test_util.InMemoryTFExampleRecord( - schema=schema, raw_record_column_name=constants.ARROW_INPUT_COLUMN - ) + schema_pb2.Schema(), + ) + tfx_io = tfx_bsl_test_util.InMemoryTFExampleRecord( + schema=schema, raw_record_column_name=constants.ARROW_INPUT_COLUMN + ) - examples = [ - self._makeExample( - example_weight1=0.5, example_weight2=0.5, fixed_int=1 - ), - self._makeExample( - example_weight1=0.0, example_weight2=1.0, fixed_int=1 - ), - ] + examples = [ + self._makeExample(example_weight1=0.5, example_weight2=0.5, fixed_int=1), + self._makeExample(example_weight1=0.0, example_weight2=1.0, fixed_int=1), + ] - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'Create' - >> beam.Create( - [e.SerializeToString() for e in examples], reshuffle=False - ) - | 'BatchExamples' >> tfx_io.BeamSource(batch_size=2) - | 'InputsToExtracts' >> model_eval_lib.BatchedInputsToExtracts() - | feature_extractor.stage_name >> feature_extractor.ptransform - | example_weight_extractor.stage_name - >> example_weight_extractor.ptransform - ) + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "Create" + >> beam.Create( + [e.SerializeToString() for e in examples], reshuffle=False + ) + | "BatchExamples" >> tfx_io.BeamSource(batch_size=2) + | "InputsToExtracts" >> model_eval_lib.BatchedInputsToExtracts() + | feature_extractor.stage_name >> feature_extractor.ptransform + | example_weight_extractor.stage_name + >> example_weight_extractor.ptransform + ) - # pylint: enable=no-value-for-parameter + # pylint: enable=no-value-for-parameter - def check_result(got): - try: - self.assertLen(got, 1) - self.assertAllClose( - got[0][constants.EXAMPLE_WEIGHTS_KEY], - { - 'output1': np.array([[0.5], [0.0]]), - 'output2': np.array([[0.5], [1.0]]), - }, - ) + def check_result(got): + try: + self.assertLen(got, 1) + self.assertAllClose( + got[0][constants.EXAMPLE_WEIGHTS_KEY], + { + "output1": np.array([[0.5], [0.0]]), + "output2": np.array([[0.5], [1.0]]), + }, + ) - except AssertionError as err: - raise util.BeamAssertException(err) + except AssertionError as err: + raise util.BeamAssertException(err) - util.assert_that(result, check_result, label='result') + util.assert_that(result, check_result, label="result") - def testExampleWeightsExtractorMultiModel(self): - model_spec1 = config_pb2.ModelSpec( - name='model1', example_weight_key='example_weight' - ) - model_spec2 = config_pb2.ModelSpec( - name='model2', - example_weight_keys={ - 'output1': 'example_weight1', - 'output2': 'example_weight2', - }, - ) - eval_config = config_pb2.EvalConfig(model_specs=[model_spec1, model_spec2]) - feature_extractor = features_extractor.FeaturesExtractor(eval_config) - example_weight_extractor = ( - example_weights_extractor.ExampleWeightsExtractor(eval_config) - ) + def testExampleWeightsExtractorMultiModel(self): + model_spec1 = config_pb2.ModelSpec( + name="model1", example_weight_key="example_weight" + ) + model_spec2 = config_pb2.ModelSpec( + name="model2", + example_weight_keys={ + "output1": "example_weight1", + "output2": "example_weight2", + }, + ) + eval_config = config_pb2.EvalConfig(model_specs=[model_spec1, model_spec2]) + feature_extractor = features_extractor.FeaturesExtractor(eval_config) + example_weight_extractor = example_weights_extractor.ExampleWeightsExtractor( + eval_config + ) - schema = text_format.Parse( - """ + schema = text_format.Parse( + """ feature { name: "example_weight" type: FLOAT @@ -246,66 +247,66 @@ def testExampleWeightsExtractorMultiModel(self): type: INT } """, - schema_pb2.Schema(), - ) - tfx_io = tfx_bsl_test_util.InMemoryTFExampleRecord( - schema=schema, raw_record_column_name=constants.ARROW_INPUT_COLUMN - ) + schema_pb2.Schema(), + ) + tfx_io = tfx_bsl_test_util.InMemoryTFExampleRecord( + schema=schema, raw_record_column_name=constants.ARROW_INPUT_COLUMN + ) - examples = [ - self._makeExample( - example_weight=0.5, - example_weight1=0.5, - example_weight2=0.5, - fixed_int=1, - ), - self._makeExample( - example_weight=0.0, - example_weight1=0.0, - example_weight2=1.0, - fixed_int=1, - ), - ] + examples = [ + self._makeExample( + example_weight=0.5, + example_weight1=0.5, + example_weight2=0.5, + fixed_int=1, + ), + self._makeExample( + example_weight=0.0, + example_weight1=0.0, + example_weight2=1.0, + fixed_int=1, + ), + ] - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'Create' - >> beam.Create( - [e.SerializeToString() for e in examples], reshuffle=False - ) - | 'BatchExamples' >> tfx_io.BeamSource(batch_size=2) - | 'InputsToExtracts' >> model_eval_lib.BatchedInputsToExtracts() - | feature_extractor.stage_name >> feature_extractor.ptransform - | example_weight_extractor.stage_name - >> example_weight_extractor.ptransform - ) + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "Create" + >> beam.Create( + [e.SerializeToString() for e in examples], reshuffle=False + ) + | "BatchExamples" >> tfx_io.BeamSource(batch_size=2) + | "InputsToExtracts" >> model_eval_lib.BatchedInputsToExtracts() + | feature_extractor.stage_name >> feature_extractor.ptransform + | example_weight_extractor.stage_name + >> example_weight_extractor.ptransform + ) - # pylint: enable=no-value-for-parameter + # pylint: enable=no-value-for-parameter - def check_result(got): - try: - self.assertLen(got, 1) - for model_name in ('model1', 'model2'): - self.assertIn(model_name, got[0][constants.EXAMPLE_WEIGHTS_KEY]) - self.assertAllClose( - got[0][constants.EXAMPLE_WEIGHTS_KEY]['model1'], - np.array([[0.5], [0.0]]), - ) - self.assertAllClose( - got[0][constants.EXAMPLE_WEIGHTS_KEY]['model2'], - { - 'output1': np.array([[0.5], [0.0]]), - 'output2': np.array([[0.5], [1.0]]), - }, - ) + def check_result(got): + try: + self.assertLen(got, 1) + for model_name in ("model1", "model2"): + self.assertIn(model_name, got[0][constants.EXAMPLE_WEIGHTS_KEY]) + self.assertAllClose( + got[0][constants.EXAMPLE_WEIGHTS_KEY]["model1"], + np.array([[0.5], [0.0]]), + ) + self.assertAllClose( + got[0][constants.EXAMPLE_WEIGHTS_KEY]["model2"], + { + "output1": np.array([[0.5], [0.0]]), + "output2": np.array([[0.5], [1.0]]), + }, + ) - except AssertionError as err: - raise util.BeamAssertException(err) + except AssertionError as err: + raise util.BeamAssertException(err) - util.assert_that(result, check_result, label='result') + util.assert_that(result, check_result, label="result") -if __name__ == '__main__': - tf.test.main() +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow_model_analysis/extractors/extractor.py b/tensorflow_model_analysis/extractors/extractor.py index 9d779980d4..f862f16448 100644 --- a/tensorflow_model_analysis/extractors/extractor.py +++ b/tensorflow_model_analysis/extractors/extractor.py @@ -16,24 +16,21 @@ from typing import Any, Dict, Iterable, NamedTuple, Optional, Union import apache_beam as beam + from tensorflow_model_analysis.api import types from tensorflow_model_analysis.utils import util # Tag for the last extractor in list of extractors. -LAST_EXTRACTOR_STAGE_NAME = '' +LAST_EXTRACTOR_STAGE_NAME = "" + # An Extractor is a PTransform that takes Extracts as input and returns Extracts # as output. A typical example is a PredictExtractor that receives an 'input' # placeholder for input and adds additional 'features', 'labels', and # 'predictions' extracts. -Extractor = NamedTuple( # pylint: disable=invalid-name - 'Extractor', - [ - ('stage_name', str), - # PTransform Extracts -> Extracts - ('ptransform', beam.PTransform), - ], -) +class Extractor(NamedTuple): + stage_name: str + ptransform: beam.PTransform @beam.ptransform_fn @@ -44,44 +41,47 @@ def Filter( # pylint: disable=invalid-name include: Optional[Union[Iterable[str], Dict[str, Any]]] = None, exclude: Optional[Union[Iterable[str], Dict[str, Any]]] = None, ) -> beam.pvalue.PCollection: - """Filters extracts to include/exclude specified keys. + """Filters extracts to include/exclude specified keys. - Args: - extracts: PCollection of extracts. - include: List or map of keys to include in output. If a map of keys is - passed then the keys and sub-keys that exist in the map will be included - in the output. An empty dict behaves as a wildcard matching all keys or - the value itself. Since matching on feature values is not currently - supported, an empty dict must be used to represent the leaf nodes. For - example, {'key1': {'key1-subkey': {}}, 'key2': {}}. - exclude: List or map of keys to exclude from output. If a map of keys is - passed then the keys and sub-keys that exist in the map will be excluded - from the output. An empty dict behaves as a wildcard matching all keys or - the value itself. Since matching on feature values is not currently - supported, an empty dict must be used to represent the leaf nodes. For - example, {'key1': {'key1-subkey': {}}, 'key2': {}}. + Args: + ---- + extracts: PCollection of extracts. + include: List or map of keys to include in output. If a map of keys is + passed then the keys and sub-keys that exist in the map will be included + in the output. An empty dict behaves as a wildcard matching all keys or + the value itself. Since matching on feature values is not currently + supported, an empty dict must be used to represent the leaf nodes. For + example, {'key1': {'key1-subkey': {}}, 'key2': {}}. + exclude: List or map of keys to exclude from output. If a map of keys is + passed then the keys and sub-keys that exist in the map will be excluded + from the output. An empty dict behaves as a wildcard matching all keys or + the value itself. Since matching on feature values is not currently + supported, an empty dict must be used to represent the leaf nodes. For + example, {'key1': {'key1-subkey': {}}, 'key2': {}}. - Returns: - Filtered PCollection of Extracts. + Returns: + ------- + Filtered PCollection of Extracts. - Raises: - ValueError: If both include and exclude are used. - """ - if include and exclude: - raise ValueError('only one of include or exclude should be used.') + Raises: + ------ + ValueError: If both include and exclude are used. + """ + if include and exclude: + raise ValueError("only one of include or exclude should be used.") - if not isinstance(include, dict): - include = {k: {} for k in include or []} - if not isinstance(exclude, dict): - exclude = {k: {} for k in exclude or []} + if not isinstance(include, dict): + include = {k: {} for k in include or []} + if not isinstance(exclude, dict): + exclude = {k: {} for k in exclude or []} - def filter_extracts(extracts: types.Extracts) -> types.Extracts: # pylint: disable=invalid-name - """Filters extracts.""" - if not include and not exclude: - return extracts - elif include: - return util.include_filter(include, extracts) - else: - return util.exclude_filter(exclude, extracts) + def filter_extracts(extracts: types.Extracts) -> types.Extracts: # pylint: disable=invalid-name + """Filters extracts.""" + if not include and not exclude: + return extracts + elif include: + return util.include_filter(include, extracts) + else: + return util.exclude_filter(exclude, extracts) - return extracts | beam.Map(filter_extracts) + return extracts | beam.Map(filter_extracts) diff --git a/tensorflow_model_analysis/extractors/extractor_test.py b/tensorflow_model_analysis/extractors/extractor_test.py index 7d80ef45b6..15002df91e 100644 --- a/tensorflow_model_analysis/extractors/extractor_test.py +++ b/tensorflow_model_analysis/extractors/extractor_test.py @@ -14,103 +14,111 @@ """Test for extractor.""" import apache_beam as beam -from apache_beam.testing import util import tensorflow as tf +from apache_beam.testing import util + from tensorflow_model_analysis.extractors import extractor from tensorflow_model_analysis.utils import test_util class ExtractorTest(test_util.TensorflowModelAnalysisTest): - - def testFilterRaisesValueError(self): - with self.assertRaises(ValueError): - with beam.Pipeline() as pipeline: - _ = ( - pipeline - | 'Create' >> beam.Create([]) - | 'Filter' >> extractor.Filter(include=['a'], exclude=['b']) - ) - - def testIncludeFilter(self): - with beam.Pipeline() as pipeline: - got = ( - pipeline - | 'Create' >> beam.Create([{'a': 1, 'b': 2, 'c': 3, 'd': 4}]) - | 'Filter' >> extractor.Filter(include=['a', 'c']) - ) - - def check_result(got): - try: - self.assertEqual(got, [{'a': 1, 'c': 3}]) - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(got, check_result) - - def testIncludeFilterWithDict(self): - with beam.Pipeline() as pipeline: - got = ( - pipeline - | 'Create' - >> beam.Create([{ - 'a': 1, - 'b': {'b2': 2}, - 'c': {'c2': {'c21': 3, 'c22': 4}}, - 'd': {'d2': 4}, - }]) - | 'Filter' - >> extractor.Filter(include={'b': {}, 'c': {'c2': {'c21': {}}}}) - ) - - def check_result(got): - try: - self.assertEqual(got, [{'b': {'b2': 2}, 'c': {'c2': {'c21': 3}}}]) - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(got, check_result) - - def testExludeFilter(self): - with beam.Pipeline() as pipeline: - got = ( - pipeline - | 'Create' >> beam.Create([{'a': 1, 'b': 2, 'c': 3, 'd': 4}]) - | 'Filter' >> extractor.Filter(exclude=['b', 'd']) - ) - - def check_result(got): - try: - self.assertEqual(got, [{'a': 1, 'c': 3}]) - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(got, check_result) - - def testExcludeFilterWithDict(self): - with beam.Pipeline() as pipeline: - got = ( - pipeline - | 'Create' - >> beam.Create([{ - 'a': 1, - 'b': {'b2': 2}, - 'c': {'c2': {'c21': 3, 'c22': 4}}, - 'd': {'d2': 4}, - }]) - | 'Filter' - >> extractor.Filter(exclude={'b': {}, 'c': {'c2': {'c21': {}}}}) - ) - - def check_result(got): - try: - self.assertEqual( - got, [{'a': 1, 'c': {'c2': {'c22': 4}}, 'd': {'d2': 4}}] - ) - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(got, check_result) - - -if __name__ == '__main__': - tf.test.main() + def testFilterRaisesValueError(self): + with self.assertRaises(ValueError): + with beam.Pipeline() as pipeline: + _ = ( + pipeline + | "Create" >> beam.Create([]) + | "Filter" >> extractor.Filter(include=["a"], exclude=["b"]) + ) + + def testIncludeFilter(self): + with beam.Pipeline() as pipeline: + got = ( + pipeline + | "Create" >> beam.Create([{"a": 1, "b": 2, "c": 3, "d": 4}]) + | "Filter" >> extractor.Filter(include=["a", "c"]) + ) + + def check_result(got): + try: + self.assertEqual(got, [{"a": 1, "c": 3}]) + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(got, check_result) + + def testIncludeFilterWithDict(self): + with beam.Pipeline() as pipeline: + got = ( + pipeline + | "Create" + >> beam.Create( + [ + { + "a": 1, + "b": {"b2": 2}, + "c": {"c2": {"c21": 3, "c22": 4}}, + "d": {"d2": 4}, + } + ] + ) + | "Filter" + >> extractor.Filter(include={"b": {}, "c": {"c2": {"c21": {}}}}) + ) + + def check_result(got): + try: + self.assertEqual(got, [{"b": {"b2": 2}, "c": {"c2": {"c21": 3}}}]) + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(got, check_result) + + def testExludeFilter(self): + with beam.Pipeline() as pipeline: + got = ( + pipeline + | "Create" >> beam.Create([{"a": 1, "b": 2, "c": 3, "d": 4}]) + | "Filter" >> extractor.Filter(exclude=["b", "d"]) + ) + + def check_result(got): + try: + self.assertEqual(got, [{"a": 1, "c": 3}]) + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(got, check_result) + + def testExcludeFilterWithDict(self): + with beam.Pipeline() as pipeline: + got = ( + pipeline + | "Create" + >> beam.Create( + [ + { + "a": 1, + "b": {"b2": 2}, + "c": {"c2": {"c21": 3, "c22": 4}}, + "d": {"d2": 4}, + } + ] + ) + | "Filter" + >> extractor.Filter(exclude={"b": {}, "c": {"c2": {"c21": {}}}}) + ) + + def check_result(got): + try: + self.assertEqual( + got, [{"a": 1, "c": {"c2": {"c22": 4}}, "d": {"d2": 4}}] + ) + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(got, check_result) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow_model_analysis/extractors/features_extractor.py b/tensorflow_model_analysis/extractors/features_extractor.py index e692273dc4..448b1156e9 100644 --- a/tensorflow_model_analysis/extractors/features_extractor.py +++ b/tensorflow_model_analysis/extractors/features_extractor.py @@ -14,134 +14,136 @@ """Features extractor.""" import copy -from typing import Mapping, Optional, Text, Tuple +from typing import Mapping, Optional, Tuple import apache_beam as beam import numpy as np import pyarrow as pa +from tensorflow_metadata.proto.v0 import schema_pb2 + from tensorflow_model_analysis import constants from tensorflow_model_analysis.api import types from tensorflow_model_analysis.extractors import extractor from tensorflow_model_analysis.proto import config_pb2 from tensorflow_model_analysis.utils import util -from tensorflow_metadata.proto.v0 import schema_pb2 - -_FEATURES_EXTRACTOR_STAGE_NAME = 'ExtractFeatures' -FEATURES_KEY = 'features' -ARROW_RECORD_BATCH_KEY = 'arrow_record_batch' +_FEATURES_EXTRACTOR_STAGE_NAME = "ExtractFeatures" +FEATURES_KEY = "features" +ARROW_RECORD_BATCH_KEY = "arrow_record_batch" def FeaturesExtractor( # pylint: disable=invalid-name eval_config: config_pb2.EvalConfig, tensor_representations: Optional[ - Mapping[Text, schema_pb2.TensorRepresentation] + Mapping[str, schema_pb2.TensorRepresentation] ] = None, ) -> extractor.Extractor: - """Creates an extractor for extracting features. + """Creates an extractor for extracting features. - The extractor acts as follows depending on the existence of certain keys - within the incoming extracts: + The extractor acts as follows depending on the existence of certain keys + within the incoming extracts: - 1) Extracts contains tfma.ARROW_RECORD_BATCH_KEY + 1) Extracts contains tfma.ARROW_RECORD_BATCH_KEY - The features stored in the RecordBatch will be extracted and added to the - output extract under the key tfma.FEATURES_KEY and the raw serialized inputs - will be added under the tfma.INPUT_KEY. Any extracts that already exist will - be merged with the values from the RecordBatch with the RecordBatch values - taking precedence when duplicate keys are detected. The - tfma.ARROW_RECORD_BATCH_KEY key will be removed from the output extracts. + The features stored in the RecordBatch will be extracted and added to the + output extract under the key tfma.FEATURES_KEY and the raw serialized inputs + will be added under the tfma.INPUT_KEY. Any extracts that already exist will + be merged with the values from the RecordBatch with the RecordBatch values + taking precedence when duplicate keys are detected. The + tfma.ARROW_RECORD_BATCH_KEY key will be removed from the output extracts. - 2) Extracts contains tfma.FEATURES_KEY (but not tfma.ARROW_RECORD_BATCH_KEY) + 2) Extracts contains tfma.FEATURES_KEY (but not tfma.ARROW_RECORD_BATCH_KEY) - The operation will be a no-op and the incoming extracts will be passed as is - to the output. + The operation will be a no-op and the incoming extracts will be passed as is + to the output. - 3) Extracts contains neither tfma.FEATURES_KEY | tfma.ARROW_RECORD_BATCH_KEY + 3) Extracts contains neither tfma.FEATURES_KEY | tfma.ARROW_RECORD_BATCH_KEY - An exception will be raised. + An exception will be raised. - Args: - eval_config: Eval config. - tensor_representations: Optional tensor representations to use when parsing - the data. If tensor_representations are not passed or a representation is - not found for a given feature name a default representation will be used - where possible, otherwise an exception will be raised. + Args: + ---- + eval_config: Eval config. + tensor_representations: Optional tensor representations to use when parsing + the data. If tensor_representations are not passed or a representation is + not found for a given feature name a default representation will be used + where possible, otherwise an exception will be raised. - Returns: - Extractor for extracting features. - """ - del eval_config - # pylint: disable=no-value-for-parameter - return extractor.Extractor( - stage_name=_FEATURES_EXTRACTOR_STAGE_NAME, - ptransform=_ExtractFeatures(tensor_representations or {}), - ) + Returns: + ------- + Extractor for extracting features. + """ + del eval_config + # pylint: disable=no-value-for-parameter + return extractor.Extractor( + stage_name=_FEATURES_EXTRACTOR_STAGE_NAME, + ptransform=_ExtractFeatures(tensor_representations or {}), + ) # TODO(b/214273030): Move to tfx-bsl. def _is_list_like(arrow_type: pa.DataType) -> bool: - return pa.types.is_list(arrow_type) or pa.types.is_large_list(arrow_type) + return pa.types.is_list(arrow_type) or pa.types.is_large_list(arrow_type) # TODO(b/214273030): Move to tfx-bsl. def _is_binary_like(arrow_type: pa.DataType) -> bool: - return ( - pa.types.is_binary(arrow_type) - or pa.types.is_large_binary(arrow_type) - or pa.types.is_string(arrow_type) - or pa.types.is_large_string(arrow_type) - ) + return ( + pa.types.is_binary(arrow_type) + or pa.types.is_large_binary(arrow_type) + or pa.types.is_string(arrow_type) + or pa.types.is_large_string(arrow_type) + ) # TODO(b/214273030): Move to tfx-bsl. def _is_supported_arrow_value_type(arrow_type: pa.DataType) -> bool: - return ( - pa.types.is_integer(arrow_type) - or pa.types.is_floating(arrow_type) - or _is_binary_like(arrow_type) - ) + return ( + pa.types.is_integer(arrow_type) + or pa.types.is_floating(arrow_type) + or _is_binary_like(arrow_type) + ) def _drop_unsupported_columns_and_fetch_raw_data_column( record_batch: pa.RecordBatch, ) -> Tuple[pa.RecordBatch, Optional[np.ndarray]]: - """Drops unsupported columns and fetches the raw data column. - - Currently, types that are not binary_like or ListArray[primitive types] are - dropped. - - Args: - record_batch: An Arrow RecordBatch. - - Returns: - Arrow RecordBatch with only supported columns. - """ - column_names, column_arrays = [], [] - serialized_examples = None - for column_name, column_array in zip( - record_batch.schema.names, record_batch.columns - ): - column_type = column_array.type - if column_name == constants.ARROW_INPUT_COLUMN: - assert _is_list_like(column_type) and _is_binary_like( - column_type.value_type - ), 'Invalid type for batched input key: {}. Expected binary like.'.format( - column_type - ) - serialized_examples = np.asarray(column_array.flatten()) - # Currently we only handle columns of type list. - # We ignore other columns as we cannot efficiently convert them into an - # instance dict format. - elif _is_list_like(column_type) and _is_supported_arrow_value_type( - column_type.value_type + """Drops unsupported columns and fetches the raw data column. + + Currently, types that are not binary_like or ListArray[primitive types] are + dropped. + + Args: + ---- + record_batch: An Arrow RecordBatch. + + Returns: + ------- + Arrow RecordBatch with only supported columns. + """ + column_names, column_arrays = [], [] + serialized_examples = None + for column_name, column_array in zip( + record_batch.schema.names, record_batch.columns ): - column_names.append(column_name) - column_arrays.append(column_array) - return ( - pa.RecordBatch.from_arrays(column_arrays, column_names), - serialized_examples, - ) + column_type = column_array.type + if column_name == constants.ARROW_INPUT_COLUMN: + assert ( + _is_list_like(column_type) and _is_binary_like(column_type.value_type) + ), f"Invalid type for batched input key: {column_type}. Expected binary like." + serialized_examples = np.asarray(column_array.flatten()) + # Currently we only handle columns of type list. + # We ignore other columns as we cannot efficiently convert them into an + # instance dict format. + elif _is_list_like(column_type) and _is_supported_arrow_value_type( + column_type.value_type + ): + column_names.append(column_name) + column_arrays.append(column_array) + return ( + pa.RecordBatch.from_arrays(column_arrays, column_names), + serialized_examples, + ) @beam.ptransform_fn @@ -151,49 +153,50 @@ def _ExtractFeatures( # pylint: disable=invalid-name extracts: beam.pvalue.PCollection, tensor_representations: Mapping[str, schema_pb2.TensorRepresentation], ) -> beam.pvalue.PCollection: - """Extracts features from extracts. - - Args: - extracts: PCollection containing features under tfma.ARROW_RECORD_BATCH_KEY - or tfma.FEATURES_KEY. - tensor_representations: Tensor representations. - - Returns: - PCollection of extracts with additional features added under the key - tfma.FEATURES_KEY and optionally inputs added under the tfma.INPUTS_KEY. - - Raises: - ValueError: If incoming extracts contains neither tfma.FEATURES_KEY nor - tfma.ARROW_RECORD_BATCH_KEY. - """ - - def extract_features(extracts: types.Extracts) -> types.Extracts: - """Extract features from extracts containing arrow table.""" - result = copy.copy(extracts) - if constants.ARROW_RECORD_BATCH_KEY in extracts: - (record_batch, serialized_examples) = ( - _drop_unsupported_columns_and_fetch_raw_data_column( - extracts[constants.ARROW_RECORD_BATCH_KEY] - ) - ) - del result[constants.ARROW_RECORD_BATCH_KEY] - features = ( - result[constants.FEATURES_KEY] - if constants.FEATURES_KEY in result - else {} - ) - features.update( - util.record_batch_to_tensor_values( - record_batch, tensor_representations - ) - ) - result[constants.FEATURES_KEY] = features - result[constants.INPUT_KEY] = serialized_examples - elif constants.FEATURES_KEY not in extracts: - raise ValueError( - 'Incoming extracts must contain either tfma.ARROW_RECORD_BATCH_KEY ' - f'or tfma.FEATURES_KEY, but extracts={extracts}' - ) - return result - - return extracts | 'ExtractFeatures' >> beam.Map(extract_features) + """Extracts features from extracts. + + Args: + ---- + extracts: PCollection containing features under tfma.ARROW_RECORD_BATCH_KEY + or tfma.FEATURES_KEY. + tensor_representations: Tensor representations. + + Returns: + ------- + PCollection of extracts with additional features added under the key + tfma.FEATURES_KEY and optionally inputs added under the tfma.INPUTS_KEY. + + Raises: + ------ + ValueError: If incoming extracts contains neither tfma.FEATURES_KEY nor + tfma.ARROW_RECORD_BATCH_KEY. + """ + + def extract_features(extracts: types.Extracts) -> types.Extracts: + """Extract features from extracts containing arrow table.""" + result = copy.copy(extracts) + if constants.ARROW_RECORD_BATCH_KEY in extracts: + (record_batch, serialized_examples) = ( + _drop_unsupported_columns_and_fetch_raw_data_column( + extracts[constants.ARROW_RECORD_BATCH_KEY] + ) + ) + del result[constants.ARROW_RECORD_BATCH_KEY] + features = ( + result[constants.FEATURES_KEY] + if constants.FEATURES_KEY in result + else {} + ) + features.update( + util.record_batch_to_tensor_values(record_batch, tensor_representations) + ) + result[constants.FEATURES_KEY] = features + result[constants.INPUT_KEY] = serialized_examples + elif constants.FEATURES_KEY not in extracts: + raise ValueError( + "Incoming extracts must contain either tfma.ARROW_RECORD_BATCH_KEY " + f"or tfma.FEATURES_KEY, but extracts={extracts}" + ) + return result + + return extracts | "ExtractFeatures" >> beam.Map(extract_features) diff --git a/tensorflow_model_analysis/extractors/features_extractor_test.py b/tensorflow_model_analysis/extractors/features_extractor_test.py index c7ab1a5cbd..15b7d948aa 100644 --- a/tensorflow_model_analysis/extractors/features_extractor_test.py +++ b/tensorflow_model_analysis/extractors/features_extractor_test.py @@ -13,61 +13,60 @@ # limitations under the License. """Test for features extractor.""" -from absl.testing import parameterized import apache_beam as beam -from apache_beam.testing import util import numpy as np import tensorflow as tf +from absl.testing import parameterized +from apache_beam.testing import util +from google.protobuf import text_format +from tensorflow_metadata.proto.v0 import schema_pb2 +from tfx_bsl.tfxio import tf_example_record + from tensorflow_model_analysis import constants from tensorflow_model_analysis.api import model_eval_lib from tensorflow_model_analysis.extractors import features_extractor from tensorflow_model_analysis.proto import config_pb2 from tensorflow_model_analysis.utils import test_util -from tfx_bsl.tfxio import tf_example_record - -from google.protobuf import text_format -from tensorflow_metadata.proto.v0 import schema_pb2 class FeaturesExtractorTest( test_util.TensorflowModelAnalysisTest, parameterized.TestCase ): - - def test_features_extractor_no_features(self): - model_spec = config_pb2.ModelSpec() - eval_config = config_pb2.EvalConfig(model_specs=[model_spec]) - feature_extractor = features_extractor.FeaturesExtractor(eval_config) - tfx_io = tf_example_record.TFExampleBeamRecord( - raw_record_column_name=constants.ARROW_INPUT_COLUMN, - physical_format='inmem', - telemetry_descriptors=['testing'], - ) - - with beam.Pipeline() as pipeline: - result = ( - pipeline - | 'Create' >> beam.Create([b''] * 3) - | 'DecodeToRecordBatch' >> tfx_io.BeamSource(batch_size=3) - | 'InputsToExtracts' >> model_eval_lib.BatchedInputsToExtracts() - | feature_extractor.stage_name >> feature_extractor.ptransform - ) - - def check_result(got): - self.assertLen(got, 1) - self.assertIn(constants.FEATURES_KEY, got[0]) - self.assertEmpty(got[0][constants.FEATURES_KEY]) - self.assertIn(constants.INPUT_KEY, got[0]) - self.assertLen(got[0][constants.INPUT_KEY], 3) - - util.assert_that(result, check_result, label='CheckResult') - - def test_features_extractor(self): - model_spec = config_pb2.ModelSpec() - eval_config = config_pb2.EvalConfig(model_specs=[model_spec]) - feature_extractor = features_extractor.FeaturesExtractor(eval_config) - - schema = text_format.Parse( - """ + def test_features_extractor_no_features(self): + model_spec = config_pb2.ModelSpec() + eval_config = config_pb2.EvalConfig(model_specs=[model_spec]) + feature_extractor = features_extractor.FeaturesExtractor(eval_config) + tfx_io = tf_example_record.TFExampleBeamRecord( + raw_record_column_name=constants.ARROW_INPUT_COLUMN, + physical_format="inmem", + telemetry_descriptors=["testing"], + ) + + with beam.Pipeline() as pipeline: + result = ( + pipeline + | "Create" >> beam.Create([b""] * 3) + | "DecodeToRecordBatch" >> tfx_io.BeamSource(batch_size=3) + | "InputsToExtracts" >> model_eval_lib.BatchedInputsToExtracts() + | feature_extractor.stage_name >> feature_extractor.ptransform + ) + + def check_result(got): + self.assertLen(got, 1) + self.assertIn(constants.FEATURES_KEY, got[0]) + self.assertEmpty(got[0][constants.FEATURES_KEY]) + self.assertIn(constants.INPUT_KEY, got[0]) + self.assertLen(got[0][constants.INPUT_KEY], 3) + + util.assert_that(result, check_result, label="CheckResult") + + def test_features_extractor(self): + model_spec = config_pb2.ModelSpec() + eval_config = config_pb2.EvalConfig(model_specs=[model_spec]) + feature_extractor = features_extractor.FeaturesExtractor(eval_config) + + schema = text_format.Parse( + """ feature { name: "example_weight" type: FLOAT @@ -85,75 +84,75 @@ def test_features_extractor(self): type: BYTES } """, - schema_pb2.Schema(), - ) - tfx_io = tf_example_record.TFExampleBeamRecord( - schema=schema, - raw_record_column_name=constants.ARROW_INPUT_COLUMN, - physical_format='inmem', - telemetry_descriptors=['testing'], - ) - - example_kwargs = [ - {'fixed_int': 1, 'fixed_float': 1.0, 'fixed_string': 'fixed_string1'}, - {'fixed_int': 1, 'fixed_float': 1.0, 'fixed_string': 'fixed_string2'}, - {'fixed_int': 2, 'fixed_float': 0.0, 'fixed_string': 'fixed_string3'}, - ] - - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'Create' - >> beam.Create( - [ - self._makeExample(**kwargs).SerializeToString() - for kwargs in example_kwargs - ], - reshuffle=False, - ) - | 'DecodeToRecordBatch' >> tfx_io.BeamSource(batch_size=3) - | 'InputsToExtracts' >> model_eval_lib.BatchedInputsToExtracts() - | feature_extractor.stage_name >> feature_extractor.ptransform - ) - - # pylint: enable=no-value-for-parameter - - def check_result(got): - try: - self.assertLen(got, 1) - self.assertIn(constants.FEATURES_KEY, got[0]) - self.assertLen(got[0][constants.FEATURES_KEY], 4) # 4 features - self.assertIn('example_weight', got[0][constants.FEATURES_KEY]) - # Arrays of type np.object won't compare with assertAllClose - self.assertEqual( - got[0][constants.FEATURES_KEY]['example_weight'].tolist(), - [None, None, None], - ) - self.assertIn('fixed_int', got[0][constants.FEATURES_KEY]) - self.assertAllClose( - got[0][constants.FEATURES_KEY]['fixed_int'], - np.array([[1], [1], [2]]), - ) - self.assertIn('fixed_float', got[0][constants.FEATURES_KEY]) - self.assertAllClose( - got[0][constants.FEATURES_KEY]['fixed_float'], - np.array([[1.0], [1.0], [0.0]]), - ) - self.assertIn('fixed_string', got[0][constants.FEATURES_KEY]) - # Arrays of type np.object won't compare with assertAllClose - self.assertEqual( - got[0][constants.FEATURES_KEY]['fixed_string'].tolist(), - [[b'fixed_string1'], [b'fixed_string2'], [b'fixed_string3']], - ) - self.assertIn(constants.INPUT_KEY, got[0]) - self.assertLen(got[0][constants.INPUT_KEY], 3) # 3 examples - - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(result, check_result, label='result') - - -if __name__ == '__main__': - tf.test.main() + schema_pb2.Schema(), + ) + tfx_io = tf_example_record.TFExampleBeamRecord( + schema=schema, + raw_record_column_name=constants.ARROW_INPUT_COLUMN, + physical_format="inmem", + telemetry_descriptors=["testing"], + ) + + example_kwargs = [ + {"fixed_int": 1, "fixed_float": 1.0, "fixed_string": "fixed_string1"}, + {"fixed_int": 1, "fixed_float": 1.0, "fixed_string": "fixed_string2"}, + {"fixed_int": 2, "fixed_float": 0.0, "fixed_string": "fixed_string3"}, + ] + + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "Create" + >> beam.Create( + [ + self._makeExample(**kwargs).SerializeToString() + for kwargs in example_kwargs + ], + reshuffle=False, + ) + | "DecodeToRecordBatch" >> tfx_io.BeamSource(batch_size=3) + | "InputsToExtracts" >> model_eval_lib.BatchedInputsToExtracts() + | feature_extractor.stage_name >> feature_extractor.ptransform + ) + + # pylint: enable=no-value-for-parameter + + def check_result(got): + try: + self.assertLen(got, 1) + self.assertIn(constants.FEATURES_KEY, got[0]) + self.assertLen(got[0][constants.FEATURES_KEY], 4) # 4 features + self.assertIn("example_weight", got[0][constants.FEATURES_KEY]) + # Arrays of type np.object won't compare with assertAllClose + self.assertEqual( + got[0][constants.FEATURES_KEY]["example_weight"].tolist(), + [None, None, None], + ) + self.assertIn("fixed_int", got[0][constants.FEATURES_KEY]) + self.assertAllClose( + got[0][constants.FEATURES_KEY]["fixed_int"], + np.array([[1], [1], [2]]), + ) + self.assertIn("fixed_float", got[0][constants.FEATURES_KEY]) + self.assertAllClose( + got[0][constants.FEATURES_KEY]["fixed_float"], + np.array([[1.0], [1.0], [0.0]]), + ) + self.assertIn("fixed_string", got[0][constants.FEATURES_KEY]) + # Arrays of type np.object won't compare with assertAllClose + self.assertEqual( + got[0][constants.FEATURES_KEY]["fixed_string"].tolist(), + [[b"fixed_string1"], [b"fixed_string2"], [b"fixed_string3"]], + ) + self.assertIn(constants.INPUT_KEY, got[0]) + self.assertLen(got[0][constants.INPUT_KEY], 3) # 3 examples + + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(result, check_result, label="result") + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow_model_analysis/extractors/inference_base.py b/tensorflow_model_analysis/extractors/inference_base.py index c3196c5462..199f38999b 100644 --- a/tensorflow_model_analysis/extractors/inference_base.py +++ b/tensorflow_model_analysis/extractors/inference_base.py @@ -15,241 +15,240 @@ from typing import Dict, Optional, Sequence, Tuple, Union -from absl import logging import apache_beam as beam import numpy as np import tensorflow as tf +from absl import logging +from tensorflow.python.saved_model import ( + loader_impl, # pylint: disable=g-direct-tensorflow-import +) +from tensorflow_serving.apis import prediction_log_pb2 + from tensorflow_model_analysis import constants from tensorflow_model_analysis.api import types from tensorflow_model_analysis.proto import config_pb2 -from tensorflow_model_analysis.utils import model_util -from tensorflow_model_analysis.utils import util - -from tensorflow.python.saved_model import loader_impl # pylint: disable=g-direct-tensorflow-import -from tensorflow_serving.apis import prediction_log_pb2 +from tensorflow_model_analysis.utils import model_util, util def is_valid_config_for_bulk_inference( eval_config: config_pb2.EvalConfig, eval_shared_model: Optional[types.MaybeMultipleEvalSharedModels] = None, ) -> bool: - """Validates config for use with Tfx-Bsl and ServoBeam Bulk Inference.""" - eval_shared_models = model_util.verify_and_update_eval_shared_models( - eval_shared_model - ) - if eval_shared_models is None: - logging.warning( - 'Invalid Bulk Inference Config: There must be at least one ' - 'eval_shared_model to run servo/tfx-bsl bulk inference.' + """Validates config for use with Tfx-Bsl and ServoBeam Bulk Inference.""" + eval_shared_models = model_util.verify_and_update_eval_shared_models( + eval_shared_model ) - return False - for eval_shared_model in eval_shared_models: - if eval_shared_model.model_type not in ( - constants.TF_GENERIC, - constants.TF_ESTIMATOR, - ): - logging.warning( - 'Invalid Bulk Inference Config: Only TF2 and TF ' - 'Estimator models are supported for servo/tfx-bsl bulk ' - 'inference' - ) - return False - name_to_eval_shared_model = {m.model_name: m for m in eval_shared_models} - for model_spec in eval_config.model_specs: - eval_shared_model = model_util.get_eval_shared_model( - model_spec.name, name_to_eval_shared_model - ) - saved_model = loader_impl.parse_saved_model(eval_shared_model.model_path) - if model_spec.signature_name: - signature_name = model_spec.signature_name - else: - signature_name = ( - model_util.get_default_signature_name_from_saved_model_proto( - saved_model - ) - ) - try: - signature_def = model_util.get_signature_def_from_saved_model_proto( - signature_name, saved_model - ) - except ValueError: - logging.warning( - 'Invalid Bulk Inference Config: models must have a ' - 'signature to run servo/tfx-bsl bulk inference. Consider ' - 'setting the signature explicitly in the ModelSpec.' - ) - return False - if len(signature_def.inputs) != 1: - logging.warning( - 'Invalid Bulk Inference Config: signature must accept ' - 'only one input for servo/tfx-bsl bulk inference.' - ) - return False - if list(signature_def.inputs.values())[0].dtype != tf.string: - logging.warning( - 'Invalid Bulk Inference Config: signature must accept ' - 'string input to run servo/tfx-bsl bulk inference.' - ) - return False - return True + if eval_shared_models is None: + logging.warning( + "Invalid Bulk Inference Config: There must be at least one " + "eval_shared_model to run servo/tfx-bsl bulk inference." + ) + return False + for eval_shared_model in eval_shared_models: + if eval_shared_model.model_type not in ( + constants.TF_GENERIC, + constants.TF_ESTIMATOR, + ): + logging.warning( + "Invalid Bulk Inference Config: Only TF2 and TF " + "Estimator models are supported for servo/tfx-bsl bulk " + "inference" + ) + return False + name_to_eval_shared_model = {m.model_name: m for m in eval_shared_models} + for model_spec in eval_config.model_specs: + eval_shared_model = model_util.get_eval_shared_model( + model_spec.name, name_to_eval_shared_model + ) + saved_model = loader_impl.parse_saved_model(eval_shared_model.model_path) + if model_spec.signature_name: + signature_name = model_spec.signature_name + else: + signature_name = ( + model_util.get_default_signature_name_from_saved_model_proto( + saved_model + ) + ) + try: + signature_def = model_util.get_signature_def_from_saved_model_proto( + signature_name, saved_model + ) + except ValueError: + logging.warning( + "Invalid Bulk Inference Config: models must have a " + "signature to run servo/tfx-bsl bulk inference. Consider " + "setting the signature explicitly in the ModelSpec." + ) + return False + if len(signature_def.inputs) != 1: + logging.warning( + "Invalid Bulk Inference Config: signature must accept " + "only one input for servo/tfx-bsl bulk inference." + ) + return False + if list(signature_def.inputs.values())[0].dtype != tf.string: + logging.warning( + "Invalid Bulk Inference Config: signature must accept " + "string input to run servo/tfx-bsl bulk inference." + ) + return False + return True def _create_inference_input_tuple( extracts: types.Extracts, ) -> Tuple[types.Extracts, bytes]: - """Creates a tuple containing the Extracts and input to the model.""" - try: - # Note after split_extracts splits the Extracts batch, INPUT_KEY has a value - # that is a 0 dimensional np.array. These arrays are indexed using the [()] - # syntax. - model_input = extracts[constants.INPUT_KEY][()] - except KeyError as e: - raise ValueError( - f'Extracts must contain the input keyed by "{constants.INPUT_KEY}" for ' - 'inference.' - ) from e - if not isinstance(model_input, bytes): - raise ValueError( - f'Extracts value at key: "{constants.INPUT_KEY}" is not of ' - 'type bytes. Only serialized tf.Examples and serialized ' - 'tf.SequenceExamples are currently supported. The value ' - f'is {model_input} and type {type(model_input)}.' - ) - return (extracts, model_input) + """Creates a tuple containing the Extracts and input to the model.""" + try: + # Note after split_extracts splits the Extracts batch, INPUT_KEY has a value + # that is a 0 dimensional np.array. These arrays are indexed using the [()] + # syntax. + model_input = extracts[constants.INPUT_KEY][()] + except KeyError as e: + raise ValueError( + f'Extracts must contain the input keyed by "{constants.INPUT_KEY}" for ' + "inference." + ) from e + if not isinstance(model_input, bytes): + raise ValueError( + f'Extracts value at key: "{constants.INPUT_KEY}" is not of ' + "type bytes. Only serialized tf.Examples and serialized " + "tf.SequenceExamples are currently supported. The value " + f"is {model_input} and type {type(model_input)}." + ) + return (extracts, model_input) def _parse_prediction_log_to_tensor_value( prediction_log: prediction_log_pb2.PredictionLog, ) -> Union[np.ndarray, Dict[str, np.ndarray]]: - """Parses the model inference values from a PredictionLog. + """Parses the model inference values from a PredictionLog. - Args: - prediction_log: Prediction_log_pb2.PredictionLog containing inference - results. + Args: + ---- + prediction_log: Prediction_log_pb2.PredictionLog containing inference + results. - Returns: - Values parsed from the PredictionLog inference result. These values are - formated in the format expected in TFMA PREDICTION_KEY Extracts value. - """ - log_type = prediction_log.WhichOneof('log_type') - if log_type == 'classify_log': - assert ( - len(prediction_log.classify_log.response.result.classifications) == 1 - ), ( - 'We expecth the number of classifications per PredictionLog to be ' - 'one because TFX-BSL RunInference expects single input/output and ' - 'handles batching entirely internally.' - ) - classes = np.array( - [ - c.label - for c in prediction_log.classify_log.response.result.classifications[ - 0 - ].classes - ], - dtype=object, - ) - scores = np.array( - [ - c.score - for c in prediction_log.classify_log.response.result.classifications[ - 0 - ].classes - ], - dtype=np.float32, - ) - return {'classes': classes, 'scores': scores} - elif log_type == 'regress_log': - return np.array( - [ - regression.value - for regression in prediction_log.regress_log.response.result.regressions - ], - dtype=float, - ) - elif log_type == 'predict_log': - output_tensor_name_to_tensor = {} - for k, v in prediction_log.predict_log.response.outputs.items(): - output_tensor_name_to_tensor[k] = np.squeeze(tf.make_ndarray(v), axis=0) - # If there is only one tensor (i.e. one dictionary item), we remove the - # tensor from the dict and return it directly. Generally, TFMA will not - # return a dictionary with a single value. - if len(output_tensor_name_to_tensor) == 1: - return list(output_tensor_name_to_tensor.values())[0] - return output_tensor_name_to_tensor - elif log_type == 'multi_inference_log': - raise NotImplementedError( - 'MultiInferenceLog processing not implemented yet.' - ) - elif log_type == 'session_log': - raise ValueError('SessionLog processing is not supported.') - else: - raise NotImplementedError(f'Unsupported log_type: {log_type}') + Returns: + ------- + Values parsed from the PredictionLog inference result. These values are + formated in the format expected in TFMA PREDICTION_KEY Extracts value. + """ + log_type = prediction_log.WhichOneof("log_type") + if log_type == "classify_log": + assert len(prediction_log.classify_log.response.result.classifications) == 1, ( + "We expecth the number of classifications per PredictionLog to be " + "one because TFX-BSL RunInference expects single input/output and " + "handles batching entirely internally." + ) + classes = np.array( + [ + c.label + for c in prediction_log.classify_log.response.result.classifications[ + 0 + ].classes + ], + dtype=object, + ) + scores = np.array( + [ + c.score + for c in prediction_log.classify_log.response.result.classifications[ + 0 + ].classes + ], + dtype=np.float32, + ) + return {"classes": classes, "scores": scores} + elif log_type == "regress_log": + return np.array( + [ + regression.value + for regression in prediction_log.regress_log.response.result.regressions + ], + dtype=float, + ) + elif log_type == "predict_log": + output_tensor_name_to_tensor = {} + for k, v in prediction_log.predict_log.response.outputs.items(): + output_tensor_name_to_tensor[k] = np.squeeze(tf.make_ndarray(v), axis=0) + # If there is only one tensor (i.e. one dictionary item), we remove the + # tensor from the dict and return it directly. Generally, TFMA will not + # return a dictionary with a single value. + if len(output_tensor_name_to_tensor) == 1: + return list(output_tensor_name_to_tensor.values())[0] + return output_tensor_name_to_tensor + elif log_type == "multi_inference_log": + raise NotImplementedError("MultiInferenceLog processing not implemented yet.") + elif log_type == "session_log": + raise ValueError("SessionLog processing is not supported.") + else: + raise NotImplementedError(f"Unsupported log_type: {log_type}") def insert_predictions_into_extracts( - inference_tuple: Tuple[ - types.Extracts, Dict[str, prediction_log_pb2.PredictionLog] - ], + inference_tuple: Tuple[types.Extracts, Dict[str, prediction_log_pb2.PredictionLog]], output_keypath: Sequence[str] = (constants.PREDICTIONS_KEY,), prediction_log_keypath: Optional[Sequence[str]] = None, ) -> types.Extracts: - """Inserts tensor values from PredictionLogs into the Extracts. - - If prediction_log_keypath is provided, then raw prediction logs from extracts - are - also inserted into a separate extract. + """Inserts tensor values from PredictionLogs into the Extracts. - Args: - inference_tuple: This is the output of inference. It includes the key - forwarded extracts and a dict of model name to predicition logs. - output_keypath: A sequence of keys to be used as the path to traverse and - insert the outputs in the extract. - prediction_log_keypath: A sequence of keys to be used as the path to - traverse and insert the prediction logs in the extract. + If prediction_log_keypath is provided, then raw prediction logs from extracts + are + also inserted into a separate extract. - Returns: - Extracts with the PREDICTIONS_KEY populated. If prediction_log_keypath is - provided, - then PREDICTION_LOG_KEY is also populated with prediction logs. Note: By - convention, - PREDICTIONS_KEY will point to a dictionary if there are multiple - prediction logs and a single value if there is only one prediction log. - """ - extracts, model_names_to_prediction_logs = inference_tuple - model_name_to_tensors = { - name: _parse_prediction_log_to_tensor_value(log) - for name, log in model_names_to_prediction_logs.items() - } - # If there is only one model (i.e. one dictionary item), we remove the model - # output from the dict and store it directly under the PREDICTIONS_KEY. This - # is in line with the general TFMA pattern of not storing one-item - # dictionaries. - if len(model_name_to_tensors) == 1: - value = next(iter(model_name_to_tensors.values())) - else: - value = model_name_to_tensors - extracts = util.copy_and_set_by_keys( # pylint: disable=protected-access - root=extracts, - keypath=output_keypath, - value=value, - ) + Args: + ---- + inference_tuple: This is the output of inference. It includes the key + forwarded extracts and a dict of model name to predicition logs. + output_keypath: A sequence of keys to be used as the path to traverse and + insert the outputs in the extract. + prediction_log_keypath: A sequence of keys to be used as the path to + traverse and insert the prediction logs in the extract. - # Save un-parsed prediction log if prediction_log_keypath is provided. - # Prediction log(s) will be saved under PREDICTION_LOG_KEY either within - # a dictionary of model names to prediction logs (if there are multiple - # models), or by itself (if there is only one model). - if prediction_log_keypath: - if len(model_names_to_prediction_logs) == 1: - value = next(iter(model_names_to_prediction_logs.values())) + Returns: + ------- + Extracts with the PREDICTIONS_KEY populated. If prediction_log_keypath is + provided, + then PREDICTION_LOG_KEY is also populated with prediction logs. Note: By + convention, + PREDICTIONS_KEY will point to a dictionary if there are multiple + prediction logs and a single value if there is only one prediction log. + """ + extracts, model_names_to_prediction_logs = inference_tuple + model_name_to_tensors = { + name: _parse_prediction_log_to_tensor_value(log) + for name, log in model_names_to_prediction_logs.items() + } + # If there is only one model (i.e. one dictionary item), we remove the model + # output from the dict and store it directly under the PREDICTIONS_KEY. This + # is in line with the general TFMA pattern of not storing one-item + # dictionaries. + if len(model_name_to_tensors) == 1: + value = next(iter(model_name_to_tensors.values())) else: - value = model_names_to_prediction_logs + value = model_name_to_tensors extracts = util.copy_and_set_by_keys( # pylint: disable=protected-access root=extracts, - keypath=prediction_log_keypath, + keypath=output_keypath, value=value, ) - return extracts + + # Save un-parsed prediction log if prediction_log_keypath is provided. + # Prediction log(s) will be saved under PREDICTION_LOG_KEY either within + # a dictionary of model names to prediction logs (if there are multiple + # models), or by itself (if there is only one model). + if prediction_log_keypath: + if len(model_names_to_prediction_logs) == 1: + value = next(iter(model_names_to_prediction_logs.values())) + else: + value = model_names_to_prediction_logs + extracts = util.copy_and_set_by_keys( # pylint: disable=protected-access + root=extracts, + keypath=prediction_log_keypath, + value=value, + ) + return extracts @beam.ptransform_fn @@ -262,62 +261,62 @@ def RunInference( # pylint: disable=invalid-name output_keypath: Sequence[str] = (constants.PREDICTIONS_KEY,), prediction_log_keypath: Optional[Sequence[str]] = None, ) -> beam.pvalue.PCollection: - """A PTransform that adds predictions and possibly other tensors to Extracts. + """A PTransform that adds predictions and possibly other tensors to Extracts. - Args: - extracts: PCollection of Extracts containing model inputs keyed by - tfma.FEATURES_KEY (if model inputs are named) or tfma.INPUTS_KEY (if model - takes raw tf.Examples as input). - inference_ptransform: Bulk inference ptransform used to generate - predictions. This allows users to use different implementations depending - on evironment or Beam runner (e.g. a cloud-friendly OSS implementation or - an internal-specific implementation). These implementations should accept - a pcollection consisting of tuples containing a key and a single example. - The key may be anything and the example may be a tf.Example or serialized - tf.Example. - output_batch_size: Sets a static output batch size. - output_keypath: A sequence of keys to be used as the path to traverse and - insert the outputs in the extract. - prediction_log_keypath: A sequence of keys to be used as the path to - traverse and insert the prediction logs in the extract. + Args: + ---- + extracts: PCollection of Extracts containing model inputs keyed by + tfma.FEATURES_KEY (if model inputs are named) or tfma.INPUTS_KEY (if model + takes raw tf.Examples as input). + inference_ptransform: Bulk inference ptransform used to generate + predictions. This allows users to use different implementations depending + on evironment or Beam runner (e.g. a cloud-friendly OSS implementation or + an internal-specific implementation). These implementations should accept + a pcollection consisting of tuples containing a key and a single example. + The key may be anything and the example may be a tf.Example or serialized + tf.Example. + output_batch_size: Sets a static output batch size. + output_keypath: A sequence of keys to be used as the path to traverse and + insert the outputs in the extract. + prediction_log_keypath: A sequence of keys to be used as the path to + traverse and insert the prediction logs in the extract. - Returns: - PCollection of Extracts updated with the predictions and prediction logs. - """ - extracts = ( - extracts - # Extracts are fed in pre-batched, but BulkInference has specific - # batch handling and batching requirements. To accomodate the API and - # encapsulate the inference batching logic, we unbatch here. This function - # returns new Extracts dicts and will not modify the input Extracts. - | 'SplitExtracts' - >> beam.FlatMap(util.split_extracts, expand_zero_dims=False) - # The BulkInference API allows for key forwarding. To avoid a join - # after running inference, we forward the unbatched Extracts as a key. - | 'CreateInferenceInputTuple' >> beam.Map(_create_inference_input_tuple) - | 'RunInferencePerModel' >> inference_ptransform - # Combine predictions back into the original Extracts. - | 'InsertPredictionsIntoExtracts' - >> beam.Map( - insert_predictions_into_extracts, - output_keypath=output_keypath, - prediction_log_keypath=prediction_log_keypath, - ) - ) - # Beam batch will group single Extracts into a batch. Then - # merge_extracts will flatten the batch into a single "batched" - # extract. - if output_batch_size is not None: - batch_kwargs = { - 'min_batch_size': output_batch_size, - 'max_batch_size': output_batch_size, - } - else: - # Default batch parameters. - batch_kwargs = {} - return ( - extracts - | 'BatchSingleExampleExtracts' >> beam.BatchElements(**batch_kwargs) - | 'MergeExtracts' - >> beam.Map(util.merge_extracts, squeeze_two_dim_vector=False) - ) + Returns: + ------- + PCollection of Extracts updated with the predictions and prediction logs. + """ + extracts = ( + extracts + # Extracts are fed in pre-batched, but BulkInference has specific + # batch handling and batching requirements. To accomodate the API and + # encapsulate the inference batching logic, we unbatch here. This function + # returns new Extracts dicts and will not modify the input Extracts. + | "SplitExtracts" >> beam.FlatMap(util.split_extracts, expand_zero_dims=False) + # The BulkInference API allows for key forwarding. To avoid a join + # after running inference, we forward the unbatched Extracts as a key. + | "CreateInferenceInputTuple" >> beam.Map(_create_inference_input_tuple) + | "RunInferencePerModel" >> inference_ptransform + # Combine predictions back into the original Extracts. + | "InsertPredictionsIntoExtracts" + >> beam.Map( + insert_predictions_into_extracts, + output_keypath=output_keypath, + prediction_log_keypath=prediction_log_keypath, + ) + ) + # Beam batch will group single Extracts into a batch. Then + # merge_extracts will flatten the batch into a single "batched" + # extract. + if output_batch_size is not None: + batch_kwargs = { + "min_batch_size": output_batch_size, + "max_batch_size": output_batch_size, + } + else: + # Default batch parameters. + batch_kwargs = {} + return ( + extracts + | "BatchSingleExampleExtracts" >> beam.BatchElements(**batch_kwargs) + | "MergeExtracts" >> beam.Map(util.merge_extracts, squeeze_two_dim_vector=False) + ) diff --git a/tensorflow_model_analysis/extractors/inference_base_test.py b/tensorflow_model_analysis/extractors/inference_base_test.py index a1dd0fc313..f6c2b7fcd6 100644 --- a/tensorflow_model_analysis/extractors/inference_base_test.py +++ b/tensorflow_model_analysis/extractors/inference_base_test.py @@ -18,65 +18,62 @@ """ import os +import unittest import tensorflow as tf -from tensorflow_model_analysis import constants -from tensorflow_model_analysis.extractors import features_extractor -from tensorflow_model_analysis.extractors import inference_base -from tensorflow_model_analysis.proto import config_pb2 -from tensorflow_model_analysis.utils import test_util as testutil -from tfx_bsl.tfxio import tensor_adapter -from tfx_bsl.tfxio import test_util - from google.protobuf import text_format -from tensorflow.core.protobuf import saved_model_pb2 # pylint: disable=g-direct-tensorflow-import +from tensorflow.core.protobuf import ( + saved_model_pb2, # pylint: disable=g-direct-tensorflow-import +) from tensorflow_metadata.proto.v0 import schema_pb2 -from tensorflow_serving.apis import logging_pb2 -from tensorflow_serving.apis import prediction_log_pb2 +from tensorflow_serving.apis import logging_pb2, prediction_log_pb2 +from tfx_bsl.tfxio import tensor_adapter, test_util -import unittest +from tensorflow_model_analysis import constants +from tensorflow_model_analysis.extractors import features_extractor, inference_base +from tensorflow_model_analysis.proto import config_pb2 +from tensorflow_model_analysis.utils import test_util as testutil class TfxBslPredictionsExtractorTest(testutil.TensorflowModelAnalysisTest): + def setUp(self): + super().setUp() + log_metadata1 = logging_pb2.LogMetadata(timestamp_secs=1) + predict_log1 = prediction_log_pb2.PredictLog() + self.prediction_log1 = prediction_log_pb2.PredictionLog( + predict_log=predict_log1, log_metadata=log_metadata1 + ) - def setUp(self): - super().setUp() - log_metadata1 = logging_pb2.LogMetadata(timestamp_secs=1) - predict_log1 = prediction_log_pb2.PredictLog() - self.prediction_log1 = prediction_log_pb2.PredictionLog( - predict_log=predict_log1, log_metadata=log_metadata1 - ) - - log_metadata2 = logging_pb2.LogMetadata(timestamp_secs=2) - predict_log2 = prediction_log_pb2.PredictLog() - self.prediction_log2 = prediction_log_pb2.PredictionLog( - predict_log=predict_log2, log_metadata=log_metadata2 - ) + log_metadata2 = logging_pb2.LogMetadata(timestamp_secs=2) + predict_log2 = prediction_log_pb2.PredictLog() + self.prediction_log2 = prediction_log_pb2.PredictionLog( + predict_log=predict_log2, log_metadata=log_metadata2 + ) - def _getExportDir(self): - return os.path.join(self._getTempDir(), 'export_dir') + def _getExportDir(self): + return os.path.join(self._getTempDir(), "export_dir") - def _create_tfxio_and_feature_extractor( - self, eval_config: config_pb2.EvalConfig, schema: schema_pb2.Schema - ): - tfx_io = test_util.InMemoryTFExampleRecord( - schema=schema, raw_record_column_name=constants.ARROW_INPUT_COLUMN - ) - tensor_adapter_config = tensor_adapter.TensorAdapterConfig( - arrow_schema=tfx_io.ArrowSchema(), - tensor_representations=tfx_io.TensorRepresentations(), - ) - feature_extractor = features_extractor.FeaturesExtractor( - eval_config=eval_config, - tensor_representations=tensor_adapter_config.tensor_representations, - ) - return tfx_io, feature_extractor + def _create_tfxio_and_feature_extractor( + self, eval_config: config_pb2.EvalConfig, schema: schema_pb2.Schema + ): + tfx_io = test_util.InMemoryTFExampleRecord( + schema=schema, raw_record_column_name=constants.ARROW_INPUT_COLUMN + ) + tensor_adapter_config = tensor_adapter.TensorAdapterConfig( + arrow_schema=tfx_io.ArrowSchema(), + tensor_representations=tfx_io.TensorRepresentations(), + ) + feature_extractor = features_extractor.FeaturesExtractor( + eval_config=eval_config, + tensor_representations=tensor_adapter_config.tensor_representations, + ) + return tfx_io, feature_extractor - # PR 189: Remove the `expectedFailure` mark if the test passes - @unittest.expectedFailure - def testIsValidConfigForBulkInferencePass(self): - saved_model_proto = text_format.Parse( - """ + # PR 189: Remove the `expectedFailure` mark if the test passes + @unittest.expectedFailure + def testIsValidConfigForBulkInferencePass(self): + saved_model_proto = text_format.Parse( + """ saved_model_schema_version: 1 meta_graphs { meta_info_def { @@ -107,37 +104,35 @@ def testIsValidConfigForBulkInferencePass(self): } } """, - saved_model_pb2.SavedModel(), - ) - temp_dir = self.create_tempdir() - temp_dir.create_file( - 'saved_model.pb', content=saved_model_proto.SerializeToString() - ) - eval_config = config_pb2.EvalConfig( - model_specs=[ - config_pb2.ModelSpec( - name='model_1', signature_name='serving_default' - ) - ] - ) - eval_shared_model = self.createTestEvalSharedModel( - model_path=temp_dir.full_path, - model_name='model_1', - tags=[tf.saved_model.SERVING], - model_type=constants.TF_GENERIC, - ) + saved_model_pb2.SavedModel(), + ) + temp_dir = self.create_tempdir() + temp_dir.create_file( + "saved_model.pb", content=saved_model_proto.SerializeToString() + ) + eval_config = config_pb2.EvalConfig( + model_specs=[ + config_pb2.ModelSpec(name="model_1", signature_name="serving_default") + ] + ) + eval_shared_model = self.createTestEvalSharedModel( + model_path=temp_dir.full_path, + model_name="model_1", + tags=[tf.saved_model.SERVING], + model_type=constants.TF_GENERIC, + ) - self.assertTrue( - inference_base.is_valid_config_for_bulk_inference( - eval_config, eval_shared_model + self.assertTrue( + inference_base.is_valid_config_for_bulk_inference( + eval_config, eval_shared_model + ) ) - ) - # PR 189: Remove the `expectedFailure` mark if the test passes - @unittest.expectedFailure - def testIsValidConfigForBulkInferencePassDefaultSignatureLookUp(self): - saved_model_proto = text_format.Parse( - """ + # PR 189: Remove the `expectedFailure` mark if the test passes + @unittest.expectedFailure + def testIsValidConfigForBulkInferencePassDefaultSignatureLookUp(self): + saved_model_proto = text_format.Parse( + """ saved_model_schema_version: 1 meta_graphs { meta_info_def { @@ -168,33 +163,33 @@ def testIsValidConfigForBulkInferencePassDefaultSignatureLookUp(self): } } """, - saved_model_pb2.SavedModel(), - ) - temp_dir = self.create_tempdir() - temp_dir.create_file( - 'saved_model.pb', content=saved_model_proto.SerializeToString() - ) - eval_config = config_pb2.EvalConfig( - model_specs=[config_pb2.ModelSpec(name='model_1')] - ) - eval_shared_model = self.createTestEvalSharedModel( - model_path=temp_dir.full_path, - model_name='model_1', - tags=[tf.saved_model.SERVING], - model_type=constants.TF_GENERIC, - ) + saved_model_pb2.SavedModel(), + ) + temp_dir = self.create_tempdir() + temp_dir.create_file( + "saved_model.pb", content=saved_model_proto.SerializeToString() + ) + eval_config = config_pb2.EvalConfig( + model_specs=[config_pb2.ModelSpec(name="model_1")] + ) + eval_shared_model = self.createTestEvalSharedModel( + model_path=temp_dir.full_path, + model_name="model_1", + tags=[tf.saved_model.SERVING], + model_type=constants.TF_GENERIC, + ) - self.assertTrue( - inference_base.is_valid_config_for_bulk_inference( - eval_config, eval_shared_model + self.assertTrue( + inference_base.is_valid_config_for_bulk_inference( + eval_config, eval_shared_model + ) ) - ) - # PR 189: Remove the `expectedFailure` mark if the test passes - @unittest.expectedFailure - def testIsValidConfigForBulkInferenceFailNoSignatureFound(self): - saved_model_proto = text_format.Parse( - """ + # PR 189: Remove the `expectedFailure` mark if the test passes + @unittest.expectedFailure + def testIsValidConfigForBulkInferenceFailNoSignatureFound(self): + saved_model_proto = text_format.Parse( + """ saved_model_schema_version: 1 meta_graphs { meta_info_def { @@ -225,33 +220,33 @@ def testIsValidConfigForBulkInferenceFailNoSignatureFound(self): } } """, - saved_model_pb2.SavedModel(), - ) - temp_dir = self.create_tempdir() - temp_dir.create_file( - 'saved_model.pb', content=saved_model_proto.SerializeToString() - ) - eval_config = config_pb2.EvalConfig( - model_specs=[ - config_pb2.ModelSpec(name='model_1', signature_name='not_found') - ] - ) - eval_shared_model = self.createTestEvalSharedModel( - model_path=temp_dir.full_path, - model_name='model_1', - model_type=constants.TF_GENERIC, - ) - self.assertFalse( - inference_base.is_valid_config_for_bulk_inference( - eval_config, eval_shared_model - ) - ) + saved_model_pb2.SavedModel(), + ) + temp_dir = self.create_tempdir() + temp_dir.create_file( + "saved_model.pb", content=saved_model_proto.SerializeToString() + ) + eval_config = config_pb2.EvalConfig( + model_specs=[ + config_pb2.ModelSpec(name="model_1", signature_name="not_found") + ] + ) + eval_shared_model = self.createTestEvalSharedModel( + model_path=temp_dir.full_path, + model_name="model_1", + model_type=constants.TF_GENERIC, + ) + self.assertFalse( + inference_base.is_valid_config_for_bulk_inference( + eval_config, eval_shared_model + ) + ) - # PR 189: Remove the `expectedFailure` mark if the test passes - @unittest.expectedFailure - def testIsValidConfigForBulkInferenceFailKerasModel(self): - saved_model_proto = text_format.Parse( - """ + # PR 189: Remove the `expectedFailure` mark if the test passes + @unittest.expectedFailure + def testIsValidConfigForBulkInferenceFailKerasModel(self): + saved_model_proto = text_format.Parse( + """ saved_model_schema_version: 1 meta_graphs { meta_info_def { @@ -282,35 +277,33 @@ def testIsValidConfigForBulkInferenceFailKerasModel(self): } } """, - saved_model_pb2.SavedModel(), - ) - temp_dir = self.create_tempdir() - temp_dir.create_file( - 'saved_model.pb', content=saved_model_proto.SerializeToString() - ) - eval_config = config_pb2.EvalConfig( - model_specs=[ - config_pb2.ModelSpec( - name='model_1', signature_name='serving_default' + saved_model_pb2.SavedModel(), + ) + temp_dir = self.create_tempdir() + temp_dir.create_file( + "saved_model.pb", content=saved_model_proto.SerializeToString() + ) + eval_config = config_pb2.EvalConfig( + model_specs=[ + config_pb2.ModelSpec(name="model_1", signature_name="serving_default") + ] + ) + eval_shared_model = self.createTestEvalSharedModel( + model_path=temp_dir.full_path, + model_name="model_1", + model_type=constants.TF_KERAS, + ) + self.assertFalse( + inference_base.is_valid_config_for_bulk_inference( + eval_config, eval_shared_model ) - ] - ) - eval_shared_model = self.createTestEvalSharedModel( - model_path=temp_dir.full_path, - model_name='model_1', - model_type=constants.TF_KERAS, - ) - self.assertFalse( - inference_base.is_valid_config_for_bulk_inference( - eval_config, eval_shared_model - ) - ) + ) - # PR 189: Remove the `expectedFailure` mark if the test passes - @unittest.expectedFailure - def testIsValidConfigForBulkInferenceFailWrongInputType(self): - saved_model_proto = text_format.Parse( - """ + # PR 189: Remove the `expectedFailure` mark if the test passes + @unittest.expectedFailure + def testIsValidConfigForBulkInferenceFailWrongInputType(self): + saved_model_proto = text_format.Parse( + """ saved_model_schema_version: 1 meta_graphs { meta_info_def { @@ -341,79 +334,75 @@ def testIsValidConfigForBulkInferenceFailWrongInputType(self): } } """, - saved_model_pb2.SavedModel(), - ) - temp_dir = self.create_tempdir() - temp_dir.create_file( - 'saved_model.pb', content=saved_model_proto.SerializeToString() - ) - eval_config = config_pb2.EvalConfig( - model_specs=[ - config_pb2.ModelSpec( - name='model_1', signature_name='serving_default' + saved_model_pb2.SavedModel(), + ) + temp_dir = self.create_tempdir() + temp_dir.create_file( + "saved_model.pb", content=saved_model_proto.SerializeToString() + ) + eval_config = config_pb2.EvalConfig( + model_specs=[ + config_pb2.ModelSpec(name="model_1", signature_name="serving_default") + ] + ) + eval_shared_model = self.createTestEvalSharedModel( + model_path=temp_dir.full_path, + model_name="model_1", + model_type=constants.TF_GENERIC, + ) + self.assertFalse( + inference_base.is_valid_config_for_bulk_inference( + eval_config, eval_shared_model ) - ] - ) - eval_shared_model = self.createTestEvalSharedModel( - model_path=temp_dir.full_path, - model_name='model_1', - model_type=constants.TF_GENERIC, - ) - self.assertFalse( - inference_base.is_valid_config_for_bulk_inference( - eval_config, eval_shared_model - ) - ) + ) - def testInsertSinglePredictionLogIntoExtract(self): - model_names_to_prediction_logs = {'prediction_log1': self.prediction_log1} - inference_tuple = ({}, model_names_to_prediction_logs) - output_extracts = inference_base.insert_predictions_into_extracts( - inference_tuple=inference_tuple, - prediction_log_keypath=[constants.PREDICTION_LOG_KEY], - ) + def testInsertSinglePredictionLogIntoExtract(self): + model_names_to_prediction_logs = {"prediction_log1": self.prediction_log1} + inference_tuple = ({}, model_names_to_prediction_logs) + output_extracts = inference_base.insert_predictions_into_extracts( + inference_tuple=inference_tuple, + prediction_log_keypath=[constants.PREDICTION_LOG_KEY], + ) - ref_extracts = {constants.PREDICTION_LOG_KEY: self.prediction_log1} + ref_extracts = {constants.PREDICTION_LOG_KEY: self.prediction_log1} - self.assertEqual( - output_extracts[constants.PREDICTION_LOG_KEY], - ref_extracts[constants.PREDICTION_LOG_KEY], - ) + self.assertEqual( + output_extracts[constants.PREDICTION_LOG_KEY], + ref_extracts[constants.PREDICTION_LOG_KEY], + ) - def testInsertTwoPredictionLogsIntoExtracts(self): - model_names_to_prediction_logs = { - 'prediction_log1': self.prediction_log1, - 'prediction_log2': self.prediction_log2, - } - inference_tuple = ({}, model_names_to_prediction_logs) - extracts = inference_base.insert_predictions_into_extracts( - inference_tuple, - prediction_log_keypath=[constants.PREDICTION_LOG_KEY], - ) + def testInsertTwoPredictionLogsIntoExtracts(self): + model_names_to_prediction_logs = { + "prediction_log1": self.prediction_log1, + "prediction_log2": self.prediction_log2, + } + inference_tuple = ({}, model_names_to_prediction_logs) + extracts = inference_base.insert_predictions_into_extracts( + inference_tuple, + prediction_log_keypath=[constants.PREDICTION_LOG_KEY], + ) - ref_extracts = { - constants.PREDICTION_LOG_KEY: model_names_to_prediction_logs - } + ref_extracts = {constants.PREDICTION_LOG_KEY: model_names_to_prediction_logs} - self.assertEqual( - extracts[constants.PREDICTION_LOG_KEY], - ref_extracts[constants.PREDICTION_LOG_KEY], - ) + self.assertEqual( + extracts[constants.PREDICTION_LOG_KEY], + ref_extracts[constants.PREDICTION_LOG_KEY], + ) - def testInsertPredictionLogsWithCustomPathIntoExtracts(self): - model_names_to_prediction_logs = { - 'prediction_log1': self.prediction_log1, - 'prediction_log2': self.prediction_log2, - } - inference_tuple = ({}, model_names_to_prediction_logs) - extracts = inference_base.insert_predictions_into_extracts( - inference_tuple, - prediction_log_keypath=['foo', 'bar'], - ) + def testInsertPredictionLogsWithCustomPathIntoExtracts(self): + model_names_to_prediction_logs = { + "prediction_log1": self.prediction_log1, + "prediction_log2": self.prediction_log2, + } + inference_tuple = ({}, model_names_to_prediction_logs) + extracts = inference_base.insert_predictions_into_extracts( + inference_tuple, + prediction_log_keypath=["foo", "bar"], + ) - ref_extracts = {'foo': {'bar': model_names_to_prediction_logs}} - self.assertEqual(extracts['foo']['bar'], ref_extracts['foo']['bar']) + ref_extracts = {"foo": {"bar": model_names_to_prediction_logs}} + self.assertEqual(extracts["foo"]["bar"], ref_extracts["foo"]["bar"]) -if __name__ == '__main__': - tf.test.main() +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow_model_analysis/extractors/labels_extractor.py b/tensorflow_model_analysis/extractors/labels_extractor.py index 426a3b64b2..5186507fa0 100644 --- a/tensorflow_model_analysis/extractors/labels_extractor.py +++ b/tensorflow_model_analysis/extractors/labels_extractor.py @@ -16,34 +16,37 @@ import copy import apache_beam as beam + from tensorflow_model_analysis import constants from tensorflow_model_analysis.api import types from tensorflow_model_analysis.extractors import extractor from tensorflow_model_analysis.proto import config_pb2 from tensorflow_model_analysis.utils import model_util -LABELS_EXTRACTOR_STAGE_NAME = 'ExtractLabels' +LABELS_EXTRACTOR_STAGE_NAME = "ExtractLabels" def LabelsExtractor(eval_config: config_pb2.EvalConfig) -> extractor.Extractor: - """Creates an extractor for extracting labels. + """Creates an extractor for extracting labels. - The extractor's PTransform uses the config's ModelSpec.label_key(s) to lookup - the associated label values stored as features under the tfma.FEATURES_KEY - (and optionally tfma.TRANSFORMED_FEATURES_KEY) in extracts. The resulting - values are then added to the extracts under the key tfma.LABELS_KEY. + The extractor's PTransform uses the config's ModelSpec.label_key(s) to lookup + the associated label values stored as features under the tfma.FEATURES_KEY + (and optionally tfma.TRANSFORMED_FEATURES_KEY) in extracts. The resulting + values are then added to the extracts under the key tfma.LABELS_KEY. - Args: - eval_config: Eval config. + Args: + ---- + eval_config: Eval config. - Returns: - Extractor for extracting labels. - """ - # pylint: disable=no-value-for-parameter - return extractor.Extractor( - stage_name=LABELS_EXTRACTOR_STAGE_NAME, - ptransform=_ExtractLabels(eval_config=eval_config), - ) + Returns: + ------- + Extractor for extracting labels. + """ + # pylint: disable=no-value-for-parameter + return extractor.Extractor( + stage_name=LABELS_EXTRACTOR_STAGE_NAME, + ptransform=_ExtractLabels(eval_config=eval_config), + ) @beam.ptransform_fn @@ -52,31 +55,33 @@ def LabelsExtractor(eval_config: config_pb2.EvalConfig) -> extractor.Extractor: def _ExtractLabels( extracts: beam.pvalue.PCollection, eval_config: config_pb2.EvalConfig ) -> beam.pvalue.PCollection: - """Extracts labels from features extracts. + """Extracts labels from features extracts. - Args: - extracts: PCollection containing features under tfma.FEATURES_KEY. - eval_config: Eval config. + Args: + ---- + extracts: PCollection containing features under tfma.FEATURES_KEY. + eval_config: Eval config. - Returns: - PCollection of extracts with additional labels added under the key - tfma.LABELS_KEY. - """ + Returns: + ------- + PCollection of extracts with additional labels added under the key + tfma.LABELS_KEY. + """ - def extract_labels( # pylint: disable=invalid-name - batched_extracts: types.Extracts, - ) -> types.Extracts: - """Extract labels from extracts containing features.""" - result = copy.copy(batched_extracts) - result[constants.LABELS_KEY] = ( - model_util.get_feature_values_for_model_spec_field( - list(eval_config.model_specs), - 'label_key', - 'label_keys', - result, - True, + def extract_labels( # pylint: disable=invalid-name + batched_extracts: types.Extracts, + ) -> types.Extracts: + """Extract labels from extracts containing features.""" + result = copy.copy(batched_extracts) + result[constants.LABELS_KEY] = ( + model_util.get_feature_values_for_model_spec_field( + list(eval_config.model_specs), + "label_key", + "label_keys", + result, + True, + ) ) - ) - return result + return result - return extracts | 'ExtractLabels' >> beam.Map(extract_labels) + return extracts | "ExtractLabels" >> beam.Map(extract_labels) diff --git a/tensorflow_model_analysis/extractors/labels_extractor_test.py b/tensorflow_model_analysis/extractors/labels_extractor_test.py index 04e48148bc..d69cf3386d 100644 --- a/tensorflow_model_analysis/extractors/labels_extractor_test.py +++ b/tensorflow_model_analysis/extractors/labels_extractor_test.py @@ -13,135 +13,136 @@ # limitations under the License. """Test for labels extractor.""" -from absl.testing import parameterized import apache_beam as beam -from apache_beam.testing import util import numpy as np import tensorflow as tf +from absl.testing import parameterized +from apache_beam.testing import util +from google.protobuf import text_format +from tensorflow_metadata.proto.v0 import schema_pb2 +from tfx_bsl.tfxio import test_util as tfx_bsl_test_util + from tensorflow_model_analysis import constants from tensorflow_model_analysis.api import model_eval_lib -from tensorflow_model_analysis.extractors import features_extractor -from tensorflow_model_analysis.extractors import labels_extractor +from tensorflow_model_analysis.extractors import features_extractor, labels_extractor from tensorflow_model_analysis.proto import config_pb2 from tensorflow_model_analysis.utils import test_util -from tfx_bsl.tfxio import test_util as tfx_bsl_test_util - -from google.protobuf import text_format -from tensorflow_metadata.proto.v0 import schema_pb2 class LabelsExtractorTest( test_util.TensorflowModelAnalysisTest, parameterized.TestCase ): + @parameterized.named_parameters(("with_label", "label"), ("without_label", None)) + def testLabelsExtractor(self, label): + model_spec = config_pb2.ModelSpec(label_key=label) + eval_config = config_pb2.EvalConfig(model_specs=[model_spec]) + feature_extractor = features_extractor.FeaturesExtractor(eval_config) + label_extractor = labels_extractor.LabelsExtractor(eval_config) - @parameterized.named_parameters( - ('with_label', 'label'), ('without_label', None) - ) - def testLabelsExtractor(self, label): - model_spec = config_pb2.ModelSpec(label_key=label) - eval_config = config_pb2.EvalConfig(model_specs=[model_spec]) - feature_extractor = features_extractor.FeaturesExtractor(eval_config) - label_extractor = labels_extractor.LabelsExtractor(eval_config) - - label_feature = '' - if label is not None: - label_feature = """ + label_feature = "" + if label is not None: + label_feature = ( + """ feature { name: "%s" type: FLOAT } - """ % label - schema = text_format.Parse( - label_feature + """ + """ + % label + ) + schema = text_format.Parse( + label_feature + + """ feature { name: "fixed_int" type: INT } """, - schema_pb2.Schema(), - ) - tfx_io = tfx_bsl_test_util.InMemoryTFExampleRecord( - schema=schema, raw_record_column_name=constants.ARROW_INPUT_COLUMN - ) + schema_pb2.Schema(), + ) + tfx_io = tfx_bsl_test_util.InMemoryTFExampleRecord( + schema=schema, raw_record_column_name=constants.ARROW_INPUT_COLUMN + ) - def maybe_add_key(d, key, value): - if key is not None: - d[key] = value - return d + def maybe_add_key(d, key, value): + if key is not None: + d[key] = value + return d - example_kwargs = [ - maybe_add_key( - { - 'fixed_int': 1, - }, - label, - 1.0, - ), - maybe_add_key( - { - 'fixed_int': 1, - }, - label, - 0.0, - ), - maybe_add_key( - { - 'fixed_int': 2, - }, - label, - 0.0, - ), - ] + example_kwargs = [ + maybe_add_key( + { + "fixed_int": 1, + }, + label, + 1.0, + ), + maybe_add_key( + { + "fixed_int": 1, + }, + label, + 0.0, + ), + maybe_add_key( + { + "fixed_int": 2, + }, + label, + 0.0, + ), + ] - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'Create' - >> beam.Create( - [ - self._makeExample(**kwargs).SerializeToString() - for kwargs in example_kwargs - ], - reshuffle=False, - ) - | 'BatchExamples' >> tfx_io.BeamSource(batch_size=3) - | 'InputsToExtracts' >> model_eval_lib.BatchedInputsToExtracts() - | feature_extractor.stage_name >> feature_extractor.ptransform - | label_extractor.stage_name >> label_extractor.ptransform - ) + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "Create" + >> beam.Create( + [ + self._makeExample(**kwargs).SerializeToString() + for kwargs in example_kwargs + ], + reshuffle=False, + ) + | "BatchExamples" >> tfx_io.BeamSource(batch_size=3) + | "InputsToExtracts" >> model_eval_lib.BatchedInputsToExtracts() + | feature_extractor.stage_name >> feature_extractor.ptransform + | label_extractor.stage_name >> label_extractor.ptransform + ) - # pylint: enable=no-value-for-parameter + # pylint: enable=no-value-for-parameter - def check_result(got): - try: - self.assertLen(got, 1) - if label is None: - self.assertIsNone(got[0][constants.LABELS_KEY]) - else: - self.assertAllClose( - got[0][constants.LABELS_KEY], np.array([[1.0], [0.0], [0.0]]) - ) + def check_result(got): + try: + self.assertLen(got, 1) + if label is None: + self.assertIsNone(got[0][constants.LABELS_KEY]) + else: + self.assertAllClose( + got[0][constants.LABELS_KEY], + np.array([[1.0], [0.0], [0.0]]), + ) - except AssertionError as err: - raise util.BeamAssertException(err) + except AssertionError as err: + raise util.BeamAssertException(err) - util.assert_that(result, check_result, label='result') + util.assert_that(result, check_result, label="result") - def testLabelsExtractorMultiOutput(self): - model_spec = config_pb2.ModelSpec( - label_keys={ - 'output1': 'label1', - 'output2': 'label2', - 'output3': 'label3', - } - ) - eval_config = config_pb2.EvalConfig(model_specs=[model_spec]) - feature_extractor = features_extractor.FeaturesExtractor(eval_config) - label_extractor = labels_extractor.LabelsExtractor(eval_config) + def testLabelsExtractorMultiOutput(self): + model_spec = config_pb2.ModelSpec( + label_keys={ + "output1": "label1", + "output2": "label2", + "output3": "label3", + } + ) + eval_config = config_pb2.EvalConfig(model_specs=[model_spec]) + feature_extractor = features_extractor.FeaturesExtractor(eval_config) + label_extractor = labels_extractor.LabelsExtractor(eval_config) - schema = text_format.Parse( - """ + schema = text_format.Parse( + """ feature { name: "label1" type: FLOAT @@ -155,64 +156,64 @@ def testLabelsExtractorMultiOutput(self): type: INT } """, - schema_pb2.Schema(), - ) - tfx_io = tfx_bsl_test_util.InMemoryTFExampleRecord( - schema=schema, raw_record_column_name=constants.ARROW_INPUT_COLUMN - ) + schema_pb2.Schema(), + ) + tfx_io = tfx_bsl_test_util.InMemoryTFExampleRecord( + schema=schema, raw_record_column_name=constants.ARROW_INPUT_COLUMN + ) - examples = [ - self._makeExample(label1=1.0, label2=0.0, fixed_int=1), - self._makeExample(label1=1.0, label2=1.0, fixed_int=1), - ] + examples = [ + self._makeExample(label1=1.0, label2=0.0, fixed_int=1), + self._makeExample(label1=1.0, label2=1.0, fixed_int=1), + ] - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'Create' - >> beam.Create( - [e.SerializeToString() for e in examples], reshuffle=False - ) - | 'BatchExamples' >> tfx_io.BeamSource(batch_size=2) - | 'InputsToExtracts' >> model_eval_lib.BatchedInputsToExtracts() - | feature_extractor.stage_name >> feature_extractor.ptransform - | label_extractor.stage_name >> label_extractor.ptransform - ) + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "Create" + >> beam.Create( + [e.SerializeToString() for e in examples], reshuffle=False + ) + | "BatchExamples" >> tfx_io.BeamSource(batch_size=2) + | "InputsToExtracts" >> model_eval_lib.BatchedInputsToExtracts() + | feature_extractor.stage_name >> feature_extractor.ptransform + | label_extractor.stage_name >> label_extractor.ptransform + ) - # pylint: enable=no-value-for-parameter + # pylint: enable=no-value-for-parameter - def check_result(got): - try: - self.assertLen(got, 1) - # None cannot be compared with assertAllClose - self.assertIn('output3', got[0][constants.LABELS_KEY]) - self.assertIsNone(got[0][constants.LABELS_KEY]['output3']) - del got[0][constants.LABELS_KEY]['output3'] - self.assertAllClose( - got[0][constants.LABELS_KEY], - { - 'output1': np.array([[1.0], [1.0]]), - 'output2': np.array([[0.0], [1.0]]), - }, - ) + def check_result(got): + try: + self.assertLen(got, 1) + # None cannot be compared with assertAllClose + self.assertIn("output3", got[0][constants.LABELS_KEY]) + self.assertIsNone(got[0][constants.LABELS_KEY]["output3"]) + del got[0][constants.LABELS_KEY]["output3"] + self.assertAllClose( + got[0][constants.LABELS_KEY], + { + "output1": np.array([[1.0], [1.0]]), + "output2": np.array([[0.0], [1.0]]), + }, + ) - except AssertionError as err: - raise util.BeamAssertException(err) + except AssertionError as err: + raise util.BeamAssertException(err) - util.assert_that(result, check_result, label='result') + util.assert_that(result, check_result, label="result") - def testLabelsExtractorMultiModel(self): - model_spec1 = config_pb2.ModelSpec(name='model1', label_key='label') - model_spec2 = config_pb2.ModelSpec( - name='model2', label_keys={'output1': 'label1', 'output2': 'label2'} - ) - eval_config = config_pb2.EvalConfig(model_specs=[model_spec1, model_spec2]) - feature_extractor = features_extractor.FeaturesExtractor(eval_config) - label_extractor = labels_extractor.LabelsExtractor(eval_config) + def testLabelsExtractorMultiModel(self): + model_spec1 = config_pb2.ModelSpec(name="model1", label_key="label") + model_spec2 = config_pb2.ModelSpec( + name="model2", label_keys={"output1": "label1", "output2": "label2"} + ) + eval_config = config_pb2.EvalConfig(model_specs=[model_spec1, model_spec2]) + feature_extractor = features_extractor.FeaturesExtractor(eval_config) + label_extractor = labels_extractor.LabelsExtractor(eval_config) - schema = text_format.Parse( - """ + schema = text_format.Parse( + """ feature { name: "label" type: FLOAT @@ -230,54 +231,54 @@ def testLabelsExtractorMultiModel(self): type: INT } """, - schema_pb2.Schema(), - ) - tfx_io = tfx_bsl_test_util.InMemoryTFExampleRecord( - schema=schema, raw_record_column_name=constants.ARROW_INPUT_COLUMN - ) + schema_pb2.Schema(), + ) + tfx_io = tfx_bsl_test_util.InMemoryTFExampleRecord( + schema=schema, raw_record_column_name=constants.ARROW_INPUT_COLUMN + ) - examples = [ - self._makeExample(label=1.0, label1=1.0, label2=0.0, fixed_int=1), - self._makeExample(label=1.0, label1=1.0, label2=1.0, fixed_int=1), - ] + examples = [ + self._makeExample(label=1.0, label1=1.0, label2=0.0, fixed_int=1), + self._makeExample(label=1.0, label1=1.0, label2=1.0, fixed_int=1), + ] - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'Create' - >> beam.Create( - [e.SerializeToString() for e in examples], reshuffle=False - ) - | 'BatchExamples' >> tfx_io.BeamSource(batch_size=2) - | 'InputsToExtracts' >> model_eval_lib.BatchedInputsToExtracts() - | feature_extractor.stage_name >> feature_extractor.ptransform - | label_extractor.stage_name >> label_extractor.ptransform - ) + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "Create" + >> beam.Create( + [e.SerializeToString() for e in examples], reshuffle=False + ) + | "BatchExamples" >> tfx_io.BeamSource(batch_size=2) + | "InputsToExtracts" >> model_eval_lib.BatchedInputsToExtracts() + | feature_extractor.stage_name >> feature_extractor.ptransform + | label_extractor.stage_name >> label_extractor.ptransform + ) - # pylint: enable=no-value-for-parameter + # pylint: enable=no-value-for-parameter - def check_result(got): - try: - self.assertLen(got, 1) - for model_name in ('model1', 'model2'): - self.assertIn(model_name, got[0][constants.LABELS_KEY]) - self.assertAllClose( - got[0][constants.LABELS_KEY]['model1'], np.array([[1.0], [1.0]]) - ) - self.assertAllClose( - got[0][constants.LABELS_KEY]['model2'], - { - 'output1': np.array([[1.0], [1.0]]), - 'output2': np.array([[0.0], [1.0]]), - }, - ) + def check_result(got): + try: + self.assertLen(got, 1) + for model_name in ("model1", "model2"): + self.assertIn(model_name, got[0][constants.LABELS_KEY]) + self.assertAllClose( + got[0][constants.LABELS_KEY]["model1"], np.array([[1.0], [1.0]]) + ) + self.assertAllClose( + got[0][constants.LABELS_KEY]["model2"], + { + "output1": np.array([[1.0], [1.0]]), + "output2": np.array([[0.0], [1.0]]), + }, + ) - except AssertionError as err: - raise util.BeamAssertException(err) + except AssertionError as err: + raise util.BeamAssertException(err) - util.assert_that(result, check_result, label='result') + util.assert_that(result, check_result, label="result") -if __name__ == '__main__': - tf.test.main() +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow_model_analysis/extractors/legacy_feature_extractor.py b/tensorflow_model_analysis/extractors/legacy_feature_extractor.py index 4291e2589f..9e48d92524 100644 --- a/tensorflow_model_analysis/extractors/legacy_feature_extractor.py +++ b/tensorflow_model_analysis/extractors/legacy_feature_extractor.py @@ -16,17 +16,18 @@ import copy from typing import Any, Dict, List, Optional -from absl import logging import apache_beam as beam import numpy as np import tensorflow as tf +from absl import logging + from tensorflow_model_analysis import constants from tensorflow_model_analysis.api import types from tensorflow_model_analysis.extractors import extractor from tensorflow_model_analysis.utils import util -_FEATURE_EXTRACTOR_STAGE_NAME = 'ExtractFeatures' -_ENCODING_NODE_SUFFIX = 'node' +_FEATURE_EXTRACTOR_STAGE_NAME = "ExtractFeatures" +_ENCODING_NODE_SUFFIX = "node" def FeatureExtractor( @@ -35,17 +36,17 @@ def FeatureExtractor( extract_source: str = constants.FEATURES_PREDICTIONS_LABELS_KEY, extract_dest: str = constants.MATERIALIZE_COLUMNS, ): - # pylint: disable=no-value-for-parameter - return extractor.Extractor( - stage_name=_FEATURE_EXTRACTOR_STAGE_NAME, - ptransform=_ExtractFeatures( - additional_extracts=additional_extracts, - excludes=excludes, - source=extract_source, - dest=extract_dest, - ), - ) - # pylint: enable=no-value-for-parameter + # pylint: disable=no-value-for-parameter + return extractor.Extractor( + stage_name=_FEATURE_EXTRACTOR_STAGE_NAME, + ptransform=_ExtractFeatures( + additional_extracts=additional_extracts, + excludes=excludes, + source=extract_source, + dest=extract_dest, + ), + ) + # pylint: enable=no-value-for-parameter def _AugmentExtracts( @@ -54,81 +55,81 @@ def _AugmentExtracts( excludes: List[bytes], extracts: types.Extracts, ) -> None: - """Augments the Extracts with FeaturesPredictionsLabels. - - Args: - data: Data dictionary returned by PredictExtractor. - prefix: Prefix to use in column naming (e.g. 'features', 'labels', etc). - excludes: List of strings containing features, predictions, or labels to - exclude from materialization. - extracts: The Extracts to be augmented. This is mutated in-place. - - Raises: - TypeError: If the FeaturesPredictionsLabels is corrupt. - """ - for name, val in data.items(): - if excludes is not None and name in excludes: - continue - # If data originated from FeaturesPredictionsLabels, then the value will be - # stored under a 'node' key. - if isinstance(val, dict) and _ENCODING_NODE_SUFFIX in val: - val = val.get(_ENCODING_NODE_SUFFIX) - - if name in (prefix, util.KEY_SEPARATOR + prefix): - col_name = prefix - elif prefix not in ('features', 'predictions', 'labels'): - # Names used by additional extracts should be properly escaped already so - # avoid escaping the name a second time by manually combining the prefix. - col_name = prefix + util.KEY_SEPARATOR + name - else: - col_name = util.compound_key([prefix, name]) - - if isinstance(val, tf.compat.v1.SparseTensorValue): - extracts[col_name] = types.MaterializedColumn( - name=col_name, value=val.values - ) - - elif isinstance(val, np.ndarray) or isinstance(val, list): - # Only support first dim for now - val = val[0] if len(val) > 0 else [] # pylint: disable=g-explicit-length-test - extracts[col_name] = types.MaterializedColumn(name=col_name, value=val) - - else: - raise TypeError( - 'Dictionary item with key %s, value %s had unexpected type %s' - % (name, val, type(val)) - ) - - -def _ParseExample( - extracts: types.Extracts, materialize_columns: bool = True -) -> None: - """Feature extraction from serialized tf.Example.""" - # Deserialize the example. - example = tf.train.Example() - try: - example.ParseFromString(extracts[constants.INPUT_KEY]) - except: # pylint: disable=bare-except - logging.warning('Could not parse tf.Example from the input source.') - - features = {} - if constants.FEATURES_PREDICTIONS_LABELS_KEY in extracts: - features = extracts[constants.FEATURES_PREDICTIONS_LABELS_KEY].features - - for name in example.features.feature: - if materialize_columns or name not in features: - key = util.compound_key(['features', name]) - value = example.features.feature[name] - if value.HasField('bytes_list'): - values = list(v for v in value.bytes_list.value) - elif value.HasField('float_list'): - values = list(v for v in value.float_list.value) - elif value.HasField('int64_list'): - values = list(v for v in value.int64_list.value) - if materialize_columns: - extracts[key] = types.MaterializedColumn(name=key, value=values) - if name not in features: - features[name] = {_ENCODING_NODE_SUFFIX: np.array([values])} + """Augments the Extracts with FeaturesPredictionsLabels. + + Args: + ---- + data: Data dictionary returned by PredictExtractor. + prefix: Prefix to use in column naming (e.g. 'features', 'labels', etc). + excludes: List of strings containing features, predictions, or labels to + exclude from materialization. + extracts: The Extracts to be augmented. This is mutated in-place. + + Raises: + ------ + TypeError: If the FeaturesPredictionsLabels is corrupt. + """ + for name, val in data.items(): + if excludes is not None and name in excludes: + continue + # If data originated from FeaturesPredictionsLabels, then the value will be + # stored under a 'node' key. + if isinstance(val, dict) and _ENCODING_NODE_SUFFIX in val: + val = val.get(_ENCODING_NODE_SUFFIX) + + if name in (prefix, util.KEY_SEPARATOR + prefix): + col_name = prefix + elif prefix not in ("features", "predictions", "labels"): + # Names used by additional extracts should be properly escaped already so + # avoid escaping the name a second time by manually combining the prefix. + col_name = prefix + util.KEY_SEPARATOR + name + else: + col_name = util.compound_key([prefix, name]) + + if isinstance(val, tf.compat.v1.SparseTensorValue): + extracts[col_name] = types.MaterializedColumn( + name=col_name, value=val.values + ) + + elif isinstance(val, np.ndarray) or isinstance(val, list): + # Only support first dim for now + val = val[0] if len(val) > 0 else [] # pylint: disable=g-explicit-length-test + extracts[col_name] = types.MaterializedColumn(name=col_name, value=val) + + else: + raise TypeError( + "Dictionary item with key %s, value %s had unexpected type %s" + % (name, val, type(val)) + ) + + +def _ParseExample(extracts: types.Extracts, materialize_columns: bool = True) -> None: + """Feature extraction from serialized tf.Example.""" + # Deserialize the example. + example = tf.train.Example() + try: + example.ParseFromString(extracts[constants.INPUT_KEY]) + except: # pylint: disable=bare-except + logging.warning("Could not parse tf.Example from the input source.") + + features = {} + if constants.FEATURES_PREDICTIONS_LABELS_KEY in extracts: + features = extracts[constants.FEATURES_PREDICTIONS_LABELS_KEY].features + + for name in example.features.feature: + if materialize_columns or name not in features: + key = util.compound_key(["features", name]) + value = example.features.feature[name] + if value.HasField("bytes_list"): + values = list(v for v in value.bytes_list.value) + elif value.HasField("float_list"): + values = list(v for v in value.float_list.value) + elif value.HasField("int64_list"): + values = list(v for v in value.int64_list.value) + if materialize_columns: + extracts[key] = types.MaterializedColumn(name=key, value=values) + if name not in features: + features[name] = {_ENCODING_NODE_SUFFIX: np.array([values])} def _MaterializeFeatures( @@ -138,70 +139,71 @@ def _MaterializeFeatures( source: str = constants.FEATURES_PREDICTIONS_LABELS_KEY, dest: str = constants.MATERIALIZE_COLUMNS, ) -> types.Extracts: - """Converts FeaturesPredictionsLabels into MaterializedColumn in the extract. - - It must be the case that the PredictExtractor was called before calling this - function. - - Args: - extracts: The Extracts to be augmented. - additional_extracts: Optional list of additional extracts to include along - with the features, predictions, and labels. - excludes: Optional list of strings containing features, predictions, or - labels to exclude from materialization. - source: Source for extracting features. Currently it supports extracting - features from FPLs and input tf.Example protos. - dest: Destination for extracted features. Currently supported are adding - materialized columns, or the features dict of the FPLs. - - Returns: - Returns Extracts (which is a deep copy of the original Extracts, so the - original isn't mutated) with features populated. - - Raises: - RuntimeError: When tfma.FEATURES_PREDICTIONS_LABELS_KEY key is not populated - by PredictExtractor for FPL source or incorrect extraction source given. - """ - # Make a deep copy, so we don't mutate the original. - result = copy.deepcopy(extracts) - - if additional_extracts: - for key in additional_extracts: - if key in result: - _AugmentExtracts(result[key], key, excludes, result) - - if source == constants.FEATURES_PREDICTIONS_LABELS_KEY: - fpl = result.get(constants.FEATURES_PREDICTIONS_LABELS_KEY) - if not fpl: - raise RuntimeError('FPL missing. Ensure PredictExtractor was called.') - - if not isinstance(fpl, types.FeaturesPredictionsLabels): - raise TypeError( - 'Expected FPL to be instance of FeaturesPredictionsLabel. FPL was: %s' - 'of type %s' % (str(fpl), type(fpl)) - ) - - # We disable pytyping here because we know that 'fpl' key corresponds to a - # non-materialized column. - # pytype: disable=attribute-error - _AugmentExtracts(fpl.features, constants.FEATURES_KEY, excludes, result) - _AugmentExtracts( - fpl.predictions, constants.PREDICTIONS_KEY, excludes, result - ) - _AugmentExtracts(fpl.labels, constants.LABELS_KEY, excludes, result) - # pytype: enable=attribute-error - return result - elif source == constants.INPUT_KEY: - serialized_example = result.get(constants.INPUT_KEY) - if not serialized_example: - raise RuntimeError( - 'tf.Example missing. Ensure extracts contain serialized tf.Example.' - ) - materialize_columns = dest == constants.MATERIALIZE_COLUMNS - _ParseExample(result, materialize_columns) - return result - else: - raise RuntimeError('Unsupported feature extraction source.') + """Converts FeaturesPredictionsLabels into MaterializedColumn in the extract. + + It must be the case that the PredictExtractor was called before calling this + function. + + Args: + ---- + extracts: The Extracts to be augmented. + additional_extracts: Optional list of additional extracts to include along + with the features, predictions, and labels. + excludes: Optional list of strings containing features, predictions, or + labels to exclude from materialization. + source: Source for extracting features. Currently it supports extracting + features from FPLs and input tf.Example protos. + dest: Destination for extracted features. Currently supported are adding + materialized columns, or the features dict of the FPLs. + + Returns: + ------- + Returns Extracts (which is a deep copy of the original Extracts, so the + original isn't mutated) with features populated. + + Raises: + ------ + RuntimeError: When tfma.FEATURES_PREDICTIONS_LABELS_KEY key is not populated + by PredictExtractor for FPL source or incorrect extraction source given. + """ + # Make a deep copy, so we don't mutate the original. + result = copy.deepcopy(extracts) + + if additional_extracts: + for key in additional_extracts: + if key in result: + _AugmentExtracts(result[key], key, excludes, result) + + if source == constants.FEATURES_PREDICTIONS_LABELS_KEY: + fpl = result.get(constants.FEATURES_PREDICTIONS_LABELS_KEY) + if not fpl: + raise RuntimeError("FPL missing. Ensure PredictExtractor was called.") + + if not isinstance(fpl, types.FeaturesPredictionsLabels): + raise TypeError( + "Expected FPL to be instance of FeaturesPredictionsLabel. FPL was: %s" + "of type %s" % (str(fpl), type(fpl)) + ) + + # We disable pytyping here because we know that 'fpl' key corresponds to a + # non-materialized column. + # pytype: disable=attribute-error + _AugmentExtracts(fpl.features, constants.FEATURES_KEY, excludes, result) + _AugmentExtracts(fpl.predictions, constants.PREDICTIONS_KEY, excludes, result) + _AugmentExtracts(fpl.labels, constants.LABELS_KEY, excludes, result) + # pytype: enable=attribute-error + return result + elif source == constants.INPUT_KEY: + serialized_example = result.get(constants.INPUT_KEY) + if not serialized_example: + raise RuntimeError( + "tf.Example missing. Ensure extracts contain serialized tf.Example." + ) + materialize_columns = dest == constants.MATERIALIZE_COLUMNS + _ParseExample(result, materialize_columns) + return result + else: + raise RuntimeError("Unsupported feature extraction source.") @beam.ptransform_fn @@ -214,30 +216,32 @@ def _ExtractFeatures( source: str = constants.FEATURES_PREDICTIONS_LABELS_KEY, dest: str = constants.MATERIALIZE_COLUMNS, ) -> beam.pvalue.PCollection: - """Builds MaterializedColumn extracts from FPL created in evaluate.Predict(). - - It must be the case that the PredictExtractor was called before calling this - function. - - Args: - extracts: PCollection containing the Extracts that will have - MaterializedColumn added to. - additional_extracts: Optional list of additional extracts to include along - with the features, predictions, and labels. - excludes: Optional list of strings containing features, predictions, or - labels to exclude from materialization. - source: Source for extracting features. Currently it supports extracting - features from FPLs and input tf.Example protos. - dest: Destination for extracted features. Currently supported are adding - materialized columns, or the features dict of the FPLs. - - Returns: - PCollection of Extracts - """ - return extracts | 'MaterializeFeatures' >> beam.Map( - _MaterializeFeatures, - additional_extracts=additional_extracts, - excludes=excludes, - source=source, - dest=dest, - ) + """Builds MaterializedColumn extracts from FPL created in evaluate.Predict(). + + It must be the case that the PredictExtractor was called before calling this + function. + + Args: + ---- + extracts: PCollection containing the Extracts that will have + MaterializedColumn added to. + additional_extracts: Optional list of additional extracts to include along + with the features, predictions, and labels. + excludes: Optional list of strings containing features, predictions, or + labels to exclude from materialization. + source: Source for extracting features. Currently it supports extracting + features from FPLs and input tf.Example protos. + dest: Destination for extracted features. Currently supported are adding + materialized columns, or the features dict of the FPLs. + + Returns: + ------- + PCollection of Extracts + """ + return extracts | "MaterializeFeatures" >> beam.Map( + _MaterializeFeatures, + additional_extracts=additional_extracts, + excludes=excludes, + source=source, + dest=dest, + ) diff --git a/tensorflow_model_analysis/extractors/legacy_feature_extractor_test.py b/tensorflow_model_analysis/extractors/legacy_feature_extractor_test.py index 0872bbeb6e..3fe8667812 100644 --- a/tensorflow_model_analysis/extractors/legacy_feature_extractor_test.py +++ b/tensorflow_model_analysis/extractors/legacy_feature_extractor_test.py @@ -15,229 +15,228 @@ import os import tempfile + import numpy as np import tensorflow as tf + from tensorflow_model_analysis import constants from tensorflow_model_analysis.api import types -from tensorflow_model_analysis.extractors import legacy_feature_extractor as feature_extractor +from tensorflow_model_analysis.extractors import ( + legacy_feature_extractor as feature_extractor, +) from tensorflow_model_analysis.utils import test_util -_ENCODING_NODE_SUFFIX = 'node' +_ENCODING_NODE_SUFFIX = "node" class BuildDiagnosticsTableTest(test_util.TensorflowModelAnalysisTest): - - def _getTempDir(self): - return tempfile.mkdtemp() - - def _exportEvalSavedModel(self, classifier): - temp_model_location = os.path.join(self._getTempDir(), 'eval_export_dir') - _, model_location = classifier(None, temp_model_location) - return model_location - - def testMaterializeFeaturesNoFpl(self): - example1 = self._makeExample( - age=3.0, language='english', label=1.0, slice_key='first_slice' - ) - - extracts = {constants.INPUT_KEY: example1.SerializeToString()} - self.assertRaises( - RuntimeError, feature_extractor._MaterializeFeatures, extracts - ) - - def testMaterializeFeaturesBadFPL(self): - example1 = self._makeExample( - age=3.0, language='english', label=1.0, slice_key='first_slice' - ) - - extracts = { - constants.INPUT_KEY: example1.SerializeToString(), - constants.FEATURES_PREDICTIONS_LABELS_KEY: 123, - } - self.assertRaises( - TypeError, feature_extractor._MaterializeFeatures, extracts - ) - - def testMaterializeFeaturesNoMaterializedColumns(self): - example1 = self._makeExample( - age=3.0, language='english', label=1.0, slice_key='first_slice' - ) - - features = { - 'f': {_ENCODING_NODE_SUFFIX: np.array([1])}, - 's': { - _ENCODING_NODE_SUFFIX: tf.compat.v1.SparseTensorValue( - indices=[[0, 5], [1, 2], [3, 6]], - values=[100.0, 200.0, 300.0], - dense_shape=[4, 10], - ) - }, - } - predictions = {'p': {_ENCODING_NODE_SUFFIX: np.array([2])}} - labels = {'l': {_ENCODING_NODE_SUFFIX: np.array([3])}} - - extracts = { - constants.INPUT_KEY: example1.SerializeToString(), - constants.FEATURES_PREDICTIONS_LABELS_KEY: ( - types.FeaturesPredictionsLabels( - input_ref=0, - features=features, - predictions=predictions, - labels=labels, - ) - ), - } - fpl = extracts[constants.FEATURES_PREDICTIONS_LABELS_KEY] - result = feature_extractor._MaterializeFeatures(extracts) - self.assertIsInstance(result, dict) - self.assertEqual( - result[constants.FEATURES_PREDICTIONS_LABELS_KEY], fpl - ) # should still be there. - self.assertEqual( - result['features__f'], - types.MaterializedColumn(name='features__f', value=[1]), - ) - self.assertEqual( - result['predictions__p'], - types.MaterializedColumn(name='predictions__p', value=[2]), - ) - self.assertEqual( - result['labels__l'], - types.MaterializedColumn(name='labels__l', value=[3]), - ) - self.assertEqual( - result['features__s'], - types.MaterializedColumn( - name='features__s', value=[100.0, 200.0, 300.0] - ), - ) - - def testAugmentFPLFromTfExample(self): - example1 = self._makeExample( - age=3.0, language='english', label=1.0, slice_key='first_slice', f=0.0 - ) - - features = { - 'f': {_ENCODING_NODE_SUFFIX: np.array([1])}, - 's': { - _ENCODING_NODE_SUFFIX: tf.compat.v1.SparseTensorValue( - indices=[[0, 5], [1, 2], [3, 6]], - values=[100.0, 200.0, 300.0], - dense_shape=[4, 10], - ) - }, - } - predictions = {'p': {_ENCODING_NODE_SUFFIX: np.array([2])}} - labels = {'l': {_ENCODING_NODE_SUFFIX: np.array([3])}} - - extracts = { - constants.INPUT_KEY: example1.SerializeToString(), - constants.FEATURES_PREDICTIONS_LABELS_KEY: ( - types.FeaturesPredictionsLabels( - input_ref=0, - features=features, - predictions=predictions, - labels=labels, - ) - ), - } - result = feature_extractor._MaterializeFeatures( - extracts, - source=constants.INPUT_KEY, - dest=constants.FEATURES_PREDICTIONS_LABELS_KEY, - ) - self.assertIsInstance(result, dict) - # Assert that materialized columns are not added. - self.assertNotIn('features__f', result) - self.assertNotIn('features__age', result) - # But that tf.Example features not present in FPL are. - result_fpl = result[constants.FEATURES_PREDICTIONS_LABELS_KEY] - self.assertEqual( - result_fpl.features['age'], {_ENCODING_NODE_SUFFIX: np.array([3.0])} - ) - self.assertEqual( - result_fpl.features['language'], - {'node': np.array([['english']], dtype='|S7')}, - ) - self.assertEqual( - result_fpl.features['slice_key'], - {'node': np.array([['first_slice']], dtype='|S11')}, - ) - # And that features present in both are not overwritten by tf.Example value. - self.assertEqual( - result_fpl.features['f'], {_ENCODING_NODE_SUFFIX: np.array([1])} - ) - - def testMaterializeFeaturesFromTfExample(self): - example1 = self._makeExample(age=3.0, language='english', label=1.0) - - extracts = {constants.INPUT_KEY: example1.SerializeToString()} - input_example = extracts[constants.INPUT_KEY] - result = feature_extractor._MaterializeFeatures( - extracts, source=constants.INPUT_KEY - ) - self.assertIsInstance(result, dict) - self.assertEqual( - result[constants.INPUT_KEY], input_example - ) # should still be there. - self.assertEqual( - result['features__age'], - types.MaterializedColumn(name='features__age', value=[3.0]), - ) - self.assertEqual( - result['features__language'], - types.MaterializedColumn(name='features__language', value=[b'english']), - ) - self.assertEqual( - result['features__label'], - types.MaterializedColumn(name='features__label', value=[1.0]), - ) - - def testMaterializeFeaturesWithBadSource(self): - example1 = self._makeExample(age=3.0, language='english', label=1.0) - - extracts = {constants.INPUT_KEY: example1.SerializeToString()} - - self.assertRaises( - RuntimeError, - feature_extractor._MaterializeFeatures, - extracts, - None, - '10', - ) - - def testMaterializeFeaturesWithExcludes(self): - example1 = self._makeExample( - age=3.0, language='english', label=1.0, slice_key='first_slice' - ) - - features = { - 'f': {_ENCODING_NODE_SUFFIX: np.array([1])}, - 's': { - _ENCODING_NODE_SUFFIX: tf.compat.v1.SparseTensorValue( - indices=[[0, 5], [1, 2], [3, 6]], - values=[100.0, 200.0, 300.0], - dense_shape=[4, 10], - ) - }, - } - predictions = {'p': {_ENCODING_NODE_SUFFIX: np.array([2])}} - labels = {'l': {_ENCODING_NODE_SUFFIX: np.array([3])}} - - extracts = { - constants.INPUT_KEY: example1.SerializeToString(), - constants.FEATURES_PREDICTIONS_LABELS_KEY: ( - types.FeaturesPredictionsLabels( - input_ref=0, - features=features, - predictions=predictions, - labels=labels, - ) - ), - } - result = feature_extractor._MaterializeFeatures(extracts, excludes=['s']) - self.assertNotIn('features__s', result) - - -if __name__ == '__main__': - tf.test.main() + def _getTempDir(self): + return tempfile.mkdtemp() + + def _exportEvalSavedModel(self, classifier): + temp_model_location = os.path.join(self._getTempDir(), "eval_export_dir") + _, model_location = classifier(None, temp_model_location) + return model_location + + def testMaterializeFeaturesNoFpl(self): + example1 = self._makeExample( + age=3.0, language="english", label=1.0, slice_key="first_slice" + ) + + extracts = {constants.INPUT_KEY: example1.SerializeToString()} + self.assertRaises( + RuntimeError, feature_extractor._MaterializeFeatures, extracts + ) + + def testMaterializeFeaturesBadFPL(self): + example1 = self._makeExample( + age=3.0, language="english", label=1.0, slice_key="first_slice" + ) + + extracts = { + constants.INPUT_KEY: example1.SerializeToString(), + constants.FEATURES_PREDICTIONS_LABELS_KEY: 123, + } + self.assertRaises(TypeError, feature_extractor._MaterializeFeatures, extracts) + + def testMaterializeFeaturesNoMaterializedColumns(self): + example1 = self._makeExample( + age=3.0, language="english", label=1.0, slice_key="first_slice" + ) + + features = { + "f": {_ENCODING_NODE_SUFFIX: np.array([1])}, + "s": { + _ENCODING_NODE_SUFFIX: tf.compat.v1.SparseTensorValue( + indices=[[0, 5], [1, 2], [3, 6]], + values=[100.0, 200.0, 300.0], + dense_shape=[4, 10], + ) + }, + } + predictions = {"p": {_ENCODING_NODE_SUFFIX: np.array([2])}} + labels = {"l": {_ENCODING_NODE_SUFFIX: np.array([3])}} + + extracts = { + constants.INPUT_KEY: example1.SerializeToString(), + constants.FEATURES_PREDICTIONS_LABELS_KEY: ( + types.FeaturesPredictionsLabels( + input_ref=0, + features=features, + predictions=predictions, + labels=labels, + ) + ), + } + fpl = extracts[constants.FEATURES_PREDICTIONS_LABELS_KEY] + result = feature_extractor._MaterializeFeatures(extracts) + self.assertIsInstance(result, dict) + self.assertEqual( + result[constants.FEATURES_PREDICTIONS_LABELS_KEY], fpl + ) # should still be there. + self.assertEqual( + result["features__f"], + types.MaterializedColumn(name="features__f", value=[1]), + ) + self.assertEqual( + result["predictions__p"], + types.MaterializedColumn(name="predictions__p", value=[2]), + ) + self.assertEqual( + result["labels__l"], + types.MaterializedColumn(name="labels__l", value=[3]), + ) + self.assertEqual( + result["features__s"], + types.MaterializedColumn(name="features__s", value=[100.0, 200.0, 300.0]), + ) + + def testAugmentFPLFromTfExample(self): + example1 = self._makeExample( + age=3.0, language="english", label=1.0, slice_key="first_slice", f=0.0 + ) + + features = { + "f": {_ENCODING_NODE_SUFFIX: np.array([1])}, + "s": { + _ENCODING_NODE_SUFFIX: tf.compat.v1.SparseTensorValue( + indices=[[0, 5], [1, 2], [3, 6]], + values=[100.0, 200.0, 300.0], + dense_shape=[4, 10], + ) + }, + } + predictions = {"p": {_ENCODING_NODE_SUFFIX: np.array([2])}} + labels = {"l": {_ENCODING_NODE_SUFFIX: np.array([3])}} + + extracts = { + constants.INPUT_KEY: example1.SerializeToString(), + constants.FEATURES_PREDICTIONS_LABELS_KEY: ( + types.FeaturesPredictionsLabels( + input_ref=0, + features=features, + predictions=predictions, + labels=labels, + ) + ), + } + result = feature_extractor._MaterializeFeatures( + extracts, + source=constants.INPUT_KEY, + dest=constants.FEATURES_PREDICTIONS_LABELS_KEY, + ) + self.assertIsInstance(result, dict) + # Assert that materialized columns are not added. + self.assertNotIn("features__f", result) + self.assertNotIn("features__age", result) + # But that tf.Example features not present in FPL are. + result_fpl = result[constants.FEATURES_PREDICTIONS_LABELS_KEY] + self.assertEqual( + result_fpl.features["age"], {_ENCODING_NODE_SUFFIX: np.array([3.0])} + ) + self.assertEqual( + result_fpl.features["language"], + {"node": np.array([["english"]], dtype="|S7")}, + ) + self.assertEqual( + result_fpl.features["slice_key"], + {"node": np.array([["first_slice"]], dtype="|S11")}, + ) + # And that features present in both are not overwritten by tf.Example value. + self.assertEqual( + result_fpl.features["f"], {_ENCODING_NODE_SUFFIX: np.array([1])} + ) + + def testMaterializeFeaturesFromTfExample(self): + example1 = self._makeExample(age=3.0, language="english", label=1.0) + + extracts = {constants.INPUT_KEY: example1.SerializeToString()} + input_example = extracts[constants.INPUT_KEY] + result = feature_extractor._MaterializeFeatures( + extracts, source=constants.INPUT_KEY + ) + self.assertIsInstance(result, dict) + self.assertEqual( + result[constants.INPUT_KEY], input_example + ) # should still be there. + self.assertEqual( + result["features__age"], + types.MaterializedColumn(name="features__age", value=[3.0]), + ) + self.assertEqual( + result["features__language"], + types.MaterializedColumn(name="features__language", value=[b"english"]), + ) + self.assertEqual( + result["features__label"], + types.MaterializedColumn(name="features__label", value=[1.0]), + ) + + def testMaterializeFeaturesWithBadSource(self): + example1 = self._makeExample(age=3.0, language="english", label=1.0) + + extracts = {constants.INPUT_KEY: example1.SerializeToString()} + + self.assertRaises( + RuntimeError, + feature_extractor._MaterializeFeatures, + extracts, + None, + "10", + ) + + def testMaterializeFeaturesWithExcludes(self): + example1 = self._makeExample( + age=3.0, language="english", label=1.0, slice_key="first_slice" + ) + + features = { + "f": {_ENCODING_NODE_SUFFIX: np.array([1])}, + "s": { + _ENCODING_NODE_SUFFIX: tf.compat.v1.SparseTensorValue( + indices=[[0, 5], [1, 2], [3, 6]], + values=[100.0, 200.0, 300.0], + dense_shape=[4, 10], + ) + }, + } + predictions = {"p": {_ENCODING_NODE_SUFFIX: np.array([2])}} + labels = {"l": {_ENCODING_NODE_SUFFIX: np.array([3])}} + + extracts = { + constants.INPUT_KEY: example1.SerializeToString(), + constants.FEATURES_PREDICTIONS_LABELS_KEY: ( + types.FeaturesPredictionsLabels( + input_ref=0, + features=features, + predictions=predictions, + labels=labels, + ) + ), + } + result = feature_extractor._MaterializeFeatures(extracts, excludes=["s"]) + self.assertNotIn("features__s", result) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow_model_analysis/extractors/legacy_input_extractor.py b/tensorflow_model_analysis/extractors/legacy_input_extractor.py index ff95619db2..a5013aed45 100644 --- a/tensorflow_model_analysis/extractors/legacy_input_extractor.py +++ b/tensorflow_model_analysis/extractors/legacy_input_extractor.py @@ -18,118 +18,120 @@ import apache_beam as beam import numpy as np +from tfx_bsl.coders import example_coder + from tensorflow_model_analysis import constants from tensorflow_model_analysis.api import types from tensorflow_model_analysis.extractors import extractor from tensorflow_model_analysis.proto import config_pb2 -from tfx_bsl.coders import example_coder -_INPUT_EXTRACTOR_STAGE_NAME = 'ExtractInputs' +_INPUT_EXTRACTOR_STAGE_NAME = "ExtractInputs" def InputExtractor(eval_config: config_pb2.EvalConfig) -> extractor.Extractor: - """Creates an extractor for extracting features, labels, and example weights. - - The extractor's PTransform parses tf.train.Example protos stored under the - tfma.INPUT_KEY in the incoming extracts and adds the resulting features, - labels, and example weights to the extracts under the keys tfma.FEATURES_KEY, - tfma.LABELS_KEY, and tfma.EXAMPLE_WEIGHTS_KEY. If the eval_config contains a - prediction_key and a corresponding key is found in the parse example, then - predictions will also be extracted and stored under the tfma.PREDICTIONS_KEY. - Any extracts that already exist will be merged with the values parsed by this - extractor with this extractor's values taking precedence when duplicate keys - are detected. - - Note that the use of a prediction_key in an eval_config serves two use cases: - (1) as a key into the dict of predictions output by predict extractor - (2) as the key for a pre-computed prediction stored as a feature. - The InputExtractor can be used to handle case (2). These cases are meant to be - exclusive (i.e. if approach (2) is used then a predict extractor would not be - configured and if (1) is used then a key matching the predictons would not be - stored in the features). However, if a feature key happens to match the same - name as the prediction output key then both paths may be executed. In this - case, the value stored here will be replaced by the predict extractor (though - it will still be popped from the features). - - Args: - eval_config: Eval config. - - Returns: - Extractor for extracting features, labels, and example weights inputs. - """ - # pylint: disable=no-value-for-parameter - return extractor.Extractor( - stage_name=_INPUT_EXTRACTOR_STAGE_NAME, - ptransform=_ExtractInputs(eval_config=eval_config), - ) + """Creates an extractor for extracting features, labels, and example weights. + + The extractor's PTransform parses tf.train.Example protos stored under the + tfma.INPUT_KEY in the incoming extracts and adds the resulting features, + labels, and example weights to the extracts under the keys tfma.FEATURES_KEY, + tfma.LABELS_KEY, and tfma.EXAMPLE_WEIGHTS_KEY. If the eval_config contains a + prediction_key and a corresponding key is found in the parse example, then + predictions will also be extracted and stored under the tfma.PREDICTIONS_KEY. + Any extracts that already exist will be merged with the values parsed by this + extractor with this extractor's values taking precedence when duplicate keys + are detected. + + Note that the use of a prediction_key in an eval_config serves two use cases: + (1) as a key into the dict of predictions output by predict extractor + (2) as the key for a pre-computed prediction stored as a feature. + The InputExtractor can be used to handle case (2). These cases are meant to be + exclusive (i.e. if approach (2) is used then a predict extractor would not be + configured and if (1) is used then a key matching the predictons would not be + stored in the features). However, if a feature key happens to match the same + name as the prediction output key then both paths may be executed. In this + case, the value stored here will be replaced by the predict extractor (though + it will still be popped from the features). + + Args: + ---- + eval_config: Eval config. + + Returns: + ------- + Extractor for extracting features, labels, and example weights inputs. + """ + # pylint: disable=no-value-for-parameter + return extractor.Extractor( + stage_name=_INPUT_EXTRACTOR_STAGE_NAME, + ptransform=_ExtractInputs(eval_config=eval_config), + ) def _keys_and_values( # pylint: disable=invalid-name key_maybe_dict: Union[str, Dict[str, str]], features: Dict[str, np.ndarray] -) -> Tuple[ - Optional[List[str]], Optional[Union[np.ndarray, Dict[str, np.ndarray]]] -]: - """Returns keys and values in dict given key (or dict of keys).""" - if isinstance(key_maybe_dict, dict): - values = {} - keys = set() - for output_name, key in key_maybe_dict.items(): - if key in features: - values[output_name] = features[key] - if key not in keys: - keys.add(key) - return (list(keys), values) - elif key_maybe_dict in features: - return ([key_maybe_dict], features[key_maybe_dict]) - else: - return ([], None) +) -> Tuple[Optional[List[str]], Optional[Union[np.ndarray, Dict[str, np.ndarray]]]]: + """Returns keys and values in dict given key (or dict of keys).""" + if isinstance(key_maybe_dict, dict): + values = {} + keys = set() + for output_name, key in key_maybe_dict.items(): + if key in features: + values[output_name] = features[key] + if key not in keys: + keys.add(key) + return (list(keys), values) + elif key_maybe_dict in features: + return ([key_maybe_dict], features[key_maybe_dict]) + else: + return ([], None) def _ParseExample(extracts: types.Extracts, eval_config: config_pb2.EvalConfig): - """Parses serialized tf.train.Example to create additional extracts. - - Args: - extracts: PCollection containing serialized examples under tfma.INPUT_KEY. - eval_config: Eval config. - - Returns: - Extracts with additional keys added for features, labels, and example - weights. - """ - - features = example_coder.ExampleToNumpyDict(extracts[constants.INPUT_KEY]) - extracts = copy.copy(extracts) - - def add_to_extracts( # pylint: disable=invalid-name - key: str, model_name: str, feature_values: Any - ): - """Adds features_values to extracts and feature_keys to keys_to_pop.""" - # Only key by model name if multiple models. - if len(eval_config.model_specs) > 1: - if key not in extracts: - extracts[key] = {} - extracts[key][model_name] = feature_values - else: - extracts[key] = feature_values - - for spec in eval_config.model_specs: - if spec.label_key or spec.label_keys: - _, values = _keys_and_values( - spec.label_key or dict(spec.label_keys), features - ) - add_to_extracts(constants.LABELS_KEY, spec.name, values) - if spec.example_weight_key or spec.example_weight_keys: - _, values = _keys_and_values( - spec.example_weight_key or dict(spec.example_weight_keys), features - ) - add_to_extracts(constants.EXAMPLE_WEIGHTS_KEY, spec.name, values) - if spec.prediction_key or spec.prediction_keys: - _, values = _keys_and_values( - spec.prediction_key or dict(spec.prediction_keys), features - ) - add_to_extracts(constants.PREDICTIONS_KEY, spec.name, values) - extracts[constants.FEATURES_KEY] = features - return extracts + """Parses serialized tf.train.Example to create additional extracts. + + Args: + ---- + extracts: PCollection containing serialized examples under tfma.INPUT_KEY. + eval_config: Eval config. + + Returns: + ------- + Extracts with additional keys added for features, labels, and example + weights. + """ + features = example_coder.ExampleToNumpyDict(extracts[constants.INPUT_KEY]) + extracts = copy.copy(extracts) + + def add_to_extracts( # pylint: disable=invalid-name + key: str, model_name: str, feature_values: Any + ): + """Adds features_values to extracts and feature_keys to keys_to_pop.""" + # Only key by model name if multiple models. + if len(eval_config.model_specs) > 1: + if key not in extracts: + extracts[key] = {} + extracts[key][model_name] = feature_values + else: + extracts[key] = feature_values + + for spec in eval_config.model_specs: + if spec.label_key or spec.label_keys: + _, values = _keys_and_values( + spec.label_key or dict(spec.label_keys), features + ) + add_to_extracts(constants.LABELS_KEY, spec.name, values) + if spec.example_weight_key or spec.example_weight_keys: + _, values = _keys_and_values( + spec.example_weight_key or dict(spec.example_weight_keys), features + ) + add_to_extracts(constants.EXAMPLE_WEIGHTS_KEY, spec.name, values) + if spec.prediction_key or spec.prediction_keys: + _, values = _keys_and_values( + spec.prediction_key or dict(spec.prediction_keys), features + ) + add_to_extracts(constants.PREDICTIONS_KEY, spec.name, values) + extracts[constants.FEATURES_KEY] = features + return extracts @beam.ptransform_fn @@ -138,15 +140,17 @@ def add_to_extracts( # pylint: disable=invalid-name def _ExtractInputs( extracts: beam.pvalue.PCollection, eval_config: config_pb2.EvalConfig ) -> beam.pvalue.PCollection: - """Extracts inputs from serialized tf.train.Example protos. - - Args: - extracts: PCollection containing serialized examples under tfma.INPUT_KEY. - eval_config: Eval config. - - Returns: - PCollection of extracts with additional features, labels, and weights added - under the keys tfma.FEATURES_KEY, tfma.LABELS_KEY, and - tfma.EXAMPLE_WEIGHTS_KEY. - """ - return extracts | 'ParseExample' >> beam.Map(_ParseExample, eval_config) + """Extracts inputs from serialized tf.train.Example protos. + + Args: + ---- + extracts: PCollection containing serialized examples under tfma.INPUT_KEY. + eval_config: Eval config. + + Returns: + ------- + PCollection of extracts with additional features, labels, and weights added + under the keys tfma.FEATURES_KEY, tfma.LABELS_KEY, and + tfma.EXAMPLE_WEIGHTS_KEY. + """ + return extracts | "ParseExample" >> beam.Map(_ParseExample, eval_config) diff --git a/tensorflow_model_analysis/extractors/legacy_input_extractor_test.py b/tensorflow_model_analysis/extractors/legacy_input_extractor_test.py index f83fa164ca..ca77c706c2 100644 --- a/tensorflow_model_analysis/extractors/legacy_input_extractor_test.py +++ b/tensorflow_model_analysis/extractors/legacy_input_extractor_test.py @@ -14,380 +14,386 @@ """Tests for input extractor.""" import apache_beam as beam -from apache_beam.testing import util import numpy as np import tensorflow as tf +from apache_beam.testing import util + from tensorflow_model_analysis import constants from tensorflow_model_analysis.api import model_eval_lib -from tensorflow_model_analysis.extractors import legacy_input_extractor as input_extractor +from tensorflow_model_analysis.extractors import ( + legacy_input_extractor as input_extractor, +) from tensorflow_model_analysis.proto import config_pb2 from tensorflow_model_analysis.utils import test_util class InputExtractorTest(test_util.TensorflowModelAnalysisTest): + def testInputExtractor(self): + model_spec = config_pb2.ModelSpec( + label_key="label", example_weight_key="example_weight" + ) + extractor = input_extractor.InputExtractor( + eval_config=config_pb2.EvalConfig(model_specs=[model_spec]) + ) - def testInputExtractor(self): - model_spec = config_pb2.ModelSpec( - label_key='label', example_weight_key='example_weight' - ) - extractor = input_extractor.InputExtractor( - eval_config=config_pb2.EvalConfig(model_specs=[model_spec]) - ) - - examples = [ - self._makeExample( - label=1.0, - example_weight=0.5, - fixed_int=1, - fixed_float=1.0, - fixed_string='fixed_string1', - ), - self._makeExample( - label=0.0, - example_weight=0.0, - fixed_int=1, - fixed_float=1.0, - fixed_string='fixed_string2', - ), - self._makeExample( - label=0.0, - example_weight=1.0, - fixed_int=2, - fixed_float=0.0, - fixed_string='fixed_string3', - ), - ] + examples = [ + self._makeExample( + label=1.0, + example_weight=0.5, + fixed_int=1, + fixed_float=1.0, + fixed_string="fixed_string1", + ), + self._makeExample( + label=0.0, + example_weight=0.0, + fixed_int=1, + fixed_float=1.0, + fixed_string="fixed_string2", + ), + self._makeExample( + label=0.0, + example_weight=1.0, + fixed_int=2, + fixed_float=0.0, + fixed_string="fixed_string3", + ), + ] - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'Create' - >> beam.Create( - [e.SerializeToString() for e in examples], reshuffle=False - ) - | 'InputsToExtracts' >> model_eval_lib.InputsToExtracts() - | extractor.stage_name >> extractor.ptransform - ) + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "Create" + >> beam.Create( + [e.SerializeToString() for e in examples], reshuffle=False + ) + | "InputsToExtracts" >> model_eval_lib.InputsToExtracts() + | extractor.stage_name >> extractor.ptransform + ) - # pylint: enable=no-value-for-parameter + # pylint: enable=no-value-for-parameter - def check_result(got): - try: - self.assertLen(got, 3) - self.assertDictElementsAlmostEqual( - got[0][constants.FEATURES_KEY], - { - 'fixed_int': np.array([1]), - 'fixed_float': np.array([1.0]), - 'label': np.array([1.0]), - 'example_weight': np.array([0.5]), - }, - ) - self.assertEqual( - got[0][constants.FEATURES_KEY]['fixed_string'], - np.array([b'fixed_string1']), - ) - self.assertAlmostEqual(got[0][constants.LABELS_KEY], np.array([1.0])) - self.assertAlmostEqual( - got[0][constants.EXAMPLE_WEIGHTS_KEY], np.array([0.5]) - ) - self.assertDictElementsAlmostEqual( - got[1][constants.FEATURES_KEY], - { - 'fixed_int': np.array([1]), - 'fixed_float': np.array([1.0]), - 'label': np.array([0.0]), - 'example_weight': np.array([0.0]), - }, - ) - self.assertEqual( - got[1][constants.FEATURES_KEY]['fixed_string'], - np.array([b'fixed_string2']), - ) - self.assertAlmostEqual(got[1][constants.LABELS_KEY], np.array([0.0])) - self.assertAlmostEqual( - got[1][constants.EXAMPLE_WEIGHTS_KEY], np.array([0.0]) - ) - self.assertDictElementsAlmostEqual( - got[2][constants.FEATURES_KEY], - { - 'fixed_int': np.array([2]), - 'fixed_float': np.array([0.0]), - 'label': np.array([0.0]), - 'example_weight': np.array([1.0]), - }, - ) - self.assertEqual( - got[2][constants.FEATURES_KEY]['fixed_string'], - np.array([b'fixed_string3']), - ) - self.assertAlmostEqual(got[2][constants.LABELS_KEY], np.array([0.0])) - self.assertAlmostEqual( - got[2][constants.EXAMPLE_WEIGHTS_KEY], np.array([1.0]) - ) + def check_result(got): + try: + self.assertLen(got, 3) + self.assertDictElementsAlmostEqual( + got[0][constants.FEATURES_KEY], + { + "fixed_int": np.array([1]), + "fixed_float": np.array([1.0]), + "label": np.array([1.0]), + "example_weight": np.array([0.5]), + }, + ) + self.assertEqual( + got[0][constants.FEATURES_KEY]["fixed_string"], + np.array([b"fixed_string1"]), + ) + self.assertAlmostEqual( + got[0][constants.LABELS_KEY], np.array([1.0]) + ) + self.assertAlmostEqual( + got[0][constants.EXAMPLE_WEIGHTS_KEY], np.array([0.5]) + ) + self.assertDictElementsAlmostEqual( + got[1][constants.FEATURES_KEY], + { + "fixed_int": np.array([1]), + "fixed_float": np.array([1.0]), + "label": np.array([0.0]), + "example_weight": np.array([0.0]), + }, + ) + self.assertEqual( + got[1][constants.FEATURES_KEY]["fixed_string"], + np.array([b"fixed_string2"]), + ) + self.assertAlmostEqual( + got[1][constants.LABELS_KEY], np.array([0.0]) + ) + self.assertAlmostEqual( + got[1][constants.EXAMPLE_WEIGHTS_KEY], np.array([0.0]) + ) + self.assertDictElementsAlmostEqual( + got[2][constants.FEATURES_KEY], + { + "fixed_int": np.array([2]), + "fixed_float": np.array([0.0]), + "label": np.array([0.0]), + "example_weight": np.array([1.0]), + }, + ) + self.assertEqual( + got[2][constants.FEATURES_KEY]["fixed_string"], + np.array([b"fixed_string3"]), + ) + self.assertAlmostEqual( + got[2][constants.LABELS_KEY], np.array([0.0]) + ) + self.assertAlmostEqual( + got[2][constants.EXAMPLE_WEIGHTS_KEY], np.array([1.0]) + ) - except AssertionError as err: - raise util.BeamAssertException(err) + except AssertionError as err: + raise util.BeamAssertException(err) - util.assert_that(result, check_result, label='result') + util.assert_that(result, check_result, label="result") - def testInputExtractorMultiOutput(self): - model_spec = config_pb2.ModelSpec( - label_keys={'output1': 'label1', 'output2': 'label2'}, - example_weight_keys={ - 'output1': 'example_weight1', - 'output2': 'example_weight2', - }, - ) - extractor = input_extractor.InputExtractor( - eval_config=config_pb2.EvalConfig(model_specs=[model_spec]) - ) + def testInputExtractorMultiOutput(self): + model_spec = config_pb2.ModelSpec( + label_keys={"output1": "label1", "output2": "label2"}, + example_weight_keys={ + "output1": "example_weight1", + "output2": "example_weight2", + }, + ) + extractor = input_extractor.InputExtractor( + eval_config=config_pb2.EvalConfig(model_specs=[model_spec]) + ) - examples = [ - self._makeExample( - label1=1.0, - label2=0.0, - example_weight1=0.5, - example_weight2=0.5, - fixed_int=1, - fixed_float=1.0, - fixed_string='fixed_string1', - ), - self._makeExample( - label1=1.0, - label2=1.0, - example_weight1=0.0, - example_weight2=1.0, - fixed_int=1, - fixed_float=1.0, - fixed_string='fixed_string2', - ), - ] + examples = [ + self._makeExample( + label1=1.0, + label2=0.0, + example_weight1=0.5, + example_weight2=0.5, + fixed_int=1, + fixed_float=1.0, + fixed_string="fixed_string1", + ), + self._makeExample( + label1=1.0, + label2=1.0, + example_weight1=0.0, + example_weight2=1.0, + fixed_int=1, + fixed_float=1.0, + fixed_string="fixed_string2", + ), + ] - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'Create' - >> beam.Create( - [e.SerializeToString() for e in examples], reshuffle=False - ) - | 'InputsToExtracts' >> model_eval_lib.InputsToExtracts() - | extractor.stage_name >> extractor.ptransform - ) + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "Create" + >> beam.Create( + [e.SerializeToString() for e in examples], reshuffle=False + ) + | "InputsToExtracts" >> model_eval_lib.InputsToExtracts() + | extractor.stage_name >> extractor.ptransform + ) - # pylint: enable=no-value-for-parameter + # pylint: enable=no-value-for-parameter - def check_result(got): - try: - self.assertLen(got, 2) - self.assertDictElementsAlmostEqual( - got[0][constants.FEATURES_KEY], - { - 'fixed_int': np.array([1]), - 'fixed_float': np.array([1.0]), - 'label1': np.array([1.0]), - 'label2': np.array([0.0]), - 'example_weight1': np.array([0.5]), - 'example_weight2': np.array([0.5]), - }, - ) - self.assertEqual( - got[0][constants.FEATURES_KEY]['fixed_string'], - np.array([b'fixed_string1']), - ) - self.assertDictElementsAlmostEqual( - got[0][constants.LABELS_KEY], - {'output1': np.array([1.0]), 'output2': np.array([0.0])}, - ) - self.assertDictElementsAlmostEqual( - got[0][constants.EXAMPLE_WEIGHTS_KEY], - {'output1': np.array([0.5]), 'output2': np.array([0.5])}, - ) - self.assertDictElementsAlmostEqual( - got[1][constants.FEATURES_KEY], - { - 'fixed_int': np.array([1]), - 'fixed_float': np.array([1.0]), - 'label1': np.array([1.0]), - 'label2': np.array([1.0]), - 'example_weight1': np.array([0.0]), - 'example_weight2': np.array([1.0]), - }, - ) - self.assertEqual( - got[1][constants.FEATURES_KEY]['fixed_string'], - np.array([b'fixed_string2']), - ) - self.assertDictElementsAlmostEqual( - got[1][constants.LABELS_KEY], - {'output1': np.array([1.0]), 'output2': np.array([1.0])}, - ) - self.assertDictElementsAlmostEqual( - got[1][constants.EXAMPLE_WEIGHTS_KEY], - {'output1': np.array([0.0]), 'output2': np.array([1.0])}, - ) + def check_result(got): + try: + self.assertLen(got, 2) + self.assertDictElementsAlmostEqual( + got[0][constants.FEATURES_KEY], + { + "fixed_int": np.array([1]), + "fixed_float": np.array([1.0]), + "label1": np.array([1.0]), + "label2": np.array([0.0]), + "example_weight1": np.array([0.5]), + "example_weight2": np.array([0.5]), + }, + ) + self.assertEqual( + got[0][constants.FEATURES_KEY]["fixed_string"], + np.array([b"fixed_string1"]), + ) + self.assertDictElementsAlmostEqual( + got[0][constants.LABELS_KEY], + {"output1": np.array([1.0]), "output2": np.array([0.0])}, + ) + self.assertDictElementsAlmostEqual( + got[0][constants.EXAMPLE_WEIGHTS_KEY], + {"output1": np.array([0.5]), "output2": np.array([0.5])}, + ) + self.assertDictElementsAlmostEqual( + got[1][constants.FEATURES_KEY], + { + "fixed_int": np.array([1]), + "fixed_float": np.array([1.0]), + "label1": np.array([1.0]), + "label2": np.array([1.0]), + "example_weight1": np.array([0.0]), + "example_weight2": np.array([1.0]), + }, + ) + self.assertEqual( + got[1][constants.FEATURES_KEY]["fixed_string"], + np.array([b"fixed_string2"]), + ) + self.assertDictElementsAlmostEqual( + got[1][constants.LABELS_KEY], + {"output1": np.array([1.0]), "output2": np.array([1.0])}, + ) + self.assertDictElementsAlmostEqual( + got[1][constants.EXAMPLE_WEIGHTS_KEY], + {"output1": np.array([0.0]), "output2": np.array([1.0])}, + ) - except AssertionError as err: - raise util.BeamAssertException(err) + except AssertionError as err: + raise util.BeamAssertException(err) - util.assert_that(result, check_result, label='result') + util.assert_that(result, check_result, label="result") - def testInputExtractorMultiModel(self): - model_spec1 = config_pb2.ModelSpec( - name='model1', - label_key='label', - example_weight_key='example_weight', - prediction_key='fixed_float', - ) - model_spec2 = config_pb2.ModelSpec( - name='model2', - label_keys={'output1': 'label1', 'output2': 'label2'}, - example_weight_keys={ - 'output1': 'example_weight1', - 'output2': 'example_weight2', - }, - prediction_keys={'output1': 'fixed_float', 'output2': 'fixed_float'}, - ) - extractor = input_extractor.InputExtractor( - eval_config=config_pb2.EvalConfig( - model_specs=[model_spec1, model_spec2] + def testInputExtractorMultiModel(self): + model_spec1 = config_pb2.ModelSpec( + name="model1", + label_key="label", + example_weight_key="example_weight", + prediction_key="fixed_float", + ) + model_spec2 = config_pb2.ModelSpec( + name="model2", + label_keys={"output1": "label1", "output2": "label2"}, + example_weight_keys={ + "output1": "example_weight1", + "output2": "example_weight2", + }, + prediction_keys={"output1": "fixed_float", "output2": "fixed_float"}, + ) + extractor = input_extractor.InputExtractor( + eval_config=config_pb2.EvalConfig(model_specs=[model_spec1, model_spec2]) ) - ) - examples = [ - self._makeExample( - label=1.0, - label1=1.0, - label2=0.0, - example_weight=0.5, - example_weight1=0.5, - example_weight2=0.5, - fixed_int=1, - fixed_float=1.0, - fixed_string='fixed_string1', - ), - self._makeExample( - label=1.0, - label1=1.0, - label2=1.0, - example_weight=0.0, - example_weight1=0.0, - example_weight2=1.0, - fixed_int=1, - fixed_float=2.0, - fixed_string='fixed_string2', - ), - ] + examples = [ + self._makeExample( + label=1.0, + label1=1.0, + label2=0.0, + example_weight=0.5, + example_weight1=0.5, + example_weight2=0.5, + fixed_int=1, + fixed_float=1.0, + fixed_string="fixed_string1", + ), + self._makeExample( + label=1.0, + label1=1.0, + label2=1.0, + example_weight=0.0, + example_weight1=0.0, + example_weight2=1.0, + fixed_int=1, + fixed_float=2.0, + fixed_string="fixed_string2", + ), + ] - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'Create' - >> beam.Create( - [e.SerializeToString() for e in examples], reshuffle=False - ) - | 'InputsToExtracts' >> model_eval_lib.InputsToExtracts() - | extractor.stage_name >> extractor.ptransform - ) + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "Create" + >> beam.Create( + [e.SerializeToString() for e in examples], reshuffle=False + ) + | "InputsToExtracts" >> model_eval_lib.InputsToExtracts() + | extractor.stage_name >> extractor.ptransform + ) - # pylint: enable=no-value-for-parameter + # pylint: enable=no-value-for-parameter - def check_result(got): - try: - self.assertLen(got, 2) - self.assertDictElementsAlmostEqual( - got[0][constants.FEATURES_KEY], - { - 'fixed_int': np.array([1]), - 'label': np.array([1.0]), - 'label1': np.array([1.0]), - 'label2': np.array([0.0]), - 'example_weight': np.array([0.5]), - 'example_weight1': np.array([0.5]), - 'example_weight2': np.array([0.5]), - }, - ) - self.assertEqual( - got[0][constants.FEATURES_KEY]['fixed_string'], - np.array([b'fixed_string1']), - ) - for model_name in ('model1', 'model2'): - self.assertIn(model_name, got[0][constants.LABELS_KEY]) - self.assertIn(model_name, got[0][constants.EXAMPLE_WEIGHTS_KEY]) - self.assertIn(model_name, got[0][constants.PREDICTIONS_KEY]) - self.assertAlmostEqual( - got[0][constants.LABELS_KEY]['model1'], np.array([1.0]) - ) - self.assertDictElementsAlmostEqual( - got[0][constants.LABELS_KEY]['model2'], - {'output1': np.array([1.0]), 'output2': np.array([0.0])}, - ) - self.assertAlmostEqual( - got[0][constants.EXAMPLE_WEIGHTS_KEY]['model1'], np.array([0.5]) - ) - self.assertDictElementsAlmostEqual( - got[0][constants.EXAMPLE_WEIGHTS_KEY]['model2'], - {'output1': np.array([0.5]), 'output2': np.array([0.5])}, - ) - self.assertAlmostEqual( - got[0][constants.PREDICTIONS_KEY]['model1'], np.array([1.0]) - ) - self.assertDictElementsAlmostEqual( - got[0][constants.PREDICTIONS_KEY]['model2'], - {'output1': np.array([1.0]), 'output2': np.array([1.0])}, - ) + def check_result(got): + try: + self.assertLen(got, 2) + self.assertDictElementsAlmostEqual( + got[0][constants.FEATURES_KEY], + { + "fixed_int": np.array([1]), + "label": np.array([1.0]), + "label1": np.array([1.0]), + "label2": np.array([0.0]), + "example_weight": np.array([0.5]), + "example_weight1": np.array([0.5]), + "example_weight2": np.array([0.5]), + }, + ) + self.assertEqual( + got[0][constants.FEATURES_KEY]["fixed_string"], + np.array([b"fixed_string1"]), + ) + for model_name in ("model1", "model2"): + self.assertIn(model_name, got[0][constants.LABELS_KEY]) + self.assertIn(model_name, got[0][constants.EXAMPLE_WEIGHTS_KEY]) + self.assertIn(model_name, got[0][constants.PREDICTIONS_KEY]) + self.assertAlmostEqual( + got[0][constants.LABELS_KEY]["model1"], np.array([1.0]) + ) + self.assertDictElementsAlmostEqual( + got[0][constants.LABELS_KEY]["model2"], + {"output1": np.array([1.0]), "output2": np.array([0.0])}, + ) + self.assertAlmostEqual( + got[0][constants.EXAMPLE_WEIGHTS_KEY]["model1"], np.array([0.5]) + ) + self.assertDictElementsAlmostEqual( + got[0][constants.EXAMPLE_WEIGHTS_KEY]["model2"], + {"output1": np.array([0.5]), "output2": np.array([0.5])}, + ) + self.assertAlmostEqual( + got[0][constants.PREDICTIONS_KEY]["model1"], np.array([1.0]) + ) + self.assertDictElementsAlmostEqual( + got[0][constants.PREDICTIONS_KEY]["model2"], + {"output1": np.array([1.0]), "output2": np.array([1.0])}, + ) - self.assertDictElementsAlmostEqual( - got[1][constants.FEATURES_KEY], - { - 'fixed_int': np.array([1]), - 'label': np.array([1.0]), - 'label1': np.array([1.0]), - 'label2': np.array([1.0]), - 'example_weight': np.array([0.0]), - 'example_weight1': np.array([0.0]), - 'example_weight2': np.array([1.0]), - }, - ) - self.assertEqual( - got[1][constants.FEATURES_KEY]['fixed_string'], - np.array([b'fixed_string2']), - ) - for model_name in ('model1', 'model2'): - self.assertIn(model_name, got[1][constants.LABELS_KEY]) - self.assertIn(model_name, got[1][constants.EXAMPLE_WEIGHTS_KEY]) - self.assertIn(model_name, got[1][constants.PREDICTIONS_KEY]) - self.assertAlmostEqual( - got[1][constants.LABELS_KEY]['model1'], np.array([1.0]) - ) - self.assertDictElementsAlmostEqual( - got[1][constants.LABELS_KEY]['model2'], - {'output1': np.array([1.0]), 'output2': np.array([1.0])}, - ) - self.assertAlmostEqual( - got[1][constants.EXAMPLE_WEIGHTS_KEY]['model1'], np.array([0.0]) - ) - self.assertDictElementsAlmostEqual( - got[1][constants.EXAMPLE_WEIGHTS_KEY]['model2'], - {'output1': np.array([0.0]), 'output2': np.array([1.0])}, - ) - self.assertAlmostEqual( - got[1][constants.PREDICTIONS_KEY]['model1'], np.array([2.0]) - ) - self.assertDictElementsAlmostEqual( - got[1][constants.PREDICTIONS_KEY]['model2'], - {'output1': np.array([2.0]), 'output2': np.array([2.0])}, - ) + self.assertDictElementsAlmostEqual( + got[1][constants.FEATURES_KEY], + { + "fixed_int": np.array([1]), + "label": np.array([1.0]), + "label1": np.array([1.0]), + "label2": np.array([1.0]), + "example_weight": np.array([0.0]), + "example_weight1": np.array([0.0]), + "example_weight2": np.array([1.0]), + }, + ) + self.assertEqual( + got[1][constants.FEATURES_KEY]["fixed_string"], + np.array([b"fixed_string2"]), + ) + for model_name in ("model1", "model2"): + self.assertIn(model_name, got[1][constants.LABELS_KEY]) + self.assertIn(model_name, got[1][constants.EXAMPLE_WEIGHTS_KEY]) + self.assertIn(model_name, got[1][constants.PREDICTIONS_KEY]) + self.assertAlmostEqual( + got[1][constants.LABELS_KEY]["model1"], np.array([1.0]) + ) + self.assertDictElementsAlmostEqual( + got[1][constants.LABELS_KEY]["model2"], + {"output1": np.array([1.0]), "output2": np.array([1.0])}, + ) + self.assertAlmostEqual( + got[1][constants.EXAMPLE_WEIGHTS_KEY]["model1"], np.array([0.0]) + ) + self.assertDictElementsAlmostEqual( + got[1][constants.EXAMPLE_WEIGHTS_KEY]["model2"], + {"output1": np.array([0.0]), "output2": np.array([1.0])}, + ) + self.assertAlmostEqual( + got[1][constants.PREDICTIONS_KEY]["model1"], np.array([2.0]) + ) + self.assertDictElementsAlmostEqual( + got[1][constants.PREDICTIONS_KEY]["model2"], + {"output1": np.array([2.0]), "output2": np.array([2.0])}, + ) - except AssertionError as err: - raise util.BeamAssertException(err) + except AssertionError as err: + raise util.BeamAssertException(err) - util.assert_that(result, check_result, label='result') + util.assert_that(result, check_result, label="result") -if __name__ == '__main__': - tf.compat.v1.enable_v2_behavior() - tf.test.main() +if __name__ == "__main__": + tf.compat.v1.enable_v2_behavior() + tf.test.main() diff --git a/tensorflow_model_analysis/extractors/legacy_meta_feature_extractor.py b/tensorflow_model_analysis/extractors/legacy_meta_feature_extractor.py index e0fea8b6ee..a8a72ec118 100644 --- a/tensorflow_model_analysis/extractors/legacy_meta_feature_extractor.py +++ b/tensorflow_model_analysis/extractors/legacy_meta_feature_extractor.py @@ -22,21 +22,19 @@ import apache_beam as beam import numpy as np import tensorflow as tf + from tensorflow_model_analysis import constants from tensorflow_model_analysis.api import types - -_ENCODING_NODE_SUFFIX = 'node' +_ENCODING_NODE_SUFFIX = "node" -def get_feature_value( - fpl: types.FeaturesPredictionsLabels, feature_key: str -) -> Any: - """Helper to get value from FPL dict.""" - node_value = fpl.features[feature_key][_ENCODING_NODE_SUFFIX] - if isinstance(node_value, tf.compat.v1.SparseTensorValue): - return node_value.values - return node_value +def get_feature_value(fpl: types.FeaturesPredictionsLabels, feature_key: str) -> Any: + """Helper to get value from FPL dict.""" + node_value = fpl.features[feature_key][_ENCODING_NODE_SUFFIX] + if isinstance(node_value, tf.compat.v1.SparseTensorValue): + return node_value.values + return node_value def _set_feature_value( @@ -44,42 +42,42 @@ def _set_feature_value( feature_key: str, feature_value: Any, ) -> types.DictOfFetchedTensorValues: - """Helper to set feature in FPL dict.""" - if not isinstance(feature_value, np.ndarray) and not isinstance( - feature_value, tf.compat.v1.SparseTensorValue - ): - feature_value = np.array([feature_value]) - features[feature_key] = {_ENCODING_NODE_SUFFIX: feature_value} - return features # pytype: disable=bad-return-type + """Helper to set feature in FPL dict.""" + if not isinstance(feature_value, np.ndarray) and not isinstance( + feature_value, tf.compat.v1.SparseTensorValue + ): + feature_value = np.array([feature_value]) + features[feature_key] = {_ENCODING_NODE_SUFFIX: feature_value} + return features # pytype: disable=bad-return-type def get_fpl_copy(extracts: types.Extracts) -> types.FeaturesPredictionsLabels: - """Get a copy of the FPL in the extracts of extracts.""" - fpl_orig = extracts.get(constants.FEATURES_PREDICTIONS_LABELS_KEY) - if not fpl_orig: - raise RuntimeError('FPL missing, Please ensure _Predict() was called.') - - # We must make a copy of the FPL tuple as well, so that we don't mutate the - # original which is disallowed by Beam. - fpl_copy = types.FeaturesPredictionsLabels( - features=copy.copy(fpl_orig.features), - labels=fpl_orig.labels, - predictions=fpl_orig.predictions, - input_ref=fpl_orig.input_ref, - ) - return fpl_copy + """Get a copy of the FPL in the extracts of extracts.""" + fpl_orig = extracts.get(constants.FEATURES_PREDICTIONS_LABELS_KEY) + if not fpl_orig: + raise RuntimeError("FPL missing, Please ensure _Predict() was called.") + + # We must make a copy of the FPL tuple as well, so that we don't mutate the + # original which is disallowed by Beam. + fpl_copy = types.FeaturesPredictionsLabels( + features=copy.copy(fpl_orig.features), + labels=fpl_orig.labels, + predictions=fpl_orig.predictions, + input_ref=fpl_orig.input_ref, + ) + return fpl_copy def update_fpl_features( fpl: types.FeaturesPredictionsLabels, new_features: types.DictOfFetchedTensorValues, ): - """Add new features to the FPL.""" - for key, value in new_features.items(): - # if the key already exists in the dictionary, throw an error. - if key in fpl.features: - raise ValueError('Modification of existing keys is not allowed.') - _set_feature_value(fpl.features, key, value) + """Add new features to the FPL.""" + for key, value in new_features.items(): + # if the key already exists in the dictionary, throw an error. + if key in fpl.features: + raise ValueError("Modification of existing keys is not allowed.") + _set_feature_value(fpl.features, key, value) def _ExtractMetaFeature( # pylint: disable=invalid-name @@ -88,17 +86,17 @@ def _ExtractMetaFeature( # pylint: disable=invalid-name [types.FeaturesPredictionsLabels], types.DictOfFetchedTensorValues ], ) -> types.Extracts: - """Augments FPL dict with new feature(s).""" - # Create a new feature from existing ones. - fpl_copy = get_fpl_copy(extracts) - new_features = new_features_fn(fpl_copy) + """Augments FPL dict with new feature(s).""" + # Create a new feature from existing ones. + fpl_copy = get_fpl_copy(extracts) + new_features = new_features_fn(fpl_copy) - # Add the new features to the existing ones. - update_fpl_features(fpl_copy, new_features) + # Add the new features to the existing ones. + update_fpl_features(fpl_copy, new_features) - result = copy.copy(extracts) - result[constants.FEATURES_PREDICTIONS_LABELS_KEY] = fpl_copy - return result + result = copy.copy(extracts) + result[constants.FEATURES_PREDICTIONS_LABELS_KEY] = fpl_copy + return result @beam.ptransform_fn @@ -110,23 +108,25 @@ def ExtractMetaFeature( # pylint: disable=invalid-name [types.FeaturesPredictionsLabels], types.DictOfFetchedTensorValues ], ) -> beam.pvalue.PCollection: - """Extracts meta-features derived from existing features. - - It must be the case that the PredictExtractor was called before calling this - function. - - Args: - extracts: PCollection containing the Extracts that will have - MaterializedColumn added to its extracts. - new_features_fn: A function that adds new features. Must take a - FeaturesPredictionsLabel tuple as an argument, and return a a dict of new - features to add, where the keys are new feature names and the values are - the associated values.Only adding new features is permitted to prevent - inadvertently removing useful data. - - Returns: - PCollection of Extracts - """ - return extracts | 'ExtractMetaFeature' >> beam.Map( - _ExtractMetaFeature, new_features_fn - ) + """Extracts meta-features derived from existing features. + + It must be the case that the PredictExtractor was called before calling this + function. + + Args: + ---- + extracts: PCollection containing the Extracts that will have + MaterializedColumn added to its extracts. + new_features_fn: A function that adds new features. Must take a + FeaturesPredictionsLabel tuple as an argument, and return a a dict of new + features to add, where the keys are new feature names and the values are + the associated values.Only adding new features is permitted to prevent + inadvertently removing useful data. + + Returns: + ------- + PCollection of Extracts + """ + return extracts | "ExtractMetaFeature" >> beam.Map( + _ExtractMetaFeature, new_features_fn + ) diff --git a/tensorflow_model_analysis/extractors/legacy_meta_feature_extractor_test.py b/tensorflow_model_analysis/extractors/legacy_meta_feature_extractor_test.py index cb20e1d2b0..a505095cf4 100644 --- a/tensorflow_model_analysis/extractors/legacy_meta_feature_extractor_test.py +++ b/tensorflow_model_analysis/extractors/legacy_meta_feature_extractor_test.py @@ -14,175 +14,180 @@ """Test for using the MetaFeatureExtractor as part of TFMA.""" import apache_beam as beam -from apache_beam.testing import util import numpy as np import tensorflow as tf +from apache_beam.testing import util + from tensorflow_model_analysis import constants from tensorflow_model_analysis.api import types -from tensorflow_model_analysis.extractors import legacy_meta_feature_extractor as meta_feature_extractor +from tensorflow_model_analysis.extractors import ( + legacy_meta_feature_extractor as meta_feature_extractor, +) from tensorflow_model_analysis.extractors import slice_key_extractor from tensorflow_model_analysis.slicer import slicer_lib as slicer from tensorflow_model_analysis.utils import test_util def make_features_dict(features_dict): - result = {} - for key, value in features_dict.items(): - result[key] = {'node': np.array(value)} - return result + result = {} + for key, value in features_dict.items(): + result[key] = {"node": np.array(value)} + return result def create_fpls(): - """Create test FPL dicts that can be used for verification.""" - fpl1 = types.FeaturesPredictionsLabels( - input_ref=0, - features=make_features_dict( - {'gender': ['f'], 'age': [13], 'interest': ['cars']} - ), - predictions=make_features_dict({ - 'kb': [1], - }), - labels=make_features_dict({'ad_risk_score': [0]}), - ) - fpl2 = types.FeaturesPredictionsLabels( - input_ref=1, - features=make_features_dict( - {'gender': ['m'], 'age': [10], 'interest': ['cars', 'movies']} - ), - predictions=make_features_dict({ - 'kb': [1], - }), - labels=make_features_dict({'ad_risk_score': [0]}), - ) - return [fpl1, fpl2] + """Create test FPL dicts that can be used for verification.""" + fpl1 = types.FeaturesPredictionsLabels( + input_ref=0, + features=make_features_dict( + {"gender": ["f"], "age": [13], "interest": ["cars"]} + ), + predictions=make_features_dict( + { + "kb": [1], + } + ), + labels=make_features_dict({"ad_risk_score": [0]}), + ) + fpl2 = types.FeaturesPredictionsLabels( + input_ref=1, + features=make_features_dict( + {"gender": ["m"], "age": [10], "interest": ["cars", "movies"]} + ), + predictions=make_features_dict( + { + "kb": [1], + } + ), + labels=make_features_dict({"ad_risk_score": [0]}), + ) + return [fpl1, fpl2] def wrap_fpl(fpl): - return { - constants.INPUT_KEY: 'xyz', - constants.FEATURES_PREDICTIONS_LABELS_KEY: fpl, - } + return { + constants.INPUT_KEY: "xyz", + constants.FEATURES_PREDICTIONS_LABELS_KEY: fpl, + } def get_num_interests(fpl): - interests = meta_feature_extractor.get_feature_value(fpl, 'interest') - new_features = {'num_interests': len(interests)} - return new_features + interests = meta_feature_extractor.get_feature_value(fpl, "interest") + new_features = {"num_interests": len(interests)} + return new_features class MetaFeatureExtractorTest(test_util.TensorflowModelAnalysisTest): - - def testMetaFeatures(self): - with beam.Pipeline() as pipeline: - fpls = create_fpls() - - metrics = ( - pipeline - | 'CreateTestInput' >> beam.Create(fpls) - | 'WrapFpls' >> beam.Map(wrap_fpl) - | 'ExtractInterestsNum' - >> meta_feature_extractor.ExtractMetaFeature(get_num_interests) - ) - - def check_result(got): - try: - self.assertEqual(2, len(got), 'got: %s' % got) - for res in got: - self.assertIn( - 'num_interests', - res[constants.FEATURES_PREDICTIONS_LABELS_KEY].features, - ) - self.assertEqual( - len( - meta_feature_extractor.get_feature_value( - res[constants.FEATURES_PREDICTIONS_LABELS_KEY], - 'interest', - ) - ), - meta_feature_extractor.get_feature_value( - res[constants.FEATURES_PREDICTIONS_LABELS_KEY], - 'num_interests', - ), + def testMetaFeatures(self): + with beam.Pipeline() as pipeline: + fpls = create_fpls() + + metrics = ( + pipeline + | "CreateTestInput" >> beam.Create(fpls) + | "WrapFpls" >> beam.Map(wrap_fpl) + | "ExtractInterestsNum" + >> meta_feature_extractor.ExtractMetaFeature(get_num_interests) ) - except AssertionError as err: - raise util.BeamAssertException(err) - util.assert_that(metrics, check_result) - - def testNoModificationOfExistingKeys(self): + def check_result(got): + try: + self.assertEqual(2, len(got), "got: %s" % got) + for res in got: + self.assertIn( + "num_interests", + res[constants.FEATURES_PREDICTIONS_LABELS_KEY].features, + ) + self.assertEqual( + len( + meta_feature_extractor.get_feature_value( + res[constants.FEATURES_PREDICTIONS_LABELS_KEY], + "interest", + ) + ), + meta_feature_extractor.get_feature_value( + res[constants.FEATURES_PREDICTIONS_LABELS_KEY], + "num_interests", + ), + ) + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(metrics, check_result) + + def testNoModificationOfExistingKeys(self): + def bad_meta_feature_fn(_): + return {"interest": ["bad", "key"]} + + with self.assertRaises(ValueError): + with beam.Pipeline() as pipeline: + fpls = create_fpls() + + _ = ( + pipeline + | "CreateTestInput" >> beam.Create(fpls) + | "WrapFpls" >> beam.Map(wrap_fpl) + | "ExtractInterestsNum" + >> meta_feature_extractor.ExtractMetaFeature(bad_meta_feature_fn) + ) + + def testSliceOnMetaFeature(self): + # We want to make sure that slicing on the newly added feature works, so + # pulling in slice here. + with beam.Pipeline() as pipeline: + fpls = create_fpls() + metrics = ( + pipeline + | "CreateTestInput" >> beam.Create(fpls) + | "WrapFpls" >> beam.Map(wrap_fpl) + | "ExtractInterestsNum" + >> meta_feature_extractor.ExtractMetaFeature(get_num_interests) + | "ExtractSlices" + >> slice_key_extractor.ExtractSliceKeys( + [ + slicer.SingleSliceSpec(), + slicer.SingleSliceSpec(columns=["num_interests"]), + ] + ) + | "FanoutSlices" >> slicer.FanoutSlices() + ) - def bad_meta_feature_fn(_): - return {'interest': ['bad', 'key']} + def check_result(got): + try: + self.assertEqual(4, len(got), "got: %s" % got) + expected_slice_keys = [ + (), + (), + (("num_interests", 1),), + (("num_interests", 2),), + ] + self.assertCountEqual( + sorted(slice_key for slice_key, _ in got), + sorted(expected_slice_keys), + ) + except AssertionError as err: + raise util.BeamAssertException(err) - with self.assertRaises(ValueError): - with beam.Pipeline() as pipeline: - fpls = create_fpls() + util.assert_that(metrics, check_result) - _ = ( - pipeline - | 'CreateTestInput' >> beam.Create(fpls) - | 'WrapFpls' >> beam.Map(wrap_fpl) - | 'ExtractInterestsNum' - >> meta_feature_extractor.ExtractMetaFeature(bad_meta_feature_fn) + def testGetSparseTensorValue(self): + sparse_tensor_value = tf.compat.v1.SparseTensorValue( + indices=[[0, 0, 0], [0, 1, 0], [0, 1, 1]], + values=["", "one", "two"], + dense_shape=[1, 2, 2], + ) + fpl_with_sparse_tensor = types.FeaturesPredictionsLabels( + input_ref=0, features={}, predictions={}, labels={} ) - def testSliceOnMetaFeature(self): - # We want to make sure that slicing on the newly added feature works, so - # pulling in slice here. - with beam.Pipeline() as pipeline: - fpls = create_fpls() - metrics = ( - pipeline - | 'CreateTestInput' >> beam.Create(fpls) - | 'WrapFpls' >> beam.Map(wrap_fpl) - | 'ExtractInterestsNum' - >> meta_feature_extractor.ExtractMetaFeature(get_num_interests) - | 'ExtractSlices' - >> slice_key_extractor.ExtractSliceKeys([ - slicer.SingleSliceSpec(), - slicer.SingleSliceSpec(columns=['num_interests']), - ]) - | 'FanoutSlices' >> slicer.FanoutSlices() - ) - - def check_result(got): - try: - self.assertEqual(4, len(got), 'got: %s' % got) - expected_slice_keys = [ - (), - (), - (('num_interests', 1),), - (('num_interests', 2),), - ] - self.assertCountEqual( - sorted(slice_key for slice_key, _ in got), - sorted(expected_slice_keys), - ) - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(metrics, check_result) - - def testGetSparseTensorValue(self): - sparse_tensor_value = tf.compat.v1.SparseTensorValue( - indices=[[0, 0, 0], [0, 1, 0], [0, 1, 1]], - values=['', 'one', 'two'], - dense_shape=[1, 2, 2], - ) - fpl_with_sparse_tensor = types.FeaturesPredictionsLabels( - input_ref=0, features={}, predictions={}, labels={} - ) - - meta_feature_extractor._set_feature_value( - fpl_with_sparse_tensor.features, 'sparse', sparse_tensor_value - ) - self.assertEqual( - ['', 'one', 'two'], - meta_feature_extractor.get_feature_value( - fpl_with_sparse_tensor, 'sparse' - ), - ) + meta_feature_extractor._set_feature_value( + fpl_with_sparse_tensor.features, "sparse", sparse_tensor_value + ) + self.assertEqual( + ["", "one", "two"], + meta_feature_extractor.get_feature_value(fpl_with_sparse_tensor, "sparse"), + ) -if __name__ == '__main__': - tf.test.main() +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow_model_analysis/extractors/legacy_predict_extractor.py b/tensorflow_model_analysis/extractors/legacy_predict_extractor.py index 4562c28716..1d64dbdfe6 100644 --- a/tensorflow_model_analysis/extractors/legacy_predict_extractor.py +++ b/tensorflow_model_analysis/extractors/legacy_predict_extractor.py @@ -20,16 +20,18 @@ import apache_beam as beam import numpy as np import pyarrow as pa + from tensorflow_model_analysis import constants from tensorflow_model_analysis.api import types -from tensorflow_model_analysis.eval_saved_model import constants as eval_saved_model_constants -from tensorflow_model_analysis.extractors import extractor -from tensorflow_model_analysis.extractors import legacy_feature_extractor +from tensorflow_model_analysis.eval_saved_model import ( + constants as eval_saved_model_constants, +) +from tensorflow_model_analysis.extractors import extractor, legacy_feature_extractor from tensorflow_model_analysis.proto import config_pb2 from tensorflow_model_analysis.utils import model_util # TODO(b/372967361): Add new tests for this file based on new Keras model. -_PREDICT_EXTRACTOR_STAGE_NAME = 'Predict' +_PREDICT_EXTRACTOR_STAGE_NAME = "Predict" _FEATURES_PREDICTIONS_LABELS_KEY_MAP = { eval_saved_model_constants.FEATURES_NAME: constants.FEATURES_KEY, @@ -44,174 +46,168 @@ def PredictExtractor( # pylint: disable=invalid-name materialize: Optional[bool] = True, eval_config: Optional[config_pb2.EvalConfig] = None, ) -> extractor.Extractor: - """Creates an Extractor for TFMAPredict. - - The extractor's PTransform loads and runs the eval_saved_model against every - example yielding a copy of the Extracts input with an additional extract - of type FeaturesPredictionsLabels keyed by - tfma.FEATURES_PREDICTIONS_LABELS_KEY unless eval_config is not None in which - case the features, predictions, and labels will be stored separately under - tfma.FEATURES_KEY, tfma.PREDICTIONS_KEY, and tfma.LABELS_KEY respectively. - - Args: - eval_shared_model: Shared model (single-model evaluation) or list of shared - models (multi-model evaluation). - desired_batch_size: Optional batch size for batching in Aggregate. - materialize: True to call the FeatureExtractor to add MaterializedColumn - entries for the features, predictions, and labels. - eval_config: Eval config. - - Returns: - Extractor for extracting features, predictions, labels, and other tensors - during predict. - """ - eval_shared_models = model_util.verify_and_update_eval_shared_models( - eval_shared_model - ) - - # pylint: disable=no-value-for-parameter - return extractor.Extractor( - stage_name=_PREDICT_EXTRACTOR_STAGE_NAME, - ptransform=_TFMAPredict( - eval_shared_models={m.model_name: m for m in eval_shared_models}, - desired_batch_size=desired_batch_size, - materialize=materialize, - eval_config=eval_config, - ), - ) + """Creates an Extractor for TFMAPredict. + + The extractor's PTransform loads and runs the eval_saved_model against every + example yielding a copy of the Extracts input with an additional extract + of type FeaturesPredictionsLabels keyed by + tfma.FEATURES_PREDICTIONS_LABELS_KEY unless eval_config is not None in which + case the features, predictions, and labels will be stored separately under + tfma.FEATURES_KEY, tfma.PREDICTIONS_KEY, and tfma.LABELS_KEY respectively. + + Args: + ---- + eval_shared_model: Shared model (single-model evaluation) or list of shared + models (multi-model evaluation). + desired_batch_size: Optional batch size for batching in Aggregate. + materialize: True to call the FeatureExtractor to add MaterializedColumn + entries for the features, predictions, and labels. + eval_config: Eval config. + + Returns: + ------- + Extractor for extracting features, predictions, labels, and other tensors + during predict. + """ + eval_shared_models = model_util.verify_and_update_eval_shared_models( + eval_shared_model + ) + + # pylint: disable=no-value-for-parameter + return extractor.Extractor( + stage_name=_PREDICT_EXTRACTOR_STAGE_NAME, + ptransform=_TFMAPredict( + eval_shared_models={m.model_name: m for m in eval_shared_models}, + desired_batch_size=desired_batch_size, + materialize=materialize, + eval_config=eval_config, + ), + ) @beam.typehints.with_input_types(beam.typehints.List[types.Extracts]) @beam.typehints.with_output_types(types.Extracts) class _TFMAPredictionDoFn(model_util.BatchReducibleDoFnWithModels): - """A DoFn that loads the model and predicts.""" - - def __init__( - self, eval_shared_models: Dict[str, types.EvalSharedModel], eval_config - ): - super().__init__({k: v.model_loader for k, v in eval_shared_models.items()}) - self._eval_config = eval_config - - def _get_example_weights( - self, model_name: str, features: Dict[str, Any] - ) -> Any: - spec = model_util.get_model_spec(self._eval_config, model_name) - if not spec: - raise ValueError( - 'Missing model_spec for model_name "{}"'.format(model_name) - ) - if spec.example_weight_key: - if spec.example_weight_key not in features: - raise ValueError( - 'Missing feature for example_weight_key "{}": features={}'.format( - spec.example_weight_key, features - ) - ) - return features[spec.example_weight_key] - elif spec.example_weight_keys: - example_weights = {} - for k, v in spec.example_weight_keys.items(): - if v not in features: - raise ValueError( - 'Missing feature for example_weight_key "{}": features={}'.format( - k, features - ) - ) - example_weights[k] = features[v] - return example_weights - else: - return np.array([1.0]) - - def _batch_reducible_process( - self, elements: List[types.Extracts] - ) -> Sequence[types.Extracts]: - serialized_examples = [x[constants.INPUT_KEY] for x in elements] - - # Compute features, predictions, and labels for each serialized_example - result = [] - for model_name, loaded_model in self._loaded_models.items(): - for i, fetched in enumerate( - loaded_model.predict_list(serialized_examples) - ): - if i >= len(result): - element_copy = copy.copy(elements[fetched.input_ref]) - for key in fetched.values: - if key in _FEATURES_PREDICTIONS_LABELS_KEY_MAP: - if self._eval_config: - element_copy[_FEATURES_PREDICTIONS_LABELS_KEY_MAP[key]] = ( - fetched.values[key] - ) - continue - element_copy[key] = fetched.values[key] - if self._eval_config: - element_copy[constants.EXAMPLE_WEIGHTS_KEY] = ( - self._get_example_weights( - model_name, element_copy[constants.FEATURES_KEY] + """A DoFn that loads the model and predicts.""" + + def __init__( + self, eval_shared_models: Dict[str, types.EvalSharedModel], eval_config + ): + super().__init__({k: v.model_loader for k, v in eval_shared_models.items()}) + self._eval_config = eval_config + + def _get_example_weights(self, model_name: str, features: Dict[str, Any]) -> Any: + spec = model_util.get_model_spec(self._eval_config, model_name) + if not spec: + raise ValueError(f'Missing model_spec for model_name "{model_name}"') + if spec.example_weight_key: + if spec.example_weight_key not in features: + raise ValueError( + f'Missing feature for example_weight_key "{spec.example_weight_key}": features={features}' ) - ) - if len(self._loaded_models) == 1: - if not self._eval_config: - element_copy[constants.FEATURES_PREDICTIONS_LABELS_KEY] = ( - loaded_model.as_features_predictions_labels([fetched])[0] - ) - else: - if not self._eval_config: - raise ValueError( - 'PredictExtractor can only be used with multi-output models ' - 'if eval_config is passed.' - ) - # If only one model, the predictions are stored without using a dict - element_copy[constants.PREDICTIONS_KEY] = { - model_name: element_copy[constants.PREDICTIONS_KEY] - } - result.append(element_copy) + return features[spec.example_weight_key] + elif spec.example_weight_keys: + example_weights = {} + for k, v in spec.example_weight_keys.items(): + if v not in features: + raise ValueError( + f'Missing feature for example_weight_key "{k}": features={features}' + ) + example_weights[k] = features[v] + return example_weights else: - element_copy = result[i] - # Assume values except for predictions are same for all models. - element_copy[constants.PREDICTIONS_KEY][model_name] = fetched.values[ - eval_saved_model_constants.PREDICTIONS_NAME - ] - if self._eval_config: - return [_wrap_as_batched_extract(result)] - return result + return np.array([1.0]) + + def _batch_reducible_process( + self, elements: List[types.Extracts] + ) -> Sequence[types.Extracts]: + serialized_examples = [x[constants.INPUT_KEY] for x in elements] + + # Compute features, predictions, and labels for each serialized_example + result = [] + for model_name, loaded_model in self._loaded_models.items(): + for i, fetched in enumerate(loaded_model.predict_list(serialized_examples)): + if i >= len(result): + element_copy = copy.copy(elements[fetched.input_ref]) + for key in fetched.values: + if key in _FEATURES_PREDICTIONS_LABELS_KEY_MAP: + if self._eval_config: + element_copy[ + _FEATURES_PREDICTIONS_LABELS_KEY_MAP[key] + ] = fetched.values[key] + continue + element_copy[key] = fetched.values[key] + if self._eval_config: + element_copy[constants.EXAMPLE_WEIGHTS_KEY] = ( + self._get_example_weights( + model_name, element_copy[constants.FEATURES_KEY] + ) + ) + if len(self._loaded_models) == 1: + if not self._eval_config: + element_copy[constants.FEATURES_PREDICTIONS_LABELS_KEY] = ( + loaded_model.as_features_predictions_labels( + [fetched] + )[0] + ) + else: + if not self._eval_config: + raise ValueError( + "PredictExtractor can only be used with multi-output models " + "if eval_config is passed." + ) + # If only one model, the predictions are stored without using a dict + element_copy[constants.PREDICTIONS_KEY] = { + model_name: element_copy[constants.PREDICTIONS_KEY] + } + result.append(element_copy) + else: + element_copy = result[i] + # Assume values except for predictions are same for all models. + element_copy[constants.PREDICTIONS_KEY][model_name] = ( + fetched.values[eval_saved_model_constants.PREDICTIONS_NAME] + ) + if self._eval_config: + return [_wrap_as_batched_extract(result)] + return result # TODO(b/178158073): Currently the batched extract has a list of per-example # feature dicts. Convert this to a batched feature dict where each feature # will have a batch of ndarrays. def _wrap_as_batched_extract(extracts: List[types.Extracts]) -> types.Extracts: - """Wrap list of per-example extracts as a batched extract.""" - result = collections.defaultdict(list) - for e in extracts: - for key, value in e.items(): - result[key].append(value) - return result + """Wrap list of per-example extracts as a batched extract.""" + result = collections.defaultdict(list) + for e in extracts: + for key, value in e.items(): + result[key].append(value) + return result def _fetch_raw_data_column(record_batch: pa.RecordBatch) -> np.ndarray: - """Fetch the raw data column. + """Fetch the raw data column. - Args: - record_batch: An Arrow RecordBatch. + Args: + ---- + record_batch: An Arrow RecordBatch. - Returns: - Raw data column. - """ - column_index = record_batch.schema.get_field_index( - constants.ARROW_INPUT_COLUMN - ) - assert column_index >= 0, 'Arrow input column not found.' - return np.asarray(record_batch.column(column_index).flatten()) + Returns: + ------- + Raw data column. + """ + column_index = record_batch.schema.get_field_index(constants.ARROW_INPUT_COLUMN) + assert column_index >= 0, "Arrow input column not found." + return np.asarray(record_batch.column(column_index).flatten()) def _unwrap_batched_extract( batched_extract: types.Extracts, ) -> List[types.Extracts]: - """Unwraps batched extract.""" - serialized_examples = _fetch_raw_data_column( - batched_extract[constants.ARROW_RECORD_BATCH_KEY] - ) - return [{constants.INPUT_KEY: e} for e in serialized_examples] + """Unwraps batched extract.""" + serialized_examples = _fetch_raw_data_column( + batched_extract[constants.ARROW_RECORD_BATCH_KEY] + ) + return [{constants.INPUT_KEY: e} for e in serialized_examples] @beam.ptransform_fn @@ -224,59 +220,61 @@ def _TFMAPredict( # pylint: disable=invalid-name materialize: Optional[bool] = True, eval_config: Optional[config_pb2.EvalConfig] = None, ) -> beam.pvalue.PCollection: - """A PTransform that adds predictions to Extracts. - - Args: - extracts: PCollection of Extracts containing a serialized example to be fed - to the model. - eval_shared_models: Shared model parameters keyed by model name. - desired_batch_size: Optional. Desired batch size for prediction. - materialize: True to call the FeatureExtractor to add MaterializedColumn - entries for the features, predictions, and labels. - eval_config: Eval config. - - Returns: - PCollection of Extracts, where the extracts contains the features, - predictions, labels retrieved. - """ - if not eval_config: - batch_args = {} - - # TODO(b/143484017): Consider removing this option if autotuning is better - # able to handle batch size selection. - if desired_batch_size: - batch_args = dict( - min_batch_size=desired_batch_size, max_batch_size=desired_batch_size - ) - - extracts = extracts | 'Batch' >> beam.BatchElements(**batch_args) - else: - extracts = extracts | 'UnwrapBatchedExtract' >> beam.Map( - _unwrap_batched_extract - ) + """A PTransform that adds predictions to Extracts. + + Args: + ---- + extracts: PCollection of Extracts containing a serialized example to be fed + to the model. + eval_shared_models: Shared model parameters keyed by model name. + desired_batch_size: Optional. Desired batch size for prediction. + materialize: True to call the FeatureExtractor to add MaterializedColumn + entries for the features, predictions, and labels. + eval_config: Eval config. + + Returns: + ------- + PCollection of Extracts, where the extracts contains the features, + predictions, labels retrieved. + """ + if not eval_config: + batch_args = {} + + # TODO(b/143484017): Consider removing this option if autotuning is better + # able to handle batch size selection. + if desired_batch_size: + batch_args = dict( + min_batch_size=desired_batch_size, max_batch_size=desired_batch_size + ) + + extracts = extracts | "Batch" >> beam.BatchElements(**batch_args) + else: + extracts = extracts | "UnwrapBatchedExtract" >> beam.Map( + _unwrap_batched_extract + ) - # We don't actually need to add the add_metrics_callbacks to do Predict, - # but because if we want to share the model between Predict and subsequent - # stages (i.e. we use same shared handle for this and subsequent stages), - # then if we don't add the metrics callbacks here, they won't be present - # in the model in the later stages if we reuse the model from this stage. - extracts = extracts | 'Predict' >> beam.ParDo( - _TFMAPredictionDoFn( - eval_shared_models=eval_shared_models, eval_config=eval_config - ) - ) - - if materialize and not eval_config: - additional_fetches = [] - for m in eval_shared_models.values(): - if m.additional_fetches: - additional_fetches.extend(m.additional_fetches) - return ( - extracts - | 'ExtractFeatures' - >> legacy_feature_extractor._ExtractFeatures( # pylint: disable=protected-access - additional_extracts=additional_fetches or None + # We don't actually need to add the add_metrics_callbacks to do Predict, + # but because if we want to share the model between Predict and subsequent + # stages (i.e. we use same shared handle for this and subsequent stages), + # then if we don't add the metrics callbacks here, they won't be present + # in the model in the later stages if we reuse the model from this stage. + extracts = extracts | "Predict" >> beam.ParDo( + _TFMAPredictionDoFn( + eval_shared_models=eval_shared_models, eval_config=eval_config ) ) - return extracts + if materialize and not eval_config: + additional_fetches = [] + for m in eval_shared_models.values(): + if m.additional_fetches: + additional_fetches.extend(m.additional_fetches) + return ( + extracts + | "ExtractFeatures" + >> legacy_feature_extractor._ExtractFeatures( # pylint: disable=protected-access + additional_extracts=additional_fetches or None + ) + ) + + return extracts diff --git a/tensorflow_model_analysis/extractors/materialized_predictions_extractor.py b/tensorflow_model_analysis/extractors/materialized_predictions_extractor.py index e334390cb1..2a3f4009ce 100644 --- a/tensorflow_model_analysis/extractors/materialized_predictions_extractor.py +++ b/tensorflow_model_analysis/extractors/materialized_predictions_extractor.py @@ -16,44 +16,44 @@ from typing import Sequence import apache_beam as beam + from tensorflow_model_analysis import constants from tensorflow_model_analysis.api import types from tensorflow_model_analysis.extractors import extractor from tensorflow_model_analysis.proto import config_pb2 -from tensorflow_model_analysis.utils import model_util -from tensorflow_model_analysis.utils import util +from tensorflow_model_analysis.utils import model_util, util -_MATERIALIZED_PREDICTIONS_EXTRACTOR_STAGE_NAME = ( - 'ExtractMaterializedPredictions' -) +_MATERIALIZED_PREDICTIONS_EXTRACTOR_STAGE_NAME = "ExtractMaterializedPredictions" def MaterializedPredictionsExtractor( eval_config: config_pb2.EvalConfig, output_keypath: Sequence[str] = (constants.PREDICTIONS_KEY,), ) -> extractor.Extractor: - """Creates an extractor for rekeying preexisting predictions. + """Creates an extractor for rekeying preexisting predictions. - The extractor's PTransform uses the config's ModelSpec.prediction_key(s) - to lookup the associated prediction values stored as features under the - tfma.FEATURES_KEY in extracts. The resulting values are then added to the - extracts under the key tfma.PREDICTIONS_KEY. + The extractor's PTransform uses the config's ModelSpec.prediction_key(s) + to lookup the associated prediction values stored as features under the + tfma.FEATURES_KEY in extracts. The resulting values are then added to the + extracts under the key tfma.PREDICTIONS_KEY. - Args: - eval_config: Eval config. - output_keypath: A list of keys to be used as the path to traverse and insert - the outputs in the extract. + Args: + ---- + eval_config: Eval config. + output_keypath: A list of keys to be used as the path to traverse and insert + the outputs in the extract. - Returns: - Extractor for rekeying preexisting predictions. - """ - # pylint: disable=no-value-for-parameter - return extractor.Extractor( - stage_name=_MATERIALIZED_PREDICTIONS_EXTRACTOR_STAGE_NAME, - ptransform=_ExtractMaterializedPredictions( - eval_config=eval_config, output_keypath=output_keypath - ), - ) + Returns: + ------- + Extractor for rekeying preexisting predictions. + """ + # pylint: disable=no-value-for-parameter + return extractor.Extractor( + stage_name=_MATERIALIZED_PREDICTIONS_EXTRACTOR_STAGE_NAME, + ptransform=_ExtractMaterializedPredictions( + eval_config=eval_config, output_keypath=output_keypath + ), + ) @beam.ptransform_fn @@ -64,35 +64,37 @@ def _ExtractMaterializedPredictions( # pylint: disable=invalid-name eval_config: config_pb2.EvalConfig, output_keypath: Sequence[str], ) -> beam.pvalue.PCollection: - """A PTransform that populates the predictions key in the extracts. + """A PTransform that populates the predictions key in the extracts. - Args: - extracts: PCollection of extracts containing model inputs keyed by - tfma.FEATURES_KEY (if model inputs are named) or tfma.INPUTS_KEY (if model - takes raw tf.Examples as input). - eval_config: Eval config. - output_keypath: A list of keys that indicates the location to which the - predictions are inserted. + Args: + ---- + extracts: PCollection of extracts containing model inputs keyed by + tfma.FEATURES_KEY (if model inputs are named) or tfma.INPUTS_KEY (if model + takes raw tf.Examples as input). + eval_config: Eval config. + output_keypath: A list of keys that indicates the location to which the + predictions are inserted. - Returns: - PCollection of Extracts updated with the predictions. - """ + Returns: + ------- + PCollection of Extracts updated with the predictions. + """ - def rekey_predictions( # pylint: disable=invalid-name - batched_extracts: types.Extracts, - ) -> types.Extracts: - """Extract predictions from extracts containing features.""" - predictions = model_util.get_feature_values_for_model_spec_field( - list(eval_config.model_specs), - 'prediction_key', - 'prediction_keys', - batched_extracts, - ) - if predictions is not None: - return util.copy_and_set_by_keys( - batched_extracts, list(output_keypath), predictions - ) - else: - return batched_extracts + def rekey_predictions( # pylint: disable=invalid-name + batched_extracts: types.Extracts, + ) -> types.Extracts: + """Extract predictions from extracts containing features.""" + predictions = model_util.get_feature_values_for_model_spec_field( + list(eval_config.model_specs), + "prediction_key", + "prediction_keys", + batched_extracts, + ) + if predictions is not None: + return util.copy_and_set_by_keys( + batched_extracts, list(output_keypath), predictions + ) + else: + return batched_extracts - return extracts | 'RekeyPredictions' >> beam.Map(rekey_predictions) + return extracts | "RekeyPredictions" >> beam.Map(rekey_predictions) diff --git a/tensorflow_model_analysis/extractors/materialized_predictions_extractor_test.py b/tensorflow_model_analysis/extractors/materialized_predictions_extractor_test.py index cd920e2fb3..4bd3859966 100644 --- a/tensorflow_model_analysis/extractors/materialized_predictions_extractor_test.py +++ b/tensorflow_model_analysis/extractors/materialized_predictions_extractor_test.py @@ -14,37 +14,33 @@ """Test for batched materialized predictions extractor.""" import apache_beam as beam -from apache_beam.testing import util import numpy as np import tensorflow as tf +from apache_beam.testing import util +from google.protobuf import text_format +from tensorflow_metadata.proto.v0 import schema_pb2 +from tfx_bsl.tfxio import tensor_adapter, test_util + from tensorflow_model_analysis import constants from tensorflow_model_analysis.api import model_eval_lib -from tensorflow_model_analysis.extractors import features_extractor -from tensorflow_model_analysis.extractors import materialized_predictions_extractor +from tensorflow_model_analysis.extractors import ( + features_extractor, + materialized_predictions_extractor, +) from tensorflow_model_analysis.proto import config_pb2 from tensorflow_model_analysis.utils import test_util as testutil -from tfx_bsl.tfxio import tensor_adapter -from tfx_bsl.tfxio import test_util - -from google.protobuf import text_format -from tensorflow_metadata.proto.v0 import schema_pb2 - -class MaterializedPredictionsExtractorTest( - testutil.TensorflowModelAnalysisTest -): - def test_rekey_predictions_in_features(self): - model_spec1 = config_pb2.ModelSpec( - name='model1', prediction_key='prediction' - ) - model_spec2 = config_pb2.ModelSpec( - name='model2', - prediction_keys={'output1': 'prediction1', 'output2': 'prediction2'}, - ) - eval_config = config_pb2.EvalConfig(model_specs=[model_spec1, model_spec2]) - schema = text_format.Parse( - """ +class MaterializedPredictionsExtractorTest(testutil.TensorflowModelAnalysisTest): + def test_rekey_predictions_in_features(self): + model_spec1 = config_pb2.ModelSpec(name="model1", prediction_key="prediction") + model_spec2 = config_pb2.ModelSpec( + name="model2", + prediction_keys={"output1": "prediction1", "output2": "prediction2"}, + ) + eval_config = config_pb2.EvalConfig(model_specs=[model_spec1, model_spec2]) + schema = text_format.Parse( + """ tensor_representation_group { key: "" value { @@ -81,76 +77,77 @@ def test_rekey_predictions_in_features(self): type: INT } """, - schema_pb2.Schema(), - ) - # TODO(b/73109633): Remove when field is removed or its default changes to - # False. - if hasattr(schema, 'generate_legacy_feature_spec'): - schema.generate_legacy_feature_spec = False - tfx_io = test_util.InMemoryTFExampleRecord( - schema=schema, raw_record_column_name=constants.ARROW_INPUT_COLUMN - ) - tensor_adapter_config = tensor_adapter.TensorAdapterConfig( - arrow_schema=tfx_io.ArrowSchema(), - tensor_representations=tfx_io.TensorRepresentations(), - ) - feature_extractor = features_extractor.FeaturesExtractor( - eval_config=eval_config, - tensor_representations=tensor_adapter_config.tensor_representations, - ) - prediction_extractor = ( - materialized_predictions_extractor.MaterializedPredictionsExtractor( - eval_config + schema_pb2.Schema(), + ) + # TODO(b/73109633): Remove when field is removed or its default changes to + # False. + if hasattr(schema, "generate_legacy_feature_spec"): + schema.generate_legacy_feature_spec = False + tfx_io = test_util.InMemoryTFExampleRecord( + schema=schema, raw_record_column_name=constants.ARROW_INPUT_COLUMN + ) + tensor_adapter_config = tensor_adapter.TensorAdapterConfig( + arrow_schema=tfx_io.ArrowSchema(), + tensor_representations=tfx_io.TensorRepresentations(), + ) + feature_extractor = features_extractor.FeaturesExtractor( + eval_config=eval_config, + tensor_representations=tensor_adapter_config.tensor_representations, + ) + prediction_extractor = ( + materialized_predictions_extractor.MaterializedPredictionsExtractor( + eval_config + ) ) - ) - examples = [ - self._makeExample( - prediction=1.0, prediction1=1.0, prediction2=0.0, fixed_int=1 - ), - self._makeExample( - prediction=1.0, prediction1=1.0, prediction2=1.0, fixed_int=1 - ), - ] + examples = [ + self._makeExample( + prediction=1.0, prediction1=1.0, prediction2=0.0, fixed_int=1 + ), + self._makeExample( + prediction=1.0, prediction1=1.0, prediction2=1.0, fixed_int=1 + ), + ] - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'Create' - >> beam.Create( - [e.SerializeToString() for e in examples], reshuffle=False - ) - | 'BatchExamples' >> tfx_io.BeamSource(batch_size=2) - | 'InputsToExtracts' >> model_eval_lib.BatchedInputsToExtracts() - | feature_extractor.stage_name >> feature_extractor.ptransform - | prediction_extractor.stage_name >> prediction_extractor.ptransform - ) + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "Create" + >> beam.Create( + [e.SerializeToString() for e in examples], reshuffle=False + ) + | "BatchExamples" >> tfx_io.BeamSource(batch_size=2) + | "InputsToExtracts" >> model_eval_lib.BatchedInputsToExtracts() + | feature_extractor.stage_name >> feature_extractor.ptransform + | prediction_extractor.stage_name >> prediction_extractor.ptransform + ) - # pylint: enable=no-value-for-parameter + # pylint: enable=no-value-for-parameter - def check_result(got): - try: - self.assertLen(got, 1) - for model_name in ('model1', 'model2'): - self.assertIn(model_name, got[0][constants.PREDICTIONS_KEY]) - self.assertAllClose( - got[0][constants.PREDICTIONS_KEY]['model1'], np.array([1.0, 1.0]) - ) - self.assertAllClose( - got[0][constants.PREDICTIONS_KEY]['model2'], - { - 'output1': np.array([1.0, 1.0]), - 'output2': np.array([0.0, 1.0]), - }, - ) + def check_result(got): + try: + self.assertLen(got, 1) + for model_name in ("model1", "model2"): + self.assertIn(model_name, got[0][constants.PREDICTIONS_KEY]) + self.assertAllClose( + got[0][constants.PREDICTIONS_KEY]["model1"], + np.array([1.0, 1.0]), + ) + self.assertAllClose( + got[0][constants.PREDICTIONS_KEY]["model2"], + { + "output1": np.array([1.0, 1.0]), + "output2": np.array([0.0, 1.0]), + }, + ) - except AssertionError as err: - raise util.BeamAssertException(err) + except AssertionError as err: + raise util.BeamAssertException(err) - util.assert_that(result, check_result, label='result') + util.assert_that(result, check_result, label="result") -if __name__ == '__main__': - tf.compat.v1.enable_v2_behavior() - tf.test.main() +if __name__ == "__main__": + tf.compat.v1.enable_v2_behavior() + tf.test.main() diff --git a/tensorflow_model_analysis/extractors/predictions_extractor.py b/tensorflow_model_analysis/extractors/predictions_extractor.py index 7ba9975911..c609b716dc 100644 --- a/tensorflow_model_analysis/extractors/predictions_extractor.py +++ b/tensorflow_model_analysis/extractors/predictions_extractor.py @@ -15,17 +15,19 @@ from typing import List, Optional, Sequence -from absl import logging import apache_beam as beam +from absl import logging + from tensorflow_model_analysis import constants from tensorflow_model_analysis.api import types -from tensorflow_model_analysis.extractors import extractor -from tensorflow_model_analysis.extractors import materialized_predictions_extractor +from tensorflow_model_analysis.extractors import ( + extractor, + materialized_predictions_extractor, +) from tensorflow_model_analysis.proto import config_pb2 from tensorflow_model_analysis.utils import model_util - -PREDICTIONS_EXTRACTOR_STAGE_NAME = 'ExtractPredictions' +PREDICTIONS_EXTRACTOR_STAGE_NAME = "ExtractPredictions" def PredictionsExtractor( @@ -33,58 +35,60 @@ def PredictionsExtractor( eval_shared_model: Optional[types.MaybeMultipleEvalSharedModels] = None, output_keypath: Sequence[str] = (constants.PREDICTIONS_KEY,), ) -> extractor.Extractor: - """Creates an extractor for performing predictions over a batch. - - The extractor's PTransform loads and runs the serving saved_model(s) against - every Extracts yielding a copy of the incoming Extracts with an additional - Extracts added for the predictions keyed by tfma.PREDICTIONS_KEY. The model - inputs are searched for under tfma.FEATURES_KEY (keras only) or tfma.INPUT_KEY - (if tfma.FEATURES_KEY is not set or the model is non-keras). If multiple - models are used the predictions will be stored in a dict keyed by model name. - - Note that the prediction_key in the ModelSpecs also serves as a key into the - dict of the prediction's output. - - Args: - eval_config: Eval config. - eval_shared_model: Shared model (single-model evaluation) or list of shared - models (multi-model evaluation) or None (predictions obtained from - features). - output_keypath: A sequence of keys to be used as the path to traverse and - insert the outputs in the extract. - - Returns: - Extractor for extracting predictions. - """ - # TODO(b/239975835): Remove this Optional support for version 1.0. - if eval_shared_model is None: - logging.warning( - 'Calling the PredictionsExtractor with eval_shared_model=None is ' - 'deprecated and no longer supported. This will break in version 1.0. ' - 'Please update your implementation to call ' - 'MaterializedPredictionsExtractor directly.' - ) - _, ptransform = ( - materialized_predictions_extractor.MaterializedPredictionsExtractor( - eval_config, output_keypath=output_keypath + """Creates an extractor for performing predictions over a batch. + + The extractor's PTransform loads and runs the serving saved_model(s) against + every Extracts yielding a copy of the incoming Extracts with an additional + Extracts added for the predictions keyed by tfma.PREDICTIONS_KEY. The model + inputs are searched for under tfma.FEATURES_KEY (keras only) or tfma.INPUT_KEY + (if tfma.FEATURES_KEY is not set or the model is non-keras). If multiple + models are used the predictions will be stored in a dict keyed by model name. + + Note that the prediction_key in the ModelSpecs also serves as a key into the + dict of the prediction's output. + + Args: + ---- + eval_config: Eval config. + eval_shared_model: Shared model (single-model evaluation) or list of shared + models (multi-model evaluation) or None (predictions obtained from + features). + output_keypath: A sequence of keys to be used as the path to traverse and + insert the outputs in the extract. + + Returns: + ------- + Extractor for extracting predictions. + """ + # TODO(b/239975835): Remove this Optional support for version 1.0. + if eval_shared_model is None: + logging.warning( + "Calling the PredictionsExtractor with eval_shared_model=None is " + "deprecated and no longer supported. This will break in version 1.0. " + "Please update your implementation to call " + "MaterializedPredictionsExtractor directly." ) - ) - # Note we are changing the stage name here for backwards compatibility. Old - # clients expect these code paths to have the same stage name. New clients - # should never reference the private stage name. + _, ptransform = ( + materialized_predictions_extractor.MaterializedPredictionsExtractor( + eval_config, output_keypath=output_keypath + ) + ) + # Note we are changing the stage name here for backwards compatibility. Old + # clients expect these code paths to have the same stage name. New clients + # should never reference the private stage name. + return extractor.Extractor( + stage_name=PREDICTIONS_EXTRACTOR_STAGE_NAME, ptransform=ptransform + ) + return extractor.Extractor( - stage_name=PREDICTIONS_EXTRACTOR_STAGE_NAME, ptransform=ptransform + stage_name=PREDICTIONS_EXTRACTOR_STAGE_NAME, + ptransform=_ModelSignaturesInferenceWrapper( # pylint: disable=no-value-for-parameter + model_specs=list(eval_config.model_specs), + eval_shared_model=eval_shared_model, + output_keypath=output_keypath, + ), ) - return extractor.Extractor( - stage_name=PREDICTIONS_EXTRACTOR_STAGE_NAME, - ptransform=_ModelSignaturesInferenceWrapper( # pylint: disable=no-value-for-parameter - model_specs=list(eval_config.model_specs), - eval_shared_model=eval_shared_model, - output_keypath=output_keypath, - ), - ) - @beam.ptransform_fn @beam.typehints.with_input_types(types.Extracts) @@ -95,45 +99,47 @@ def _ModelSignaturesInferenceWrapper( eval_shared_model: types.MaybeMultipleEvalSharedModels, output_keypath: Sequence[str], ) -> beam.pvalue.PCollection: - """A PTransform that adds predictions and possibly other tensors to Extracts. - - Args: - extracts: PCollection of Extracts containing model inputs keyed by - tfma.FEATURES_KEY (if model inputs are named) or tfma.INPUTS_KEY (if model - takes raw tf.Examples as input). - model_specs: Model specs each of which corresponds to each of the - eval_shared_models. - eval_shared_model: Shared model parameters keyed by model name. - output_keypath: A sequence of keys to be used as the path to traverse and - insert the outputs in the extract. - - Returns: - PCollection of Extracts updated with the predictions. - """ - eval_shared_models = model_util.verify_and_update_eval_shared_models( - eval_shared_model - ) - # This should never happen, but verify_and_update_eval_shared_models can - # theoretically return None or empty iterables. - if not eval_shared_models: - raise ValueError( - 'No valid model(s) were provided. Please ensure that ' - 'EvalConfig.ModelSpec is correctly configured to enable ' - 'using the PredictionsExtractor.' + """A PTransform that adds predictions and possibly other tensors to Extracts. + + Args: + ---- + extracts: PCollection of Extracts containing model inputs keyed by + tfma.FEATURES_KEY (if model inputs are named) or tfma.INPUTS_KEY (if model + takes raw tf.Examples as input). + model_specs: Model specs each of which corresponds to each of the + eval_shared_models. + eval_shared_model: Shared model parameters keyed by model name. + output_keypath: A sequence of keys to be used as the path to traverse and + insert the outputs in the extract. + + Returns: + ------- + PCollection of Extracts updated with the predictions. + """ + eval_shared_models = model_util.verify_and_update_eval_shared_models( + eval_shared_model ) + # This should never happen, but verify_and_update_eval_shared_models can + # theoretically return None or empty iterables. + if not eval_shared_models: + raise ValueError( + "No valid model(s) were provided. Please ensure that " + "EvalConfig.ModelSpec is correctly configured to enable " + "using the PredictionsExtractor." + ) - name_to_eval_shared_model = {m.model_name: m for m in eval_shared_models} - signature_names = {} - for model_spec in model_specs: - model_name = '' if len(model_specs) == 1 else model_spec.name - signature_names[model_name] = [model_spec.signature_name] - - return extracts | 'Inference' >> beam.ParDo( - model_util.ModelSignaturesDoFn( - model_specs=model_specs, - eval_shared_models=name_to_eval_shared_model, - output_keypath=output_keypath, - signature_names=signature_names, - prefer_dict_outputs=False, - ) - ) + name_to_eval_shared_model = {m.model_name: m for m in eval_shared_models} + signature_names = {} + for model_spec in model_specs: + model_name = "" if len(model_specs) == 1 else model_spec.name + signature_names[model_name] = [model_spec.signature_name] + + return extracts | "Inference" >> beam.ParDo( + model_util.ModelSignaturesDoFn( + model_specs=model_specs, + eval_shared_models=name_to_eval_shared_model, + output_keypath=output_keypath, + signature_names=signature_names, + prefer_dict_outputs=False, + ) + ) diff --git a/tensorflow_model_analysis/extractors/predictions_extractor_test.py b/tensorflow_model_analysis/extractors/predictions_extractor_test.py index 5975cc9fe7..a05054b876 100644 --- a/tensorflow_model_analysis/extractors/predictions_extractor_test.py +++ b/tensorflow_model_analysis/extractors/predictions_extractor_test.py @@ -15,97 +15,98 @@ import os -from absl.testing import parameterized import apache_beam as beam -from apache_beam.testing import util import numpy as np import tensorflow as tf +from absl.testing import parameterized +from apache_beam.testing import util +from google.protobuf import text_format +from tensorflow_metadata.proto.v0 import schema_pb2 +from tfx_bsl.tfxio import tensor_adapter +from tfx_bsl.tfxio import test_util as tfx_bsl_test_util + from tensorflow_model_analysis import constants from tensorflow_model_analysis.api import model_eval_lib -from tensorflow_model_analysis.extractors import features_extractor -from tensorflow_model_analysis.extractors import predictions_extractor +from tensorflow_model_analysis.extractors import ( + features_extractor, + predictions_extractor, +) from tensorflow_model_analysis.proto import config_pb2 from tensorflow_model_analysis.utils import test_util from tensorflow_model_analysis.utils.keras_lib import tf_keras -from tfx_bsl.tfxio import tensor_adapter -from tfx_bsl.tfxio import test_util as tfx_bsl_test_util - -from google.protobuf import text_format -from tensorflow_metadata.proto.v0 import schema_pb2 class PredictionsExtractorTest( test_util.TensorflowModelAnalysisTest, parameterized.TestCase ): - - def _getExportDir(self): - return os.path.join(self._getTempDir(), 'export_dir') - - def _create_tfxio_and_feature_extractor( - self, eval_config: config_pb2.EvalConfig, schema: schema_pb2.Schema - ): - tfx_io = tfx_bsl_test_util.InMemoryTFExampleRecord( - schema=schema, raw_record_column_name=constants.ARROW_INPUT_COLUMN - ) - tensor_adapter_config = tensor_adapter.TensorAdapterConfig( - arrow_schema=tfx_io.ArrowSchema(), - tensor_representations=tfx_io.TensorRepresentations(), - ) - feature_extractor = features_extractor.FeaturesExtractor( - eval_config=eval_config, - tensor_representations=tensor_adapter_config.tensor_representations, - ) - return tfx_io, feature_extractor - - # Note: The funtionality covered in this unit test is not supported by - # PredictionExtractorOSS. This Keras model accepts multiple input tensors, - # and does not include a signature that # accepts serialized input - # (i.e. string). This is a requirement for using the bulk inference APIs which - # only support serialized input right now. - @parameterized.named_parameters( - ('ModelSignaturesDoFnInferenceCallableModel', ''), - ('ModelSignaturesDoFnInferenceServingDefault', 'serving_default'), - ) - def testPredictionsExtractorWithKerasModel(self, signature_name): - input1 = tf_keras.layers.Input(shape=(2,), name='input1') - input2 = tf_keras.layers.Input(shape=(2,), name='input2') - inputs = [input1, input2] - input_layer = tf_keras.layers.concatenate(inputs) - output_layer = tf_keras.layers.Dense( - 1, activation=tf.nn.sigmoid, name='output' - )(input_layer) - model = tf_keras.models.Model(inputs, output_layer) - model.compile( - optimizer=tf_keras.optimizers.Adam(lr=0.001), - loss=tf_keras.losses.binary_crossentropy, - metrics=['accuracy'], + def _getExportDir(self): + return os.path.join(self._getTempDir(), "export_dir") + + def _create_tfxio_and_feature_extractor( + self, eval_config: config_pb2.EvalConfig, schema: schema_pb2.Schema + ): + tfx_io = tfx_bsl_test_util.InMemoryTFExampleRecord( + schema=schema, raw_record_column_name=constants.ARROW_INPUT_COLUMN + ) + tensor_adapter_config = tensor_adapter.TensorAdapterConfig( + arrow_schema=tfx_io.ArrowSchema(), + tensor_representations=tfx_io.TensorRepresentations(), + ) + feature_extractor = features_extractor.FeaturesExtractor( + eval_config=eval_config, + tensor_representations=tensor_adapter_config.tensor_representations, + ) + return tfx_io, feature_extractor + + # Note: The funtionality covered in this unit test is not supported by + # PredictionExtractorOSS. This Keras model accepts multiple input tensors, + # and does not include a signature that # accepts serialized input + # (i.e. string). This is a requirement for using the bulk inference APIs which + # only support serialized input right now. + @parameterized.named_parameters( + ("ModelSignaturesDoFnInferenceCallableModel", ""), + ("ModelSignaturesDoFnInferenceServingDefault", "serving_default"), ) + def testPredictionsExtractorWithKerasModel(self, signature_name): + input1 = tf_keras.layers.Input(shape=(2,), name="input1") + input2 = tf_keras.layers.Input(shape=(2,), name="input2") + inputs = [input1, input2] + input_layer = tf_keras.layers.concatenate(inputs) + output_layer = tf_keras.layers.Dense( + 1, activation=tf.nn.sigmoid, name="output" + )(input_layer) + model = tf_keras.models.Model(inputs, output_layer) + model.compile( + optimizer=tf_keras.optimizers.Adam(lr=0.001), + loss=tf_keras.losses.binary_crossentropy, + metrics=["accuracy"], + ) - train_features = { - 'input1': [[0.0, 0.0], [1.0, 1.0]], - 'input2': [[1.0, 1.0], [0.0, 0.0]], - } - labels = [[1], [0]] - example_weights = [1.0, 0.5] - dataset = tf.data.Dataset.from_tensor_slices( - (train_features, labels, example_weights) - ) - dataset = dataset.shuffle(buffer_size=1).repeat().batch(2) - model.fit(dataset, steps_per_epoch=1) + train_features = { + "input1": [[0.0, 0.0], [1.0, 1.0]], + "input2": [[1.0, 1.0], [0.0, 0.0]], + } + labels = [[1], [0]] + example_weights = [1.0, 0.5] + dataset = tf.data.Dataset.from_tensor_slices( + (train_features, labels, example_weights) + ) + dataset = dataset.shuffle(buffer_size=1).repeat().batch(2) + model.fit(dataset, steps_per_epoch=1) - export_dir = self._getExportDir() - model.save(export_dir, save_format='tf') + export_dir = self._getExportDir() + model.save(export_dir, save_format="tf") - eval_config = config_pb2.EvalConfig( - model_specs=[config_pb2.ModelSpec(signature_name=signature_name)] - ) - eval_shared_model = self.createKerasTestEvalSharedModel( - eval_saved_model_path=export_dir, eval_config=eval_config - ) - tfx_io, feature_extractor = self._create_tfxio_and_feature_extractor( - eval_config, - text_format.Parse( - """ + eval_config = config_pb2.EvalConfig( + model_specs=[config_pb2.ModelSpec(signature_name=signature_name)] + ) + eval_shared_model = self.createKerasTestEvalSharedModel( + eval_saved_model_path=export_dir, eval_config=eval_config + ) + tfx_io, feature_extractor = self._create_tfxio_and_feature_extractor( + eval_config, + text_format.Parse( + """ tensor_representation_group { key: "" value { @@ -142,92 +143,94 @@ def testPredictionsExtractorWithKerasModel(self, signature_name): type: INT } """, - schema_pb2.Schema(), - ), - ) - prediction_extractor = predictions_extractor.PredictionsExtractor( - eval_config=eval_config, eval_shared_model=eval_shared_model - ) - - examples = [ - self._makeExample( - input1=[0.0, 0.0], input2=[1.0, 1.0], non_model_feature=0 - ), # should be ignored by model - self._makeExample( - input1=[1.0, 1.0], input2=[0.0, 0.0], non_model_feature=1 - ), # should be ignored by model - ] - num_examples = len(examples) - - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'Create' - >> beam.Create( - [e.SerializeToString() for e in examples], reshuffle=False - ) - | 'BatchExamples' >> tfx_io.BeamSource(batch_size=num_examples) - | 'InputsToExtracts' >> model_eval_lib.BatchedInputsToExtracts() - | feature_extractor.stage_name >> feature_extractor.ptransform - | prediction_extractor.stage_name >> prediction_extractor.ptransform - ) - # pylint: enable=no-value-for-parameter - - def check_result(got): - try: - self.assertLen(got, 1) - # We can't verify the actual predictions, but we can verify the keys. - self.assertIn(constants.PREDICTIONS_KEY, got[0]) - - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(result, check_result) - - # Note: The funtionality covered in this unit test is not supported by - # PredictionExtractorOSS. This Keras model does not include a signature that - # accepts serialized input (i.e. string). This is a requirement for using the - # bulk inference APIs which only support serialized input right now. - @parameterized.named_parameters( - ('ModelSignaturesDoFnInferenceCallableModel', ''), - ('ModelSignaturesDoFnInferenceServingDefault', 'serving_default'), - ) - def testPredictionsExtractorWithSequentialKerasModel(self, signature_name): - # Note that the input will be called 'test_input' - model = tf_keras.models.Sequential([ - tf_keras.layers.Dense( - 1, activation=tf.nn.sigmoid, input_shape=(2,), name='test' + schema_pb2.Schema(), + ), + ) + prediction_extractor = predictions_extractor.PredictionsExtractor( + eval_config=eval_config, eval_shared_model=eval_shared_model ) - ]) - model.compile( - optimizer=tf_keras.optimizers.Adam(lr=0.001), - loss=tf_keras.losses.binary_crossentropy, - metrics=['accuracy'], - ) - train_features = {'test_input': [[0.0, 0.0], [1.0, 1.0]]} - labels = [[1], [0]] - example_weights = [1.0, 0.5] - dataset = tf.data.Dataset.from_tensor_slices( - (train_features, labels, example_weights) + examples = [ + self._makeExample( + input1=[0.0, 0.0], input2=[1.0, 1.0], non_model_feature=0 + ), # should be ignored by model + self._makeExample( + input1=[1.0, 1.0], input2=[0.0, 0.0], non_model_feature=1 + ), # should be ignored by model + ] + num_examples = len(examples) + + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "Create" + >> beam.Create( + [e.SerializeToString() for e in examples], reshuffle=False + ) + | "BatchExamples" >> tfx_io.BeamSource(batch_size=num_examples) + | "InputsToExtracts" >> model_eval_lib.BatchedInputsToExtracts() + | feature_extractor.stage_name >> feature_extractor.ptransform + | prediction_extractor.stage_name >> prediction_extractor.ptransform + ) + # pylint: enable=no-value-for-parameter + + def check_result(got): + try: + self.assertLen(got, 1) + # We can't verify the actual predictions, but we can verify the keys. + self.assertIn(constants.PREDICTIONS_KEY, got[0]) + + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(result, check_result) + + # Note: The funtionality covered in this unit test is not supported by + # PredictionExtractorOSS. This Keras model does not include a signature that + # accepts serialized input (i.e. string). This is a requirement for using the + # bulk inference APIs which only support serialized input right now. + @parameterized.named_parameters( + ("ModelSignaturesDoFnInferenceCallableModel", ""), + ("ModelSignaturesDoFnInferenceServingDefault", "serving_default"), ) - dataset = dataset.shuffle(buffer_size=1).repeat().batch(2) - model.fit(dataset, steps_per_epoch=1) + def testPredictionsExtractorWithSequentialKerasModel(self, signature_name): + # Note that the input will be called 'test_input' + model = tf_keras.models.Sequential( + [ + tf_keras.layers.Dense( + 1, activation=tf.nn.sigmoid, input_shape=(2,), name="test" + ) + ] + ) + model.compile( + optimizer=tf_keras.optimizers.Adam(lr=0.001), + loss=tf_keras.losses.binary_crossentropy, + metrics=["accuracy"], + ) - export_dir = self._getExportDir() - model.save(export_dir, save_format='tf') + train_features = {"test_input": [[0.0, 0.0], [1.0, 1.0]]} + labels = [[1], [0]] + example_weights = [1.0, 0.5] + dataset = tf.data.Dataset.from_tensor_slices( + (train_features, labels, example_weights) + ) + dataset = dataset.shuffle(buffer_size=1).repeat().batch(2) + model.fit(dataset, steps_per_epoch=1) - eval_config = config_pb2.EvalConfig( - model_specs=[config_pb2.ModelSpec(signature_name=signature_name)] - ) - eval_shared_model = self.createKerasTestEvalSharedModel( - eval_saved_model_path=export_dir, eval_config=eval_config - ) - tfx_io, feature_extractor = self._create_tfxio_and_feature_extractor( - eval_config, - text_format.Parse( - """ + export_dir = self._getExportDir() + model.save(export_dir, save_format="tf") + + eval_config = config_pb2.EvalConfig( + model_specs=[config_pb2.ModelSpec(signature_name=signature_name)] + ) + eval_shared_model = self.createKerasTestEvalSharedModel( + eval_saved_model_path=export_dir, eval_config=eval_config + ) + tfx_io, feature_extractor = self._create_tfxio_and_feature_extractor( + eval_config, + text_format.Parse( + """ tensor_representation_group { key: "" value { @@ -251,86 +254,86 @@ def testPredictionsExtractorWithSequentialKerasModel(self, signature_name): type: INT } """, - schema_pb2.Schema(), - ), - ) - prediction_extractor = predictions_extractor.PredictionsExtractor( - eval_config=eval_config, eval_shared_model=eval_shared_model - ) + schema_pb2.Schema(), + ), + ) + prediction_extractor = predictions_extractor.PredictionsExtractor( + eval_config=eval_config, eval_shared_model=eval_shared_model + ) - # Notice that the features are 'test' but the model expects 'test_input'. - # This tests that the PredictExtractor properly handles this case. - examples = [ - self._makeExample( - test=[0.0, 0.0], non_model_feature=0 - ), # should be ignored by model - self._makeExample( - test=[1.0, 1.0], non_model_feature=1 - ), # should be ignored by model - ] - num_examples = len(examples) - - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'Create' - >> beam.Create( - [e.SerializeToString() for e in examples], reshuffle=False - ) - | 'BatchExamples' >> tfx_io.BeamSource(batch_size=num_examples) - | 'InputsToExtracts' >> model_eval_lib.BatchedInputsToExtracts() - | feature_extractor.stage_name >> feature_extractor.ptransform - | prediction_extractor.stage_name >> prediction_extractor.ptransform - ) - # pylint: enable=no-value-for-parameter - - def check_result(got): - try: - self.assertLen(got, 1) - # We can't verify the actual predictions, but we can verify the keys. - self.assertIn(constants.PREDICTIONS_KEY, got[0]) - - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(result, check_result) - - # Note: The funtionality covered in this unit test is not supported by - # PredictionExtractorOSS. This Keras model accepts multiple input tensors, - # and does not include a signature that # accepts serialized input - # (i.e. string). This is a requirement for using the bulk inference APIs which - # only support serialized input right now. - def testBatchSizeLimitWithKerasModel(self): - input1 = tf_keras.layers.Input(shape=(1,), batch_size=1, name='input1') - input2 = tf_keras.layers.Input(shape=(1,), batch_size=1, name='input2') - - inputs = [input1, input2] - input_layer = tf_keras.layers.concatenate(inputs) - - def add_1(tensor): - return tf.add_n([tensor, tf.constant(1.0, shape=(1, 2))]) - - assert_layer = tf_keras.layers.Lambda(add_1)(input_layer) - - model = tf_keras.models.Model(inputs, assert_layer) - model.compile( - optimizer=tf_keras.optimizers.Adam(lr=0.001), - loss=tf_keras.losses.binary_crossentropy, - metrics=['accuracy'], - ) + # Notice that the features are 'test' but the model expects 'test_input'. + # This tests that the PredictExtractor properly handles this case. + examples = [ + self._makeExample( + test=[0.0, 0.0], non_model_feature=0 + ), # should be ignored by model + self._makeExample( + test=[1.0, 1.0], non_model_feature=1 + ), # should be ignored by model + ] + num_examples = len(examples) + + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "Create" + >> beam.Create( + [e.SerializeToString() for e in examples], reshuffle=False + ) + | "BatchExamples" >> tfx_io.BeamSource(batch_size=num_examples) + | "InputsToExtracts" >> model_eval_lib.BatchedInputsToExtracts() + | feature_extractor.stage_name >> feature_extractor.ptransform + | prediction_extractor.stage_name >> prediction_extractor.ptransform + ) + # pylint: enable=no-value-for-parameter + + def check_result(got): + try: + self.assertLen(got, 1) + # We can't verify the actual predictions, but we can verify the keys. + self.assertIn(constants.PREDICTIONS_KEY, got[0]) + + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(result, check_result) + + # Note: The funtionality covered in this unit test is not supported by + # PredictionExtractorOSS. This Keras model accepts multiple input tensors, + # and does not include a signature that # accepts serialized input + # (i.e. string). This is a requirement for using the bulk inference APIs which + # only support serialized input right now. + def testBatchSizeLimitWithKerasModel(self): + input1 = tf_keras.layers.Input(shape=(1,), batch_size=1, name="input1") + input2 = tf_keras.layers.Input(shape=(1,), batch_size=1, name="input2") + + inputs = [input1, input2] + input_layer = tf_keras.layers.concatenate(inputs) + + def add_1(tensor): + return tf.add_n([tensor, tf.constant(1.0, shape=(1, 2))]) + + assert_layer = tf_keras.layers.Lambda(add_1)(input_layer) + + model = tf_keras.models.Model(inputs, assert_layer) + model.compile( + optimizer=tf_keras.optimizers.Adam(lr=0.001), + loss=tf_keras.losses.binary_crossentropy, + metrics=["accuracy"], + ) - export_dir = self._getExportDir() - model.save(export_dir, save_format='tf') + export_dir = self._getExportDir() + model.save(export_dir, save_format="tf") - eval_config = config_pb2.EvalConfig(model_specs=[config_pb2.ModelSpec()]) - eval_shared_model = self.createKerasTestEvalSharedModel( - eval_saved_model_path=export_dir, eval_config=eval_config - ) - tfx_io, feature_extractor = self._create_tfxio_and_feature_extractor( - eval_config, - text_format.Parse( - """ + eval_config = config_pb2.EvalConfig(model_specs=[config_pb2.ModelSpec()]) + eval_shared_model = self.createKerasTestEvalSharedModel( + eval_saved_model_path=export_dir, eval_config=eval_config + ) + tfx_io, feature_extractor = self._create_tfxio_and_feature_extractor( + eval_config, + text_format.Parse( + """ tensor_representation_group { key: "" value { @@ -363,54 +366,52 @@ def add_1(tensor): type: FLOAT } """, - schema_pb2.Schema(), - ), - ) - prediction_extractor = predictions_extractor.PredictionsExtractor( - eval_config=eval_config, eval_shared_model=eval_shared_model - ) + schema_pb2.Schema(), + ), + ) + prediction_extractor = predictions_extractor.PredictionsExtractor( + eval_config=eval_config, eval_shared_model=eval_shared_model + ) - examples = [] - for _ in range(4): - examples.append(self._makeExample(input1=0.0, input2=1.0)) - - with beam.Pipeline() as pipeline: - predict_extracts = ( - pipeline - | 'Create' - >> beam.Create( - [e.SerializeToString() for e in examples], reshuffle=False - ) - | 'BatchExamples' >> tfx_io.BeamSource(batch_size=1) - | 'InputsToExtracts' >> model_eval_lib.BatchedInputsToExtracts() - | feature_extractor.stage_name >> feature_extractor.ptransform - | prediction_extractor.stage_name >> prediction_extractor.ptransform - ) - - def check_result(got): - try: - self.assertLen(got, 4) - # We can't verify the actual predictions, but we can verify the keys. - for item in got: - self.assertIn(constants.PREDICTIONS_KEY, item) - - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(predict_extracts, check_result) - - # TODO(b/239975835): Remove this test for version 1.0. - def testRekeyPredictionsInFeaturesForPrematerializedPredictions(self): - model_spec1 = config_pb2.ModelSpec( - name='model1', prediction_key='prediction' - ) - model_spec2 = config_pb2.ModelSpec( - name='model2', - prediction_keys={'output1': 'prediction1', 'output2': 'prediction2'}, - ) - eval_config = config_pb2.EvalConfig(model_specs=[model_spec1, model_spec2]) - schema = text_format.Parse( - """ + examples = [] + for _ in range(4): + examples.append(self._makeExample(input1=0.0, input2=1.0)) + + with beam.Pipeline() as pipeline: + predict_extracts = ( + pipeline + | "Create" + >> beam.Create( + [e.SerializeToString() for e in examples], reshuffle=False + ) + | "BatchExamples" >> tfx_io.BeamSource(batch_size=1) + | "InputsToExtracts" >> model_eval_lib.BatchedInputsToExtracts() + | feature_extractor.stage_name >> feature_extractor.ptransform + | prediction_extractor.stage_name >> prediction_extractor.ptransform + ) + + def check_result(got): + try: + self.assertLen(got, 4) + # We can't verify the actual predictions, but we can verify the keys. + for item in got: + self.assertIn(constants.PREDICTIONS_KEY, item) + + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(predict_extracts, check_result) + + # TODO(b/239975835): Remove this test for version 1.0. + def testRekeyPredictionsInFeaturesForPrematerializedPredictions(self): + model_spec1 = config_pb2.ModelSpec(name="model1", prediction_key="prediction") + model_spec2 = config_pb2.ModelSpec( + name="model2", + prediction_keys={"output1": "prediction1", "output2": "prediction2"}, + ) + eval_config = config_pb2.EvalConfig(model_specs=[model_spec1, model_spec2]) + schema = text_format.Parse( + """ tensor_representation_group { key: "" value { @@ -447,68 +448,69 @@ def testRekeyPredictionsInFeaturesForPrematerializedPredictions(self): type: INT } """, - schema_pb2.Schema(), - ) - # TODO(b/73109633): Remove when field is removed or its default changes to - # False. - if hasattr(schema, 'generate_legacy_feature_spec'): - schema.generate_legacy_feature_spec = False - tfx_io, feature_extractor = self._create_tfxio_and_feature_extractor( - eval_config, schema - ) + schema_pb2.Schema(), + ) + # TODO(b/73109633): Remove when field is removed or its default changes to + # False. + if hasattr(schema, "generate_legacy_feature_spec"): + schema.generate_legacy_feature_spec = False + tfx_io, feature_extractor = self._create_tfxio_and_feature_extractor( + eval_config, schema + ) - examples = [ - self._makeExample( - prediction=1.0, prediction1=1.0, prediction2=0.0, fixed_int=1 - ), - self._makeExample( - prediction=1.0, prediction1=1.0, prediction2=1.0, fixed_int=1 - ), - ] - num_examples = len(examples) - - prediction_extractor = predictions_extractor.PredictionsExtractor( - eval_config=eval_config, eval_shared_model=None - ) + examples = [ + self._makeExample( + prediction=1.0, prediction1=1.0, prediction2=0.0, fixed_int=1 + ), + self._makeExample( + prediction=1.0, prediction1=1.0, prediction2=1.0, fixed_int=1 + ), + ] + num_examples = len(examples) + + prediction_extractor = predictions_extractor.PredictionsExtractor( + eval_config=eval_config, eval_shared_model=None + ) - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'Create' - >> beam.Create( - [e.SerializeToString() for e in examples], reshuffle=False - ) - | 'BatchExamples' >> tfx_io.BeamSource(batch_size=num_examples) - | 'InputsToExtracts' >> model_eval_lib.BatchedInputsToExtracts() - | feature_extractor.stage_name >> feature_extractor.ptransform - | prediction_extractor.stage_name >> prediction_extractor.ptransform - ) - # pylint: enable=no-value-for-parameter - - def check_result(got): - try: - self.assertLen(got, 1) - for model_name in ('model1', 'model2'): - self.assertIn(model_name, got[0][constants.PREDICTIONS_KEY]) - self.assertAllClose( - np.array([1.0, 1.0]), got[0][constants.PREDICTIONS_KEY]['model1'] - ) - - self.assertAllClose( - { - 'output1': np.array([1.0, 1.0]), - 'output2': np.array([0.0, 1.0]), - }, - got[0][constants.PREDICTIONS_KEY]['model2'], - ) - - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(result, check_result) - - -if __name__ == '__main__': - tf.compat.v1.enable_v2_behavior() - tf.test.main() + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "Create" + >> beam.Create( + [e.SerializeToString() for e in examples], reshuffle=False + ) + | "BatchExamples" >> tfx_io.BeamSource(batch_size=num_examples) + | "InputsToExtracts" >> model_eval_lib.BatchedInputsToExtracts() + | feature_extractor.stage_name >> feature_extractor.ptransform + | prediction_extractor.stage_name >> prediction_extractor.ptransform + ) + # pylint: enable=no-value-for-parameter + + def check_result(got): + try: + self.assertLen(got, 1) + for model_name in ("model1", "model2"): + self.assertIn(model_name, got[0][constants.PREDICTIONS_KEY]) + self.assertAllClose( + np.array([1.0, 1.0]), + got[0][constants.PREDICTIONS_KEY]["model1"], + ) + + self.assertAllClose( + { + "output1": np.array([1.0, 1.0]), + "output2": np.array([0.0, 1.0]), + }, + got[0][constants.PREDICTIONS_KEY]["model2"], + ) + + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(result, check_result) + + +if __name__ == "__main__": + tf.compat.v1.enable_v2_behavior() + tf.test.main() diff --git a/tensorflow_model_analysis/extractors/slice_key_extractor.py b/tensorflow_model_analysis/extractors/slice_key_extractor.py index ae80b45ba0..d99cc8ae5e 100644 --- a/tensorflow_model_analysis/extractors/slice_key_extractor.py +++ b/tensorflow_model_analysis/extractors/slice_key_extractor.py @@ -17,6 +17,7 @@ from typing import List, Optional import apache_beam as beam + from tensorflow_model_analysis import constants from tensorflow_model_analysis.api import types from tensorflow_model_analysis.extractors import extractor @@ -24,7 +25,7 @@ from tensorflow_model_analysis.slicer import slicer_lib as slicer from tensorflow_model_analysis.utils import util -SLICE_KEY_EXTRACTOR_STAGE_NAME = 'ExtractSliceKeys' +SLICE_KEY_EXTRACTOR_STAGE_NAME = "ExtractSliceKeys" def SliceKeyExtractor( @@ -32,126 +33,126 @@ def SliceKeyExtractor( eval_config: Optional[config_pb2.EvalConfig] = None, materialize: Optional[bool] = True, ) -> extractor.Extractor: - """Creates an extractor for extracting slice keys. - - The incoming Extracts must contain features stored under tfma.FEATURES_KEY - and optionally under tfma.TRANSFORMED_FEATURES. - - The extractor's PTransform yields a copy of the Extracts input with an - additional extract pointing at the list of SliceKeyType values keyed by - tfma.SLICE_KEY_TYPES_KEY. If materialize is True then a materialized version - of the slice keys will be added under the key tfma.SLICE_KEYS_KEY. - - Args: - slice_spec: Deprecated (use EvalConfig). - eval_config: Optional EvalConfig containing slicing_specs specifying the - slices to slice the data into. If slicing_specs are empty, defaults to - overall slice. - materialize: True to add MaterializedColumn entries for the slice keys. - - Returns: - Extractor for slice keys. - """ - if slice_spec and eval_config: - raise ValueError('slice_spec is deprecated, only use eval_config') - if eval_config: - slice_spec = [ - slicer.SingleSliceSpec(spec=spec) for spec in eval_config.slicing_specs - ] - for cross_slice_spec in eval_config.cross_slicing_specs: - baseline_slice_spec = slicer.SingleSliceSpec( - spec=cross_slice_spec.baseline_spec - ) - if baseline_slice_spec not in slice_spec: - slice_spec.append(baseline_slice_spec) - for spec in cross_slice_spec.slicing_specs: - comparison_slice_spec = slicer.SingleSliceSpec(spec=spec) - if comparison_slice_spec not in slice_spec: - slice_spec.append(comparison_slice_spec) - if not slice_spec: - slice_spec = [slicer.SingleSliceSpec()] - return extractor.Extractor( - stage_name=SLICE_KEY_EXTRACTOR_STAGE_NAME, - ptransform=ExtractSliceKeys(slice_spec, eval_config, materialize), - ) + """Creates an extractor for extracting slice keys. + + The incoming Extracts must contain features stored under tfma.FEATURES_KEY + and optionally under tfma.TRANSFORMED_FEATURES. + + The extractor's PTransform yields a copy of the Extracts input with an + additional extract pointing at the list of SliceKeyType values keyed by + tfma.SLICE_KEY_TYPES_KEY. If materialize is True then a materialized version + of the slice keys will be added under the key tfma.SLICE_KEYS_KEY. + + Args: + ---- + slice_spec: Deprecated (use EvalConfig). + eval_config: Optional EvalConfig containing slicing_specs specifying the + slices to slice the data into. If slicing_specs are empty, defaults to + overall slice. + materialize: True to add MaterializedColumn entries for the slice keys. + + Returns: + ------- + Extractor for slice keys. + """ + if slice_spec and eval_config: + raise ValueError("slice_spec is deprecated, only use eval_config") + if eval_config: + slice_spec = [ + slicer.SingleSliceSpec(spec=spec) for spec in eval_config.slicing_specs + ] + for cross_slice_spec in eval_config.cross_slicing_specs: + baseline_slice_spec = slicer.SingleSliceSpec( + spec=cross_slice_spec.baseline_spec + ) + if baseline_slice_spec not in slice_spec: + slice_spec.append(baseline_slice_spec) + for spec in cross_slice_spec.slicing_specs: + comparison_slice_spec = slicer.SingleSliceSpec(spec=spec) + if comparison_slice_spec not in slice_spec: + slice_spec.append(comparison_slice_spec) + if not slice_spec: + slice_spec = [slicer.SingleSliceSpec()] + return extractor.Extractor( + stage_name=SLICE_KEY_EXTRACTOR_STAGE_NAME, + ptransform=ExtractSliceKeys(slice_spec, eval_config, materialize), + ) def _not_empty(elems) -> bool: # pylint: disable=invalid-name - if hasattr(elems, '__array__'): - return elems.size > 0 - return bool(elems) + if hasattr(elems, "__array__"): + return elems.size > 0 + return bool(elems) @beam.typehints.with_input_types(types.Extracts, List[slicer.SingleSliceSpec]) @beam.typehints.with_output_types(types.Extracts) class ExtractSliceKeysFn(beam.DoFn): - """A DoFn that extracts slice keys that apply per example.""" - - def __init__( - self, eval_config: Optional[config_pb2.EvalConfig], materialize: bool - ): - self._eval_config = eval_config - self._materialize = materialize - self._duplicate_slice_keys_counter = beam.metrics.Metrics.counter( - constants.METRICS_NAMESPACE, 'num_examples_with_duplicate_slice_keys' - ) + """A DoFn that extracts slice keys that apply per example.""" - def process( - self, element: types.Extracts, slice_spec: List[slicer.SingleSliceSpec] - ) -> List[types.Extracts]: - # Slice on transformed features if available. - features_dicts = [] - if ( - constants.TRANSFORMED_FEATURES_KEY in element - and element[constants.TRANSFORMED_FEATURES_KEY] is not None - ): - transformed_features = element[constants.TRANSFORMED_FEATURES_KEY] - # If only one model, the output is stored without keying on model name. - if not self._eval_config or len(self._eval_config.model_specs) == 1: - features_dicts.append(transformed_features) - else: - # Search for slices in each model's transformed features output. - for spec in self._eval_config.model_specs: - if spec.name in transformed_features: - features_dicts.append(transformed_features[spec.name]) - # Search for slices first in transformed features (if any). If a match is - # not found there then search in raw features. - slice_keys = list( - slicer.get_slices_for_features_dicts( - features_dicts, util.get_features_from_extracts(element), slice_spec + def __init__(self, eval_config: Optional[config_pb2.EvalConfig], materialize: bool): + self._eval_config = eval_config + self._materialize = materialize + self._duplicate_slice_keys_counter = beam.metrics.Metrics.counter( + constants.METRICS_NAMESPACE, "num_examples_with_duplicate_slice_keys" ) - ) - # If SLICE_KEY_TYPES_KEY already exists, that means the - # SqlSliceKeyExtractor has generated some slice keys. We need to add - # them to current slice_keys list. - if constants.SLICE_KEY_TYPES_KEY in element and _not_empty( - element[constants.SLICE_KEY_TYPES_KEY] - ): - slice_keys.extend(element[constants.SLICE_KEY_TYPES_KEY]) + def process( + self, element: types.Extracts, slice_spec: List[slicer.SingleSliceSpec] + ) -> List[types.Extracts]: + # Slice on transformed features if available. + features_dicts = [] + if ( + constants.TRANSFORMED_FEATURES_KEY in element + and element[constants.TRANSFORMED_FEATURES_KEY] is not None + ): + transformed_features = element[constants.TRANSFORMED_FEATURES_KEY] + # If only one model, the output is stored without keying on model name. + if not self._eval_config or len(self._eval_config.model_specs) == 1: + features_dicts.append(transformed_features) + else: + # Search for slices in each model's transformed features output. + for spec in self._eval_config.model_specs: + if spec.name in transformed_features: + features_dicts.append(transformed_features[spec.name]) + # Search for slices first in transformed features (if any). If a match is + # not found there then search in raw features. + slice_keys = list( + slicer.get_slices_for_features_dicts( + features_dicts, util.get_features_from_extracts(element), slice_spec + ) + ) - unique_slice_keys = list(set(slice_keys)) - if len(slice_keys) != len(unique_slice_keys): - self._duplicate_slice_keys_counter.inc() + # If SLICE_KEY_TYPES_KEY already exists, that means the + # SqlSliceKeyExtractor has generated some slice keys. We need to add + # them to current slice_keys list. + if constants.SLICE_KEY_TYPES_KEY in element and _not_empty( + element[constants.SLICE_KEY_TYPES_KEY] + ): + slice_keys.extend(element[constants.SLICE_KEY_TYPES_KEY]) - # Make a a shallow copy, so we don't mutate the original. - element_copy = copy.copy(element) + unique_slice_keys = list(set(slice_keys)) + if len(slice_keys) != len(unique_slice_keys): + self._duplicate_slice_keys_counter.inc() - element_copy[constants.SLICE_KEY_TYPES_KEY] = ( - slicer.slice_keys_to_numpy_array(unique_slice_keys) - ) - # Add a list of stringified slice keys to be materialized to output table. - if self._materialize: - element_copy[constants.SLICE_KEYS_KEY] = types.MaterializedColumn( - name=constants.SLICE_KEYS_KEY, - value=( - list( - slicer.stringify_slice_key(x).encode('utf-8') - for x in unique_slice_keys - ) - ), - ) - return [element_copy] + # Make a a shallow copy, so we don't mutate the original. + element_copy = copy.copy(element) + + element_copy[constants.SLICE_KEY_TYPES_KEY] = slicer.slice_keys_to_numpy_array( + unique_slice_keys + ) + # Add a list of stringified slice keys to be materialized to output table. + if self._materialize: + element_copy[constants.SLICE_KEYS_KEY] = types.MaterializedColumn( + name=constants.SLICE_KEYS_KEY, + value=( + list( + slicer.stringify_slice_key(x).encode("utf-8") + for x in unique_slice_keys + ) + ), + ) + return [element_copy] @beam.ptransform_fn @@ -163,6 +164,6 @@ def ExtractSliceKeys( eval_config: Optional[config_pb2.EvalConfig] = None, materialize: bool = True, ) -> beam.pvalue.PCollection: - return extracts | 'ExtractSliceKeys' >> beam.ParDo( - ExtractSliceKeysFn(eval_config, materialize), slice_spec=slice_spec - ) + return extracts | "ExtractSliceKeys" >> beam.ParDo( + ExtractSliceKeysFn(eval_config, materialize), slice_spec=slice_spec + ) diff --git a/tensorflow_model_analysis/extractors/slice_key_extractor_test.py b/tensorflow_model_analysis/extractors/slice_key_extractor_test.py index 39de56933e..c2e1a75c05 100644 --- a/tensorflow_model_analysis/extractors/slice_key_extractor_test.py +++ b/tensorflow_model_analysis/extractors/slice_key_extractor_test.py @@ -13,11 +13,12 @@ # limitations under the License. """Test for slice_key_extractor.""" -from absl.testing import parameterized import apache_beam as beam -from apache_beam.testing import util import numpy as np import tensorflow as tf +from absl.testing import parameterized +from apache_beam.testing import util + from tensorflow_model_analysis import constants from tensorflow_model_analysis.api import types from tensorflow_model_analysis.extractors import slice_key_extractor @@ -27,296 +28,307 @@ def make_features_dict(features_dict): - result = {} - for key, value in features_dict.items(): - result[key] = {'node': np.array(value)} - return result + result = {} + for key, value in features_dict.items(): + result[key] = {"node": np.array(value)} + return result def create_fpls(): - fpl1 = types.FeaturesPredictionsLabels( - input_ref=0, - features=make_features_dict( - {'gender': ['f'], 'age': [13], 'interest': ['cars']} - ), - predictions=make_features_dict({ - 'kb': [1], - }), - labels=make_features_dict({'ad_risk_score': [0]}), - ) - fpl2 = types.FeaturesPredictionsLabels( - input_ref=0, - features=make_features_dict( - {'gender': ['m'], 'age': [10], 'interest': ['cars']} - ), - predictions=make_features_dict({ - 'kb': [1], - }), - labels=make_features_dict({'ad_risk_score': [0]}), - ) - return [fpl1, fpl2] + fpl1 = types.FeaturesPredictionsLabels( + input_ref=0, + features=make_features_dict( + {"gender": ["f"], "age": [13], "interest": ["cars"]} + ), + predictions=make_features_dict( + { + "kb": [1], + } + ), + labels=make_features_dict({"ad_risk_score": [0]}), + ) + fpl2 = types.FeaturesPredictionsLabels( + input_ref=0, + features=make_features_dict( + {"gender": ["m"], "age": [10], "interest": ["cars"]} + ), + predictions=make_features_dict( + { + "kb": [1], + } + ), + labels=make_features_dict({"ad_risk_score": [0]}), + ) + return [fpl1, fpl2] def wrap_fpl(fpl): - return { - constants.INPUT_KEY: fpl, - constants.FEATURES_PREDICTIONS_LABELS_KEY: fpl, - } + return { + constants.INPUT_KEY: fpl, + constants.FEATURES_PREDICTIONS_LABELS_KEY: fpl, + } class SliceTest(test_util.TensorflowModelAnalysisTest, parameterized.TestCase): - - @parameterized.named_parameters( - ( - 'features_only', - [''], - [ - { - constants.FEATURES_KEY: make_features_dict( - {'gender': ['m'], 'age': [10], 'interest': ['cars']} - ) - }, - { - constants.FEATURES_KEY: make_features_dict( - {'gender': ['f'], 'age': [12], 'interest': ['cars']} - ) - }, - ], - [slicer.SingleSliceSpec(columns=['gender'])], - [[(('gender', 'm'),)], [(('gender', 'f'),)]], - ), - ( - 'duplicate_feature_keys', - [''], - [ - { - constants.FEATURES_KEY: make_features_dict( - {'gender': ['m'], 'age': [10], 'interest': ['cars']} - ) - }, - { - constants.FEATURES_KEY: make_features_dict( - {'gender': ['f'], 'age': [12], 'interest': ['cars']} - ) - }, - ], - [ - slicer.SingleSliceSpec(columns=['gender']), - slicer.SingleSliceSpec(columns=['gender']), - ], - [[(('gender', 'm'),)], [(('gender', 'f'),)]], - ), - ( - 'transformed_features', - [''], - [ - { - constants.FEATURES_KEY: make_features_dict( - {'gender': ['m'], 'age': [10], 'interest': ['cars']} - ), - constants.TRANSFORMED_FEATURES_KEY: make_features_dict( - {'gender': ['m'], 'age': [10], 'interest': ['boats']} - ), - }, - { - constants.FEATURES_KEY: make_features_dict( - {'gender': ['f'], 'age': [12], 'interest': ['cars']} - ), - constants.TRANSFORMED_FEATURES_KEY: make_features_dict( - {'gender': ['m'], 'age': [10], 'interest': ['planes']} - ), - }, - ], - [slicer.SingleSliceSpec(columns=['interest'])], - [[(('interest', 'boats'),)], [(('interest', 'planes'),)]], - ), - ( - 'missing_features', - [''], - [ - { - constants.TRANSFORMED_FEATURES_KEY: make_features_dict( - {'gender': ['m'], 'age': [10], 'interest': ['boats']} - ) - }, - { - constants.TRANSFORMED_FEATURES_KEY: make_features_dict( - {'gender': ['m'], 'age': [10], 'interest': ['planes']} - ) - }, - ], - [slicer.SingleSliceSpec(columns=['interest'])], - [[(('interest', 'boats'),)], [(('interest', 'planes'),)]], - ), - ( - 'transformed_features_with_multiple_models', - ['model1', 'model2'], - [ - { - constants.FEATURES_KEY: make_features_dict( - {'gender': ['m'], 'age': [10], 'interest': ['cars']} - ), - constants.TRANSFORMED_FEATURES_KEY: { - 'model1': make_features_dict({'interest': ['boats']}), - 'model2': make_features_dict({'interest': ['planes']}), - }, - }, - { - constants.FEATURES_KEY: make_features_dict( - {'gender': ['f'], 'age': [12], 'interest': ['planes']} - ), - constants.TRANSFORMED_FEATURES_KEY: { - 'model1': make_features_dict({'interest': ['trains']}), - 'model2': make_features_dict({'interest': ['planes']}), - }, - }, - ], - [slicer.SingleSliceSpec(columns=['interest'])], - [ - [(('interest', 'boats'),), (('interest', 'planes'),)], - [(('interest', 'planes'),), (('interest', 'trains'),)], - ], - ), - ( - 'features_with_batched_slices_keys', - [''], - [ - { - constants.FEATURES_KEY: make_features_dict( - {'gender': ['m'], 'age': [10], 'interest': ['cars']} - ), - constants.SLICE_KEY_TYPES_KEY: [( - ('age', '10'), - ('interest', 'cars'), - )], - }, - { - constants.FEATURES_KEY: make_features_dict( - {'gender': ['f'], 'age': [12], 'interest': ['cars']} - ), - constants.SLICE_KEY_TYPES_KEY: [( - ('age', '12'), - ('interest', 'cars'), - )], - }, - ], - [slicer.SingleSliceSpec(columns=['gender'])], - [ - [ - ( - ('age', '10'), - ('interest', 'cars'), - ), - (('gender', 'm'),), - ], - [ - ( - ('age', '12'), - ('interest', 'cars'), - ), - (('gender', 'f'),), - ], - ], - ), - ) - def testSliceKeys(self, model_names, extracts, slice_specs, expected_slices): - eval_config = config_pb2.EvalConfig( - model_specs=[config_pb2.ModelSpec(name=name) for name in model_names] + @parameterized.named_parameters( + ( + "features_only", + [""], + [ + { + constants.FEATURES_KEY: make_features_dict( + {"gender": ["m"], "age": [10], "interest": ["cars"]} + ) + }, + { + constants.FEATURES_KEY: make_features_dict( + {"gender": ["f"], "age": [12], "interest": ["cars"]} + ) + }, + ], + [slicer.SingleSliceSpec(columns=["gender"])], + [[(("gender", "m"),)], [(("gender", "f"),)]], + ), + ( + "duplicate_feature_keys", + [""], + [ + { + constants.FEATURES_KEY: make_features_dict( + {"gender": ["m"], "age": [10], "interest": ["cars"]} + ) + }, + { + constants.FEATURES_KEY: make_features_dict( + {"gender": ["f"], "age": [12], "interest": ["cars"]} + ) + }, + ], + [ + slicer.SingleSliceSpec(columns=["gender"]), + slicer.SingleSliceSpec(columns=["gender"]), + ], + [[(("gender", "m"),)], [(("gender", "f"),)]], + ), + ( + "transformed_features", + [""], + [ + { + constants.FEATURES_KEY: make_features_dict( + {"gender": ["m"], "age": [10], "interest": ["cars"]} + ), + constants.TRANSFORMED_FEATURES_KEY: make_features_dict( + {"gender": ["m"], "age": [10], "interest": ["boats"]} + ), + }, + { + constants.FEATURES_KEY: make_features_dict( + {"gender": ["f"], "age": [12], "interest": ["cars"]} + ), + constants.TRANSFORMED_FEATURES_KEY: make_features_dict( + {"gender": ["m"], "age": [10], "interest": ["planes"]} + ), + }, + ], + [slicer.SingleSliceSpec(columns=["interest"])], + [[(("interest", "boats"),)], [(("interest", "planes"),)]], + ), + ( + "missing_features", + [""], + [ + { + constants.TRANSFORMED_FEATURES_KEY: make_features_dict( + {"gender": ["m"], "age": [10], "interest": ["boats"]} + ) + }, + { + constants.TRANSFORMED_FEATURES_KEY: make_features_dict( + {"gender": ["m"], "age": [10], "interest": ["planes"]} + ) + }, + ], + [slicer.SingleSliceSpec(columns=["interest"])], + [[(("interest", "boats"),)], [(("interest", "planes"),)]], + ), + ( + "transformed_features_with_multiple_models", + ["model1", "model2"], + [ + { + constants.FEATURES_KEY: make_features_dict( + {"gender": ["m"], "age": [10], "interest": ["cars"]} + ), + constants.TRANSFORMED_FEATURES_KEY: { + "model1": make_features_dict({"interest": ["boats"]}), + "model2": make_features_dict({"interest": ["planes"]}), + }, + }, + { + constants.FEATURES_KEY: make_features_dict( + {"gender": ["f"], "age": [12], "interest": ["planes"]} + ), + constants.TRANSFORMED_FEATURES_KEY: { + "model1": make_features_dict({"interest": ["trains"]}), + "model2": make_features_dict({"interest": ["planes"]}), + }, + }, + ], + [slicer.SingleSliceSpec(columns=["interest"])], + [ + [(("interest", "boats"),), (("interest", "planes"),)], + [(("interest", "planes"),), (("interest", "trains"),)], + ], + ), + ( + "features_with_batched_slices_keys", + [""], + [ + { + constants.FEATURES_KEY: make_features_dict( + {"gender": ["m"], "age": [10], "interest": ["cars"]} + ), + constants.SLICE_KEY_TYPES_KEY: [ + ( + ("age", "10"), + ("interest", "cars"), + ) + ], + }, + { + constants.FEATURES_KEY: make_features_dict( + {"gender": ["f"], "age": [12], "interest": ["cars"]} + ), + constants.SLICE_KEY_TYPES_KEY: [ + ( + ("age", "12"), + ("interest", "cars"), + ) + ], + }, + ], + [slicer.SingleSliceSpec(columns=["gender"])], + [ + [ + ( + ("age", "10"), + ("interest", "cars"), + ), + (("gender", "m"),), + ], + [ + ( + ("age", "12"), + ("interest", "cars"), + ), + (("gender", "f"),), + ], + ], + ), ) - with beam.Pipeline() as pipeline: - slice_keys_extracts = ( - pipeline - | 'CreateTestInput' >> beam.Create(extracts) - | 'ExtractSlices' - >> slice_key_extractor.ExtractSliceKeys( - slice_spec=slice_specs, eval_config=eval_config - ) - ) - - def check_result(got): - try: - self.assertLen(got, 2) - got_results = [] - for item in got: - self.assertIn(constants.SLICE_KEY_TYPES_KEY, item) - got_results.append(sorted(item[constants.SLICE_KEY_TYPES_KEY])) - self.assertCountEqual(got_results, expected_slices) - except AssertionError as err: - raise util.BeamAssertException(err) + def testSliceKeys(self, model_names, extracts, slice_specs, expected_slices): + eval_config = config_pb2.EvalConfig( + model_specs=[config_pb2.ModelSpec(name=name) for name in model_names] + ) + with beam.Pipeline() as pipeline: + slice_keys_extracts = ( + pipeline + | "CreateTestInput" >> beam.Create(extracts) + | "ExtractSlices" + >> slice_key_extractor.ExtractSliceKeys( + slice_spec=slice_specs, eval_config=eval_config + ) + ) - util.assert_that(slice_keys_extracts, check_result) + def check_result(got): + try: + self.assertLen(got, 2) + got_results = [] + for item in got: + self.assertIn(constants.SLICE_KEY_TYPES_KEY, item) + got_results.append(sorted(item[constants.SLICE_KEY_TYPES_KEY])) + self.assertCountEqual(got_results, expected_slices) + except AssertionError as err: + raise util.BeamAssertException(err) - def testLegacySliceKeys(self): - with beam.Pipeline() as pipeline: - fpls = create_fpls() - slice_keys_extracts = ( - pipeline - | 'CreateTestInput' >> beam.Create(fpls) - | 'WrapFpls' >> beam.Map(wrap_fpl) - | 'ExtractSlices' - >> slice_key_extractor.ExtractSliceKeys([ - slicer.SingleSliceSpec(), - slicer.SingleSliceSpec(columns=['gender']), - ]) - ) + util.assert_that(slice_keys_extracts, check_result) - def check_result(got): - try: - self.assertLen(got, 2) - expected_results = sorted( - [[(), (('gender', 'f'),)], [(), (('gender', 'm'),)]] - ) - got_results = [] - for item in got: - self.assertIn(constants.SLICE_KEY_TYPES_KEY, item) - got_results.append(sorted(item[constants.SLICE_KEY_TYPES_KEY])) - self.assertCountEqual(got_results, expected_results) - except AssertionError as err: - raise util.BeamAssertException(err) + def testLegacySliceKeys(self): + with beam.Pipeline() as pipeline: + fpls = create_fpls() + slice_keys_extracts = ( + pipeline + | "CreateTestInput" >> beam.Create(fpls) + | "WrapFpls" >> beam.Map(wrap_fpl) + | "ExtractSlices" + >> slice_key_extractor.ExtractSliceKeys( + [ + slicer.SingleSliceSpec(), + slicer.SingleSliceSpec(columns=["gender"]), + ] + ) + ) - util.assert_that(slice_keys_extracts, check_result) + def check_result(got): + try: + self.assertLen(got, 2) + expected_results = sorted( + [[(), (("gender", "f"),)], [(), (("gender", "m"),)]] + ) + got_results = [] + for item in got: + self.assertIn(constants.SLICE_KEY_TYPES_KEY, item) + got_results.append(sorted(item[constants.SLICE_KEY_TYPES_KEY])) + self.assertCountEqual(got_results, expected_results) + except AssertionError as err: + raise util.BeamAssertException(err) - def testMaterializedLegacySliceKeys(self): - with beam.Pipeline() as pipeline: - fpls = create_fpls() - slice_keys_extracts = ( - pipeline - | 'CreateTestInput' >> beam.Create(fpls) - | 'WrapFpls' >> beam.Map(wrap_fpl) - | 'ExtractSlices' - >> slice_key_extractor.ExtractSliceKeys( - [ - slicer.SingleSliceSpec(), - slicer.SingleSliceSpec(columns=['gender']), - ], - materialize=True, - ) - ) + util.assert_that(slice_keys_extracts, check_result) - def check_result(got): - try: - self.assertLen(got, 2) - expected_results = [ - types.MaterializedColumn( - name=constants.SLICE_KEYS_KEY, value=[b'Overall', b'gender:f'] - ), - types.MaterializedColumn( - name=constants.SLICE_KEYS_KEY, value=[b'Overall', b'gender:m'] - ), - ] - got_results = [] - for item in got: - self.assertIn(constants.SLICE_KEYS_KEY, item) - got_result = item[constants.SLICE_KEYS_KEY] - got_results.append( - got_result._replace(value=sorted(got_result.value)) + def testMaterializedLegacySliceKeys(self): + with beam.Pipeline() as pipeline: + fpls = create_fpls() + slice_keys_extracts = ( + pipeline + | "CreateTestInput" >> beam.Create(fpls) + | "WrapFpls" >> beam.Map(wrap_fpl) + | "ExtractSlices" + >> slice_key_extractor.ExtractSliceKeys( + [ + slicer.SingleSliceSpec(), + slicer.SingleSliceSpec(columns=["gender"]), + ], + materialize=True, + ) ) - self.assertCountEqual(got_results, expected_results) - except AssertionError as err: - raise util.BeamAssertException(err) - util.assert_that(slice_keys_extracts, check_result) + def check_result(got): + try: + self.assertLen(got, 2) + expected_results = [ + types.MaterializedColumn( + name=constants.SLICE_KEYS_KEY, + value=[b"Overall", b"gender:f"], + ), + types.MaterializedColumn( + name=constants.SLICE_KEYS_KEY, + value=[b"Overall", b"gender:m"], + ), + ] + got_results = [] + for item in got: + self.assertIn(constants.SLICE_KEYS_KEY, item) + got_result = item[constants.SLICE_KEYS_KEY] + got_results.append( + got_result._replace(value=sorted(got_result.value)) + ) + self.assertCountEqual(got_results, expected_results) + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(slice_keys_extracts, check_result) -if __name__ == '__main__': - tf.test.main() +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow_model_analysis/extractors/sql_slice_key_extractor.py b/tensorflow_model_analysis/extractors/sql_slice_key_extractor.py index e4a301e33d..372c1d8694 100644 --- a/tensorflow_model_analysis/extractors/sql_slice_key_extractor.py +++ b/tensorflow_model_analysis/extractors/sql_slice_key_extractor.py @@ -20,150 +20,150 @@ import apache_beam as beam import pyarrow as pa import tensorflow as tf +from tfx_bsl.arrow import sql_util +from tfx_bsl.tfxio import tensor_to_arrow + from tensorflow_model_analysis import constants from tensorflow_model_analysis.api import types from tensorflow_model_analysis.extractors import extractor from tensorflow_model_analysis.proto import config_pb2 from tensorflow_model_analysis.slicer import slicer_lib from tensorflow_model_analysis.utils import util -from tfx_bsl.arrow import sql_util -from tfx_bsl.tfxio import tensor_to_arrow -_TF_MAJOR_VERSION = int(tf.version.VERSION.split('.')[0]) +_TF_MAJOR_VERSION = int(tf.version.VERSION.split(".")[0]) -_SQL_SLICE_KEY_EXTRACTOR_STAGE_NAME = 'ExtractSqlSliceKeys' +_SQL_SLICE_KEY_EXTRACTOR_STAGE_NAME = "ExtractSqlSliceKeys" def SqlSliceKeyExtractor( eval_config: config_pb2.EvalConfig, ) -> extractor.Extractor: - """Creates an extractor for sql slice keys. - - This extractor extracts slices keys in a batch based on the SQL statement in - the eval config. - - Args: - eval_config: EvalConfig containing slicing_specs specifying the slices to - slice the data into. - - Returns: - Extractor for extracting slice keys in batch. - """ - # pylint: disable=no-value-for-parameter - return extractor.Extractor( - stage_name=_SQL_SLICE_KEY_EXTRACTOR_STAGE_NAME, - ptransform=_ExtractSqlSliceKey(eval_config), - ) + """Creates an extractor for sql slice keys. + + This extractor extracts slices keys in a batch based on the SQL statement in + the eval config. + + Args: + ---- + eval_config: EvalConfig containing slicing_specs specifying the slices to + slice the data into. + + Returns: + ------- + Extractor for extracting slice keys in batch. + """ + # pylint: disable=no-value-for-parameter + return extractor.Extractor( + stage_name=_SQL_SLICE_KEY_EXTRACTOR_STAGE_NAME, + ptransform=_ExtractSqlSliceKey(eval_config), + ) @beam.typehints.with_input_types(types.Extracts) @beam.typehints.with_output_types(types.Extracts) class ExtractSqlSliceKeyFn(beam.DoFn): - """A DoFn that extracts slice keys in batch.""" + """A DoFn that extracts slice keys in batch.""" - def __init__(self, eval_config: config_pb2.EvalConfig): - self._eval_config = eval_config - self._sqls = [ - """ + def __init__(self, eval_config: config_pb2.EvalConfig): + self._eval_config = eval_config + self._sqls = [ + f""" SELECT ARRAY( - {} + {spec.slice_keys_sql} ) as slice_key - FROM Examples as example;""".format(spec.slice_keys_sql) - for spec in eval_config.slicing_specs - if spec.slice_keys_sql - ] - self._sql_slicer_num_record_batch_schemas = ( - beam.metrics.Metrics.distribution( - constants.METRICS_NAMESPACE, 'sql_slicer_num_record_batch_schemas' + FROM Examples as example;""" + for spec in eval_config.slicing_specs + if spec.slice_keys_sql + ] + self._sql_slicer_num_record_batch_schemas = beam.metrics.Metrics.distribution( + constants.METRICS_NAMESPACE, "sql_slicer_num_record_batch_schemas" ) - ) - def setup(self): - - def _GenerateQueries( - schema: pa.Schema, - ) -> List[sql_util.RecordBatchSQLSliceQuery]: - result = [] - for sql in self._sqls: - try: - result.append(sql_util.RecordBatchSQLSliceQuery(sql, schema)) - except Exception as e: - raise RuntimeError(f'Failed to parse sql:\n\n{sql}') from e - return result - - # A cache for compiled sql queries, keyed by record batch schemas. - # This way the extractor can work with record batches of different schemas, - # which is legit but uncommon. - self._cached_queries = functools.lru_cache(maxsize=3)(_GenerateQueries) - - def process(self, batched_extract: types.Extracts) -> List[types.Extracts]: - features = batched_extract[constants.FEATURES_KEY] - # Slice on transformed features if available. - if ( - constants.TRANSFORMED_FEATURES_KEY in batched_extract - and batched_extract[constants.TRANSFORMED_FEATURES_KEY] is not None - ): - transformed_features = batched_extract[constants.TRANSFORMED_FEATURES_KEY] - # If only one model, the output is stored without keying on model name. - if not self._eval_config or len(self._eval_config.model_specs) == 1: - features.update(transformed_features) - else: - # Models listed earlier have precedence in feature lookup. - for spec in reversed(self._eval_config.model_specs): - if spec.name in transformed_features: - features.update(transformed_features[spec.name]) - - tensors = util.to_tensorflow_tensors(features) - tensor_specs = util.infer_tensor_specs(tensors) - - if _TF_MAJOR_VERSION < 2 or not tf.executing_eagerly(): - # TODO(b/228456048): TFX-BSL doesn't support passing tensorflow tensors - # for non-sparse/ragged values in TF 1.x (i.e. it only accepts np.ndarray - # for dense) so we need to convert dense tensors to numpy. - sess = tf.compat.v1.Session() - - def _convert_dense_to_numpy(values): # pylint: disable=invalid-name - if isinstance(values, Mapping): - for k, v in values.items(): - if isinstance(v, Mapping): - values[k] = _convert_dense_to_numpy(v) - elif isinstance(v, tf.Tensor): - values[k] = v.eval(session=sess) - return values - - tensors = _convert_dense_to_numpy(tensors) - - converter = tensor_to_arrow.TensorsToRecordBatchConverter(tensor_specs) - record_batch = converter.convert(tensors) - sql_slice_keys = [[] for _ in range(record_batch.num_rows)] - - for query in self._cached_queries(record_batch.schema): - # Example of result with batch size = 3: - # result = [[[('feature', 'value_1')]], - # [[('feature', 'value_2')]], - # [] - # ] - result = query.Execute(record_batch) - for row_index, row_result in enumerate(result): - sql_slice_keys[row_index].extend([tuple(s) for s in row_result]) - - # convert sql_slice_keys into a VarLenTensorValue where each row has dtype - # object. - dense_rows = [] - for row_slice_keys in sql_slice_keys: - dense_rows.append(slicer_lib.slice_keys_to_numpy_array(row_slice_keys)) - varlen_sql_slice_keys = types.VarLenTensorValue.from_dense_rows(dense_rows) - - # Make a a shallow copy, so we don't mutate the original. - batched_extract_copy = copy.copy(batched_extract) - batched_extract_copy[constants.SLICE_KEY_TYPES_KEY] = varlen_sql_slice_keys - - self._sql_slicer_num_record_batch_schemas.update( - self._cached_queries.cache_info().currsize - ) + def setup(self): + def _GenerateQueries( + schema: pa.Schema, + ) -> List[sql_util.RecordBatchSQLSliceQuery]: + result = [] + for sql in self._sqls: + try: + result.append(sql_util.RecordBatchSQLSliceQuery(sql, schema)) + except Exception as e: + raise RuntimeError(f"Failed to parse sql:\n\n{sql}") from e + return result + + # A cache for compiled sql queries, keyed by record batch schemas. + # This way the extractor can work with record batches of different schemas, + # which is legit but uncommon. + self._cached_queries = functools.lru_cache(maxsize=3)(_GenerateQueries) + + def process(self, batched_extract: types.Extracts) -> List[types.Extracts]: + features = batched_extract[constants.FEATURES_KEY] + # Slice on transformed features if available. + if ( + constants.TRANSFORMED_FEATURES_KEY in batched_extract + and batched_extract[constants.TRANSFORMED_FEATURES_KEY] is not None + ): + transformed_features = batched_extract[constants.TRANSFORMED_FEATURES_KEY] + # If only one model, the output is stored without keying on model name. + if not self._eval_config or len(self._eval_config.model_specs) == 1: + features.update(transformed_features) + else: + # Models listed earlier have precedence in feature lookup. + for spec in reversed(self._eval_config.model_specs): + if spec.name in transformed_features: + features.update(transformed_features[spec.name]) + + tensors = util.to_tensorflow_tensors(features) + tensor_specs = util.infer_tensor_specs(tensors) + + if _TF_MAJOR_VERSION < 2 or not tf.executing_eagerly(): + # TODO(b/228456048): TFX-BSL doesn't support passing tensorflow tensors + # for non-sparse/ragged values in TF 1.x (i.e. it only accepts np.ndarray + # for dense) so we need to convert dense tensors to numpy. + sess = tf.compat.v1.Session() + + def _convert_dense_to_numpy(values): # pylint: disable=invalid-name + if isinstance(values, Mapping): + for k, v in values.items(): + if isinstance(v, Mapping): + values[k] = _convert_dense_to_numpy(v) + elif isinstance(v, tf.Tensor): + values[k] = v.eval(session=sess) + return values + + tensors = _convert_dense_to_numpy(tensors) + + converter = tensor_to_arrow.TensorsToRecordBatchConverter(tensor_specs) + record_batch = converter.convert(tensors) + sql_slice_keys = [[] for _ in range(record_batch.num_rows)] + + for query in self._cached_queries(record_batch.schema): + # Example of result with batch size = 3: + # result = [[[('feature', 'value_1')]], + # [[('feature', 'value_2')]], + # [] + # ] + result = query.Execute(record_batch) + for row_index, row_result in enumerate(result): + sql_slice_keys[row_index].extend([tuple(s) for s in row_result]) + + # convert sql_slice_keys into a VarLenTensorValue where each row has dtype + # object. + dense_rows = [] + for row_slice_keys in sql_slice_keys: + dense_rows.append(slicer_lib.slice_keys_to_numpy_array(row_slice_keys)) + varlen_sql_slice_keys = types.VarLenTensorValue.from_dense_rows(dense_rows) + + # Make a a shallow copy, so we don't mutate the original. + batched_extract_copy = copy.copy(batched_extract) + batched_extract_copy[constants.SLICE_KEY_TYPES_KEY] = varlen_sql_slice_keys + + self._sql_slicer_num_record_batch_schemas.update( + self._cached_queries.cache_info().currsize + ) - return [batched_extract_copy] + return [batched_extract_copy] @beam.ptransform_fn @@ -172,4 +172,4 @@ def _convert_dense_to_numpy(values): # pylint: disable=invalid-name def _ExtractSqlSliceKey( extracts: beam.pvalue.PCollection, eval_config: config_pb2.EvalConfig ) -> beam.pvalue.PCollection: - return extracts | beam.ParDo(ExtractSqlSliceKeyFn(eval_config)) + return extracts | beam.ParDo(ExtractSqlSliceKeyFn(eval_config)) diff --git a/tensorflow_model_analysis/extractors/sql_slice_key_extractor_test.py b/tensorflow_model_analysis/extractors/sql_slice_key_extractor_test.py index e429bef4c8..0508538f69 100644 --- a/tensorflow_model_analysis/extractors/sql_slice_key_extractor_test.py +++ b/tensorflow_model_analysis/extractors/sql_slice_key_extractor_test.py @@ -14,22 +14,23 @@ """Tests for tensorflow_model_analysis.google.extractors.sql_slice_key_extractor.""" import apache_beam as beam -from apache_beam.testing import util import numpy as np import pyarrow as pa import tensorflow as tf +from apache_beam.testing import util +from google.protobuf import text_format +from tensorflow_metadata.proto.v0 import schema_pb2 +from tfx_bsl.tfxio import tf_example_record + from tensorflow_model_analysis import constants -from tensorflow_model_analysis.api import model_eval_lib -from tensorflow_model_analysis.api import types -from tensorflow_model_analysis.extractors import features_extractor -from tensorflow_model_analysis.extractors import sql_slice_key_extractor +from tensorflow_model_analysis.api import model_eval_lib, types +from tensorflow_model_analysis.extractors import ( + features_extractor, + sql_slice_key_extractor, +) from tensorflow_model_analysis.proto import config_pb2 from tensorflow_model_analysis.slicer import slicer_lib from tensorflow_model_analysis.utils import test_util -from tfx_bsl.tfxio import tf_example_record - -from google.protobuf import text_format -from tensorflow_metadata.proto.v0 import schema_pb2 _SCHEMA = text_format.Parse( """ @@ -51,373 +52,398 @@ class SqlSliceKeyExtractorTest(test_util.TensorflowModelAnalysisTest): - - def testSqlSliceKeyExtractor(self): - eval_config = config_pb2.EvalConfig( - slicing_specs=[config_pb2.SlicingSpec(slice_keys_sql=""" + def testSqlSliceKeyExtractor(self): + eval_config = config_pb2.EvalConfig( + slicing_specs=[ + config_pb2.SlicingSpec( + slice_keys_sql=""" SELECT STRUCT(fixed_string) FROM example.fixed_string, example.fixed_int WHERE fixed_int = 1 - """)] - ) - feature_extractor = features_extractor.FeaturesExtractor( - eval_config=eval_config - ) - slice_key_extractor = sql_slice_key_extractor.SqlSliceKeyExtractor( - eval_config - ) - - tfx_io = tf_example_record.TFExampleBeamRecord( - physical_format='inmem', - telemetry_descriptors=['test', 'component'], - schema=_SCHEMA, - raw_record_column_name=constants.ARROW_INPUT_COLUMN, - ) - examples = [ - self._makeExample( - fixed_int=1, fixed_float=1.0, fixed_string='fixed_string1' - ), - self._makeExample( - fixed_int=1, fixed_float=1.0, fixed_string='fixed_string2' - ), - self._makeExample( - fixed_int=2, fixed_float=0.0, fixed_string='fixed_string3' - ), - ] - - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'Create' - >> beam.Create( - [e.SerializeToString() for e in examples], reshuffle=False - ) - | 'BatchExamples' >> tfx_io.BeamSource(batch_size=3) - | 'InputsToExtracts' >> model_eval_lib.BatchedInputsToExtracts() - | feature_extractor.stage_name >> feature_extractor.ptransform - | slice_key_extractor.stage_name >> slice_key_extractor.ptransform - ) - - # pylint: enable=no-value-for-parameter - - def check_result(got): - try: - self.assertLen(got, 1) - np.testing.assert_equal( - got[0][constants.SLICE_KEY_TYPES_KEY], - types.VarLenTensorValue.from_dense_rows([ - slicer_lib.slice_keys_to_numpy_array([( - ('fixed_string', 'fixed_string1'), - )]), - slicer_lib.slice_keys_to_numpy_array([( - ('fixed_string', 'fixed_string2'), - )]), - np.array([]), - ]), - ) - - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(result, check_result) - - def testSqlSliceKeyExtractorWithTransformedFeatures(self): - eval_config = config_pb2.EvalConfig( - model_specs=[ - config_pb2.ModelSpec(name='model1'), - config_pb2.ModelSpec(name='model2'), - ], - slicing_specs=[config_pb2.SlicingSpec(slice_keys_sql=""" + """ + ) + ] + ) + feature_extractor = features_extractor.FeaturesExtractor( + eval_config=eval_config + ) + slice_key_extractor = sql_slice_key_extractor.SqlSliceKeyExtractor(eval_config) + + tfx_io = tf_example_record.TFExampleBeamRecord( + physical_format="inmem", + telemetry_descriptors=["test", "component"], + schema=_SCHEMA, + raw_record_column_name=constants.ARROW_INPUT_COLUMN, + ) + examples = [ + self._makeExample( + fixed_int=1, fixed_float=1.0, fixed_string="fixed_string1" + ), + self._makeExample( + fixed_int=1, fixed_float=1.0, fixed_string="fixed_string2" + ), + self._makeExample( + fixed_int=2, fixed_float=0.0, fixed_string="fixed_string3" + ), + ] + + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "Create" + >> beam.Create( + [e.SerializeToString() for e in examples], reshuffle=False + ) + | "BatchExamples" >> tfx_io.BeamSource(batch_size=3) + | "InputsToExtracts" >> model_eval_lib.BatchedInputsToExtracts() + | feature_extractor.stage_name >> feature_extractor.ptransform + | slice_key_extractor.stage_name >> slice_key_extractor.ptransform + ) + + # pylint: enable=no-value-for-parameter + + def check_result(got): + try: + self.assertLen(got, 1) + np.testing.assert_equal( + got[0][constants.SLICE_KEY_TYPES_KEY], + types.VarLenTensorValue.from_dense_rows( + [ + slicer_lib.slice_keys_to_numpy_array( + [(("fixed_string", "fixed_string1"),)] + ), + slicer_lib.slice_keys_to_numpy_array( + [(("fixed_string", "fixed_string2"),)] + ), + np.array([]), + ] + ), + ) + + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(result, check_result) + + def testSqlSliceKeyExtractorWithTransformedFeatures(self): + eval_config = config_pb2.EvalConfig( + model_specs=[ + config_pb2.ModelSpec(name="model1"), + config_pb2.ModelSpec(name="model2"), + ], + slicing_specs=[ + config_pb2.SlicingSpec( + slice_keys_sql=""" SELECT STRUCT(fixed_string) FROM example.fixed_string, example.fixed_int WHERE fixed_int = 1 - """)], - ) - slice_key_extractor = sql_slice_key_extractor.SqlSliceKeyExtractor( - eval_config - ) - - extracts = { - constants.FEATURES_KEY: { - 'fixed_int': np.array([1, 1, 2]), - }, - constants.TRANSFORMED_FEATURES_KEY: { - 'model1': { - 'fixed_int': np.array([1, 1, 2]), - 'fixed_float': np.array([1.0, 1.0, 0.0]), - 'fixed_string': np.array( - ['fixed_string1', 'fixed_string2', 'fixed_string3'] - ), + """ + ) + ], + ) + slice_key_extractor = sql_slice_key_extractor.SqlSliceKeyExtractor(eval_config) + + extracts = { + constants.FEATURES_KEY: { + "fixed_int": np.array([1, 1, 2]), }, - 'model2': { - 'fixed_int': np.array([1, 1, 2]), - 'fixed_string': np.array( - ['fixed_string1', 'fixed_string2', 'fixed_string3'] - ), + constants.TRANSFORMED_FEATURES_KEY: { + "model1": { + "fixed_int": np.array([1, 1, 2]), + "fixed_float": np.array([1.0, 1.0, 0.0]), + "fixed_string": np.array( + ["fixed_string1", "fixed_string2", "fixed_string3"] + ), + }, + "model2": { + "fixed_int": np.array([1, 1, 2]), + "fixed_string": np.array( + ["fixed_string1", "fixed_string2", "fixed_string3"] + ), + }, }, - }, - } - - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'CreateTestInput' >> beam.Create([extracts]) - | slice_key_extractor.stage_name >> slice_key_extractor.ptransform - ) - - # pylint: enable=no-value-for-parameter - - def check_result(got): - try: - self.assertLen(got, 1) - np.testing.assert_equal( - got[0][constants.SLICE_KEY_TYPES_KEY], - types.VarLenTensorValue.from_dense_rows([ - slicer_lib.slice_keys_to_numpy_array([( - ('fixed_string', 'fixed_string1'), - )]), - slicer_lib.slice_keys_to_numpy_array([( - ('fixed_string', 'fixed_string2'), - )]), - np.array([]), - ]), - ) - - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(result, check_result) - - def testSqlSliceKeyExtractorWithCrossSlices(self): - eval_config = config_pb2.EvalConfig( - slicing_specs=[config_pb2.SlicingSpec(slice_keys_sql=""" + } + + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "CreateTestInput" >> beam.Create([extracts]) + | slice_key_extractor.stage_name >> slice_key_extractor.ptransform + ) + + # pylint: enable=no-value-for-parameter + + def check_result(got): + try: + self.assertLen(got, 1) + np.testing.assert_equal( + got[0][constants.SLICE_KEY_TYPES_KEY], + types.VarLenTensorValue.from_dense_rows( + [ + slicer_lib.slice_keys_to_numpy_array( + [(("fixed_string", "fixed_string1"),)] + ), + slicer_lib.slice_keys_to_numpy_array( + [(("fixed_string", "fixed_string2"),)] + ), + np.array([]), + ] + ), + ) + + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(result, check_result) + + def testSqlSliceKeyExtractorWithCrossSlices(self): + eval_config = config_pb2.EvalConfig( + slicing_specs=[ + config_pb2.SlicingSpec( + slice_keys_sql=""" SELECT STRUCT(fixed_string, fixed_int) FROM example.fixed_string, example.fixed_int WHERE fixed_int = 1 - """)] - ) - feature_extractor = features_extractor.FeaturesExtractor( - eval_config=eval_config - ) - slice_key_extractor = sql_slice_key_extractor.SqlSliceKeyExtractor( - eval_config - ) - - tfx_io = tf_example_record.TFExampleBeamRecord( - physical_format='inmem', - telemetry_descriptors=['test', 'component'], - schema=_SCHEMA, - raw_record_column_name=constants.ARROW_INPUT_COLUMN, - ) - examples = [ - self._makeExample( - fixed_int=1, fixed_float=1.0, fixed_string='fixed_string1' - ), - self._makeExample( - fixed_int=1, fixed_float=1.0, fixed_string='fixed_string2' - ), - self._makeExample( - fixed_int=2, fixed_float=0.0, fixed_string='fixed_string3' - ), - ] - - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'Create' - >> beam.Create( - [e.SerializeToString() for e in examples], reshuffle=False - ) - | 'BatchExamples' >> tfx_io.BeamSource(batch_size=3) - | 'InputsToExtracts' >> model_eval_lib.BatchedInputsToExtracts() - | feature_extractor.stage_name >> feature_extractor.ptransform - | slice_key_extractor.stage_name >> slice_key_extractor.ptransform - ) - - # pylint: enable=no-value-for-parameter - - def check_result(got): - try: - self.assertLen(got, 1) - np.testing.assert_equal( - got[0][constants.SLICE_KEY_TYPES_KEY], - types.VarLenTensorValue.from_dense_rows([ - slicer_lib.slice_keys_to_numpy_array( - [(('fixed_string', 'fixed_string1'), ('fixed_int', '1'))] - ), - slicer_lib.slice_keys_to_numpy_array( - [(('fixed_string', 'fixed_string2'), ('fixed_int', '1'))] - ), - np.array([]), - ]), - ) - - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(result, check_result) - - def testSqlSliceKeyExtractorWithEmptySqlConfig(self): - eval_config = config_pb2.EvalConfig() - feature_extractor = features_extractor.FeaturesExtractor( - eval_config=eval_config - ) - slice_key_extractor = sql_slice_key_extractor.SqlSliceKeyExtractor( - eval_config - ) - - tfx_io = tf_example_record.TFExampleBeamRecord( - physical_format='inmem', - telemetry_descriptors=['test', 'component'], - schema=_SCHEMA, - raw_record_column_name=constants.ARROW_INPUT_COLUMN, - ) - examples = [ - self._makeExample( - fixed_int=1, fixed_float=1.0, fixed_string='fixed_string1' - ), - self._makeExample( - fixed_int=1, fixed_float=1.0, fixed_string='fixed_string2' - ), - self._makeExample( - fixed_int=2, fixed_float=0.0, fixed_string='fixed_string3' - ), - ] - - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'Create' - >> beam.Create( - [e.SerializeToString() for e in examples], reshuffle=False - ) - | 'BatchExamples' >> tfx_io.BeamSource(batch_size=3) - | 'InputsToExtracts' >> model_eval_lib.BatchedInputsToExtracts() - | feature_extractor.stage_name >> feature_extractor.ptransform - | slice_key_extractor.stage_name >> slice_key_extractor.ptransform - ) - - # pylint: enable=no-value-for-parameter - - def check_result(got): - try: - self.assertLen(got, 1) - np.testing.assert_equal( - got[0][constants.SLICE_KEY_TYPES_KEY], - types.VarLenTensorValue.from_dense_rows( - [np.array([]), np.array([]), np.array([])] - ), - ) - - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(result, check_result) - - def testSqlSliceKeyExtractorWithMultipleSchema(self): - eval_config = config_pb2.EvalConfig( - slicing_specs=[config_pb2.SlicingSpec(slice_keys_sql=""" + """ + ) + ] + ) + feature_extractor = features_extractor.FeaturesExtractor( + eval_config=eval_config + ) + slice_key_extractor = sql_slice_key_extractor.SqlSliceKeyExtractor(eval_config) + + tfx_io = tf_example_record.TFExampleBeamRecord( + physical_format="inmem", + telemetry_descriptors=["test", "component"], + schema=_SCHEMA, + raw_record_column_name=constants.ARROW_INPUT_COLUMN, + ) + examples = [ + self._makeExample( + fixed_int=1, fixed_float=1.0, fixed_string="fixed_string1" + ), + self._makeExample( + fixed_int=1, fixed_float=1.0, fixed_string="fixed_string2" + ), + self._makeExample( + fixed_int=2, fixed_float=0.0, fixed_string="fixed_string3" + ), + ] + + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "Create" + >> beam.Create( + [e.SerializeToString() for e in examples], reshuffle=False + ) + | "BatchExamples" >> tfx_io.BeamSource(batch_size=3) + | "InputsToExtracts" >> model_eval_lib.BatchedInputsToExtracts() + | feature_extractor.stage_name >> feature_extractor.ptransform + | slice_key_extractor.stage_name >> slice_key_extractor.ptransform + ) + + # pylint: enable=no-value-for-parameter + + def check_result(got): + try: + self.assertLen(got, 1) + np.testing.assert_equal( + got[0][constants.SLICE_KEY_TYPES_KEY], + types.VarLenTensorValue.from_dense_rows( + [ + slicer_lib.slice_keys_to_numpy_array( + [ + ( + ("fixed_string", "fixed_string1"), + ("fixed_int", "1"), + ) + ] + ), + slicer_lib.slice_keys_to_numpy_array( + [ + ( + ("fixed_string", "fixed_string2"), + ("fixed_int", "1"), + ) + ] + ), + np.array([]), + ] + ), + ) + + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(result, check_result) + + def testSqlSliceKeyExtractorWithEmptySqlConfig(self): + eval_config = config_pb2.EvalConfig() + feature_extractor = features_extractor.FeaturesExtractor( + eval_config=eval_config + ) + slice_key_extractor = sql_slice_key_extractor.SqlSliceKeyExtractor(eval_config) + + tfx_io = tf_example_record.TFExampleBeamRecord( + physical_format="inmem", + telemetry_descriptors=["test", "component"], + schema=_SCHEMA, + raw_record_column_name=constants.ARROW_INPUT_COLUMN, + ) + examples = [ + self._makeExample( + fixed_int=1, fixed_float=1.0, fixed_string="fixed_string1" + ), + self._makeExample( + fixed_int=1, fixed_float=1.0, fixed_string="fixed_string2" + ), + self._makeExample( + fixed_int=2, fixed_float=0.0, fixed_string="fixed_string3" + ), + ] + + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "Create" + >> beam.Create( + [e.SerializeToString() for e in examples], reshuffle=False + ) + | "BatchExamples" >> tfx_io.BeamSource(batch_size=3) + | "InputsToExtracts" >> model_eval_lib.BatchedInputsToExtracts() + | feature_extractor.stage_name >> feature_extractor.ptransform + | slice_key_extractor.stage_name >> slice_key_extractor.ptransform + ) + + # pylint: enable=no-value-for-parameter + + def check_result(got): + try: + self.assertLen(got, 1) + np.testing.assert_equal( + got[0][constants.SLICE_KEY_TYPES_KEY], + types.VarLenTensorValue.from_dense_rows( + [np.array([]), np.array([]), np.array([])] + ), + ) + + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(result, check_result) + + def testSqlSliceKeyExtractorWithMultipleSchema(self): + eval_config = config_pb2.EvalConfig( + slicing_specs=[ + config_pb2.SlicingSpec( + slice_keys_sql=""" SELECT STRUCT(fixed_string) FROM example.fixed_string, example.fixed_int WHERE fixed_int = 1 - """)] - ) - feature_extractor = features_extractor.FeaturesExtractor( - eval_config=eval_config - ) - slice_key_extractor = sql_slice_key_extractor.SqlSliceKeyExtractor( - eval_config - ) - - record_batch_1 = pa.RecordBatch.from_arrays( - [ - pa.array([[1], [1], [2]], type=pa.list_(pa.int64())), - pa.array([[1.0], [1.0], [2.0]], type=pa.list_(pa.float64())), - pa.array( - [['fixed_string1'], ['fixed_string2'], ['fixed_string3']], - type=pa.list_(pa.string()), - ), - ], - ['fixed_int', 'fixed_float', 'fixed_string'], - ) - record_batch_2 = pa.RecordBatch.from_arrays( - [ - pa.array([[1], [1], [2]], type=pa.list_(pa.int64())), - pa.array([[1.0], [1.0], [2.0]], type=pa.list_(pa.float64())), - pa.array( - [['fixed_string1'], ['fixed_string2'], ['fixed_string3']], - type=pa.list_(pa.string()), - ), - pa.array( - [['extra_field1'], ['extra_field2'], ['extra_field3']], - type=pa.list_(pa.string()), - ), - ], - ['fixed_int', 'fixed_float', 'fixed_string', 'extra_field'], - ) - - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'Create' - >> beam.Create([record_batch_1, record_batch_2], reshuffle=False) - | 'InputsToExtracts' >> model_eval_lib.BatchedInputsToExtracts() - | feature_extractor.stage_name >> feature_extractor.ptransform - | slice_key_extractor.stage_name >> slice_key_extractor.ptransform - ) - - # pylint: enable=no-value-for-parameter - - def check_result(got): - try: - self.assertLen(got, 2) - np.testing.assert_equal( - got[0][constants.SLICE_KEY_TYPES_KEY], - types.VarLenTensorValue.from_dense_rows([ - slicer_lib.slice_keys_to_numpy_array([( - ('fixed_string', 'fixed_string1'), - )]), - slicer_lib.slice_keys_to_numpy_array([( - ('fixed_string', 'fixed_string2'), - )]), - np.array([]), - ]), - ) - np.testing.assert_equal( - got[1][constants.SLICE_KEY_TYPES_KEY], - types.VarLenTensorValue.from_dense_rows([ - slicer_lib.slice_keys_to_numpy_array([( - ('fixed_string', 'fixed_string1'), - )]), - slicer_lib.slice_keys_to_numpy_array([( - ('fixed_string', 'fixed_string2'), - )]), - np.array([]), - ]), - ) - - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(result, check_result) - - -if __name__ == '__main__': - tf.test.main() + """ + ) + ] + ) + feature_extractor = features_extractor.FeaturesExtractor( + eval_config=eval_config + ) + slice_key_extractor = sql_slice_key_extractor.SqlSliceKeyExtractor(eval_config) + + record_batch_1 = pa.RecordBatch.from_arrays( + [ + pa.array([[1], [1], [2]], type=pa.list_(pa.int64())), + pa.array([[1.0], [1.0], [2.0]], type=pa.list_(pa.float64())), + pa.array( + [["fixed_string1"], ["fixed_string2"], ["fixed_string3"]], + type=pa.list_(pa.string()), + ), + ], + ["fixed_int", "fixed_float", "fixed_string"], + ) + record_batch_2 = pa.RecordBatch.from_arrays( + [ + pa.array([[1], [1], [2]], type=pa.list_(pa.int64())), + pa.array([[1.0], [1.0], [2.0]], type=pa.list_(pa.float64())), + pa.array( + [["fixed_string1"], ["fixed_string2"], ["fixed_string3"]], + type=pa.list_(pa.string()), + ), + pa.array( + [["extra_field1"], ["extra_field2"], ["extra_field3"]], + type=pa.list_(pa.string()), + ), + ], + ["fixed_int", "fixed_float", "fixed_string", "extra_field"], + ) + + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "Create" + >> beam.Create([record_batch_1, record_batch_2], reshuffle=False) + | "InputsToExtracts" >> model_eval_lib.BatchedInputsToExtracts() + | feature_extractor.stage_name >> feature_extractor.ptransform + | slice_key_extractor.stage_name >> slice_key_extractor.ptransform + ) + + # pylint: enable=no-value-for-parameter + + def check_result(got): + try: + self.assertLen(got, 2) + np.testing.assert_equal( + got[0][constants.SLICE_KEY_TYPES_KEY], + types.VarLenTensorValue.from_dense_rows( + [ + slicer_lib.slice_keys_to_numpy_array( + [(("fixed_string", "fixed_string1"),)] + ), + slicer_lib.slice_keys_to_numpy_array( + [(("fixed_string", "fixed_string2"),)] + ), + np.array([]), + ] + ), + ) + np.testing.assert_equal( + got[1][constants.SLICE_KEY_TYPES_KEY], + types.VarLenTensorValue.from_dense_rows( + [ + slicer_lib.slice_keys_to_numpy_array( + [(("fixed_string", "fixed_string1"),)] + ), + slicer_lib.slice_keys_to_numpy_array( + [(("fixed_string", "fixed_string2"),)] + ), + np.array([]), + ] + ), + ) + + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(result, check_result) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow_model_analysis/extractors/tfjs_predict_extractor.py b/tensorflow_model_analysis/extractors/tfjs_predict_extractor.py index dbdf3383ea..c7c9a90b21 100644 --- a/tensorflow_model_analysis/extractors/tfjs_predict_extractor.py +++ b/tensorflow_model_analysis/extractors/tfjs_predict_extractor.py @@ -19,254 +19,243 @@ import os import subprocess import tempfile -from typing import Dict, List, NamedTuple, Sequence, Union import uuid +from typing import Dict, List, NamedTuple, Sequence, Union import apache_beam as beam import numpy as np import tensorflow as tf + from tensorflow_model_analysis import constants from tensorflow_model_analysis.api import types from tensorflow_model_analysis.extractors import extractor -from tensorflow_model_analysis.extractors.tfjs_predict_extractor_util import get_tfjs_binary +from tensorflow_model_analysis.extractors.tfjs_predict_extractor_util import ( + get_tfjs_binary, +) from tensorflow_model_analysis.proto import config_pb2 -from tensorflow_model_analysis.utils import model_util -from tensorflow_model_analysis.utils import util +from tensorflow_model_analysis.utils import model_util, util -_TFJS_PREDICT_EXTRACTOR_STAGE_NAME = 'ExtractTFJSPredictions' +_TFJS_PREDICT_EXTRACTOR_STAGE_NAME = "ExtractTFJSPredictions" -_MODELS_SUBDIR = 'Models' -_EXAMPLES_SUBDIR = 'Input_Examples' -_OUTPUTS_SUBDIR = 'Inference_Results' +_MODELS_SUBDIR = "Models" +_EXAMPLES_SUBDIR = "Input_Examples" +_OUTPUTS_SUBDIR = "Inference_Results" -_MODEL_JSON = 'model.json' -_DATA_JSON = 'data.json' -_DTYPE_JSON = 'dtype.json' -_SHAPE_JSON = 'shape.json' -_TF_INPUT_NAME_JSON = 'tf_input_name.json' +_MODEL_JSON = "model.json" +_DATA_JSON = "data.json" +_DTYPE_JSON = "dtype.json" +_SHAPE_JSON = "shape.json" +_TF_INPUT_NAME_JSON = "tf_input_name.json" -class _InputSpec( - NamedTuple('_InputSpec', [('dim', List[int]), ('dtype', str)]) -): - """_InputSpec encapsulates the processed input spec from TFJS model. +class _InputSpec(NamedTuple("_InputSpec", [("dim", List[int]), ("dtype", str)])): + """_InputSpec encapsulates the processed input spec from TFJS model. - Attributes: - dim: The expected dim of the input. - dtype: The expected dtype of the input. - """ + Attributes + ---------- + dim: The expected dim of the input. + dtype: The expected dtype of the input. + """ # TODO(b/149981535) Determine if we should merge with RunInference. @beam.typehints.with_input_types(types.Extracts) @beam.typehints.with_output_types(types.Extracts) class _TFJSPredictionDoFn(model_util.BatchReducibleBatchedDoFnWithModels): - """A DoFn that loads tfjs models and predicts.""" - - def __init__( - self, - eval_config: config_pb2.EvalConfig, - eval_shared_models: Dict[str, types.EvalSharedModel], - ) -> None: - super().__init__({k: v.model_loader for k, v in eval_shared_models.items()}) - self._eval_config = eval_config - self._src_model_paths = { - k: v.model_path for k, v in eval_shared_models.items() - } - - def setup(self): - super().setup() - self._binary_path = get_tfjs_binary() - - base_path = tempfile.mkdtemp() - base_model_path = os.path.join(base_path, _MODELS_SUBDIR) - - self._model_properties = {} - for model_name, model_path in self._src_model_paths.items(): - with tf.io.gfile.GFile(os.path.join(model_path, _MODEL_JSON)) as f: - model_json = json.load(f) - if ( - 'userDefinedMetadata' in model_json - and 'signature' in model_json['userDefinedMetadata'] - ): - model_signature = model_json['userDefinedMetadata']['signature'] - else: - model_signature = model_json['signature'] - model_inputs = {} - for k, v in model_signature['inputs'].items(): - model_inputs[k] = _InputSpec( - [int(d['size']) for d in v['tensorShape']['dim']], v['dtype'] - ) - - model_outputs = {} - for k, v in model_signature['outputs'].items(): - model_outputs[k] = [int(i['size']) for i in v['tensorShape']['dim']] - - cur_model_path = os.path.join(base_model_path, model_name) - self._model_properties[model_name] = { - 'inputs': model_inputs, - 'outputs': model_outputs, - 'path': cur_model_path, - } - - # We copy models to local tmp storage so that the tfjs binary can - # access them. - tf.io.gfile.makedirs(cur_model_path) - for directory, _, files in tf.io.gfile.walk(model_path): - cur_path = os.path.join( - cur_model_path, os.path.relpath(directory, model_path) - ) - tf.io.gfile.makedirs(cur_path) - for f in files: - src_path = os.path.join(directory, f) - tf.io.gfile.copy(src_path, os.path.join(cur_path, f)) - - def _batch_reducible_process( - self, element: types.Extracts - ) -> Sequence[types.Extracts]: - """Invokes the tfjs model on the provided inputs and stores the result.""" - result = copy.copy(element) - - batched_features = collections.defaultdict(list) - for key, value in element[constants.FEATURES_KEY].items(): - if value.dtype == np.int64: - value = value.astype(np.int32) - batched_features[key] = value - batch_size = util.batch_size(batched_features) - - for spec in self._eval_config.model_specs: - model_name = spec.name if len(self._eval_config.model_specs) > 1 else '' - if model_name not in self._loaded_models: - raise ValueError( - 'model for "{}" not found: eval_config={}'.format( - spec.name, self._eval_config + """A DoFn that loads tfjs models and predicts.""" + + def __init__( + self, + eval_config: config_pb2.EvalConfig, + eval_shared_models: Dict[str, types.EvalSharedModel], + ) -> None: + super().__init__({k: v.model_loader for k, v in eval_shared_models.items()}) + self._eval_config = eval_config + self._src_model_paths = {k: v.model_path for k, v in eval_shared_models.items()} + + def setup(self): + super().setup() + self._binary_path = get_tfjs_binary() + + base_path = tempfile.mkdtemp() + base_model_path = os.path.join(base_path, _MODELS_SUBDIR) + + self._model_properties = {} + for model_name, model_path in self._src_model_paths.items(): + with tf.io.gfile.GFile(os.path.join(model_path, _MODEL_JSON)) as f: + model_json = json.load(f) + if ( + "userDefinedMetadata" in model_json + and "signature" in model_json["userDefinedMetadata"] + ): + model_signature = model_json["userDefinedMetadata"]["signature"] + else: + model_signature = model_json["signature"] + model_inputs = {} + for k, v in model_signature["inputs"].items(): + model_inputs[k] = _InputSpec( + [int(d["size"]) for d in v["tensorShape"]["dim"]], v["dtype"] + ) + + model_outputs = {} + for k, v in model_signature["outputs"].items(): + model_outputs[k] = [int(i["size"]) for i in v["tensorShape"]["dim"]] + + cur_model_path = os.path.join(base_model_path, model_name) + self._model_properties[model_name] = { + "inputs": model_inputs, + "outputs": model_outputs, + "path": cur_model_path, + } + + # We copy models to local tmp storage so that the tfjs binary can + # access them. + tf.io.gfile.makedirs(cur_model_path) + for directory, _, files in tf.io.gfile.walk(model_path): + cur_path = os.path.join( + cur_model_path, os.path.relpath(directory, model_path) + ) + tf.io.gfile.makedirs(cur_path) + for f in files: + src_path = os.path.join(directory, f) + tf.io.gfile.copy(src_path, os.path.join(cur_path, f)) + + def _batch_reducible_process( + self, element: types.Extracts + ) -> Sequence[types.Extracts]: + """Invokes the tfjs model on the provided inputs and stores the result.""" + result = copy.copy(element) + + batched_features = collections.defaultdict(list) + for key, value in element[constants.FEATURES_KEY].items(): + if value.dtype == np.int64: + value = value.astype(np.int32) + batched_features[key] = value + batch_size = util.batch_size(batched_features) + + for spec in self._eval_config.model_specs: + model_name = spec.name if len(self._eval_config.model_specs) > 1 else "" + if model_name not in self._loaded_models: + raise ValueError( + f'model for "{spec.name}" not found: eval_config={self._eval_config}' + ) + + model_features = {} + for k in self._model_properties[model_name]["inputs"]: + k_name = k.split(":")[0] + if k_name not in batched_features: + raise ValueError( + f'model requires feature "{k_name}" not available in input.' + ) + dim = self._model_properties[model_name]["inputs"][k].dim + elems = [] + for i in batched_features[k_name]: + if np.ndim(i) > len(dim): + raise ValueError( + f'ranks for input "{k_name}" are not compatible ' + "with the model." + ) + # TODO(dzats): See if we can support case where multiple dimensions + # are not defined. + elems.append(np.reshape(i, dim)) + model_features[k] = elems + + model_features = {k: np.concatenate(v) for k, v in model_features.items()} + + batched_entries = collections.defaultdict(list) + for feature, value in model_features.items(): + if ( + self._model_properties[model_name]["inputs"][feature].dtype + == "DT_STRING" + ): + # For numpy array, even we cast it to string, we cannot get 'string' + # by directly calling str(value.tdype). + value = value.astype(str) + dtype_str = "string" + else: + dtype_str = str(value.dtype) + batched_entries[_DATA_JSON].append(value.tolist()) + batched_entries[_DTYPE_JSON].append(dtype_str) + batched_entries[_SHAPE_JSON].append(value.shape) + batched_entries[_TF_INPUT_NAME_JSON].append(feature) + + cur_subdir = str(uuid.uuid4()) + cur_input_path = os.path.join( + self._model_properties[model_name]["path"], + _EXAMPLES_SUBDIR, + cur_subdir, ) - ) - - model_features = {} - for k in self._model_properties[model_name]['inputs']: - k_name = k.split(':')[0] - if k_name not in batched_features: - raise ValueError( - 'model requires feature "{}" not available in input.'.format( - k_name - ) - ) - dim = self._model_properties[model_name]['inputs'][k].dim - elems = [] - for i in batched_features[k_name]: - if np.ndim(i) > len(dim): - raise ValueError( - 'ranks for input "{}" are not compatible ' - 'with the model.'.format(k_name) - ) - # TODO(dzats): See if we can support case where multiple dimensions - # are not defined. - elems.append(np.reshape(i, dim)) - model_features[k] = elems - - model_features = {k: np.concatenate(v) for k, v in model_features.items()} - - batched_entries = collections.defaultdict(list) - for feature, value in model_features.items(): - if ( - self._model_properties[model_name]['inputs'][feature].dtype - == 'DT_STRING' - ): - # For numpy array, even we cast it to string, we cannot get 'string' - # by directly calling str(value.tdype). - value = value.astype(str) - dtype_str = 'string' - else: - dtype_str = str(value.dtype) - batched_entries[_DATA_JSON].append(value.tolist()) - batched_entries[_DTYPE_JSON].append(dtype_str) - batched_entries[_SHAPE_JSON].append(value.shape) - batched_entries[_TF_INPUT_NAME_JSON].append(feature) - - cur_subdir = str(uuid.uuid4()) - cur_input_path = os.path.join( - self._model_properties[model_name]['path'], - _EXAMPLES_SUBDIR, - cur_subdir, - ) - tf.io.gfile.makedirs(cur_input_path) - for entry, value in batched_entries.items(): - with tf.io.gfile.GFile(os.path.join(cur_input_path, entry), 'w') as f: - f.write(json.dumps(value)) - - cur_output_path = os.path.join( - self._model_properties[model_name]['path'], - _OUTPUTS_SUBDIR, - cur_subdir, - ) - tf.io.gfile.makedirs(cur_output_path) - inference_command = [ - self._binary_path, - '--model_path=' - + os.path.join( - self._model_properties[model_name]['path'], _MODEL_JSON - ), - '--inputs_dir=' + cur_input_path, - '--outputs_dir=' + cur_output_path, - ] - - popen = subprocess.Popen( - inference_command, - stdin=subprocess.PIPE, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - ) - stdout, stderr = popen.communicate() - if popen.returncode != 0: - raise ValueError( - 'Inference failed with status {}\nstdout:\n{}\nstderr:\n{}'.format( - popen.returncode, stdout, stderr + tf.io.gfile.makedirs(cur_input_path) + for entry, value in batched_entries.items(): + with tf.io.gfile.GFile(os.path.join(cur_input_path, entry), "w") as f: + f.write(json.dumps(value)) + + cur_output_path = os.path.join( + self._model_properties[model_name]["path"], + _OUTPUTS_SUBDIR, + cur_subdir, ) - ) - - try: - with tf.io.gfile.GFile(os.path.join(cur_output_path, _DATA_JSON)) as f: - data = json.load(f) - with tf.io.gfile.GFile(os.path.join(cur_output_path, _DTYPE_JSON)) as f: - dtype = json.load(f) - with tf.io.gfile.GFile(os.path.join(cur_output_path, _SHAPE_JSON)) as f: - shape = json.load(f) - except (FileNotFoundError, tf.errors.NotFoundError) as e: - raise FileNotFoundError( - 'Unable to find files containing inference result. This likely ' - 'means that inference did not succeed. Error {}.\n Inference failed' - ' with status {}\nstdout:\n{}\nstderr:\n{}'.format( - e, popen.returncode, stdout, stderr + tf.io.gfile.makedirs(cur_output_path) + inference_command = [ + self._binary_path, + "--model_path=" + + os.path.join(self._model_properties[model_name]["path"], _MODEL_JSON), + "--inputs_dir=" + cur_input_path, + "--outputs_dir=" + cur_output_path, + ] + + popen = subprocess.Popen( + inference_command, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, ) - ) from e - - name = [ - n.split(':')[0] for n in self._model_properties[model_name]['outputs'] - ] - - tf.io.gfile.rmtree(cur_input_path) - tf.io.gfile.rmtree(cur_output_path) - - outputs = {} - for n, s, t, d in zip(name, shape, dtype, data): - d_val = [d[str(i)] for i in range(len(d))] - outputs[n] = np.reshape(np.array(d_val, t), s) - - for v in outputs.values(): - if len(v) != batch_size: - raise ValueError('Did not get the expected number of results.') - - if len(outputs) == 1: - outputs = list(outputs.values())[0] - - if len(self._eval_config.model_specs) == 1: - result[constants.PREDICTIONS_KEY] = outputs - else: - if constants.PREDICTIONS_KEY not in result: - result[constants.PREDICTIONS_KEY] = {} - result[constants.PREDICTIONS_KEY][spec.name] = outputs - return [result] + stdout, stderr = popen.communicate() + if popen.returncode != 0: + raise ValueError( + f"Inference failed with status {popen.returncode}\nstdout:\n{stdout}\nstderr:\n{stderr}" + ) + + try: + with tf.io.gfile.GFile(os.path.join(cur_output_path, _DATA_JSON)) as f: + data = json.load(f) + with tf.io.gfile.GFile(os.path.join(cur_output_path, _DTYPE_JSON)) as f: + dtype = json.load(f) + with tf.io.gfile.GFile(os.path.join(cur_output_path, _SHAPE_JSON)) as f: + shape = json.load(f) + except (FileNotFoundError, tf.errors.NotFoundError) as e: + raise FileNotFoundError( + "Unable to find files containing inference result. This likely " + f"means that inference did not succeed. Error {e}.\n Inference failed" + f" with status {popen.returncode}\nstdout:\n{stdout}\nstderr:\n{stderr}" + ) from e + + name = [ + n.split(":")[0] for n in self._model_properties[model_name]["outputs"] + ] + + tf.io.gfile.rmtree(cur_input_path) + tf.io.gfile.rmtree(cur_output_path) + + outputs = {} + for n, s, t, d in zip(name, shape, dtype, data): + d_val = [d[str(i)] for i in range(len(d))] + outputs[n] = np.reshape(np.array(d_val, t), s) + + for v in outputs.values(): + if len(v) != batch_size: + raise ValueError("Did not get the expected number of results.") + + if len(outputs) == 1: + outputs = list(outputs.values())[0] + + if len(self._eval_config.model_specs) == 1: + result[constants.PREDICTIONS_KEY] = outputs + else: + if constants.PREDICTIONS_KEY not in result: + result[constants.PREDICTIONS_KEY] = {} + result[constants.PREDICTIONS_KEY][spec.name] = outputs + return [result] @beam.ptransform_fn @@ -277,55 +266,57 @@ def _ExtractTFJSPredictions( # pylint: disable=invalid-name eval_config: config_pb2.EvalConfig, eval_shared_models: Dict[str, types.EvalSharedModel], ) -> beam.pvalue.PCollection: - """A PTransform that adds predictions and possibly other tensors to extracts. - - Args: - extracts: PCollection of extracts containing model inputs keyed by - tfma.FEATURES_KEY. - eval_config: Eval config. - eval_shared_models: Shared model parameters keyed by model name. - - Returns: - PCollection of Extracts updated with the predictions. - """ - return extracts | 'Predict' >> beam.ParDo( - _TFJSPredictionDoFn( - eval_config=eval_config, eval_shared_models=eval_shared_models - ) - ) + """A PTransform that adds predictions and possibly other tensors to extracts. + + Args: + ---- + extracts: PCollection of extracts containing model inputs keyed by + tfma.FEATURES_KEY. + eval_config: Eval config. + eval_shared_models: Shared model parameters keyed by model name. + + Returns: + ------- + PCollection of Extracts updated with the predictions. + """ + return extracts | "Predict" >> beam.ParDo( + _TFJSPredictionDoFn( + eval_config=eval_config, eval_shared_models=eval_shared_models + ) + ) def TFJSPredictExtractor( # pylint: disable=invalid-name eval_config: config_pb2.EvalConfig, - eval_shared_model: Union[ - types.EvalSharedModel, Dict[str, types.EvalSharedModel] - ], + eval_shared_model: Union[types.EvalSharedModel, Dict[str, types.EvalSharedModel]], ) -> extractor.Extractor: - """Creates an extractor for performing predictions on tfjs models. - - The extractor's PTransform loads and interprets the tfjs model against - every extract yielding a copy of the incoming extracts with an additional - extract added for the predictions keyed by tfma.PREDICTIONS_KEY. The model - inputs are searched for under tfma.FEATURES_KEY. If multiple - models are used the predictions will be stored in a dict keyed by model name. - - Args: - eval_config: Eval config. - eval_shared_model: Shared model (single-model evaluation) or dict of shared - models keyed by model name (multi-model evaluation). - - Returns: - Extractor for extracting predictions. - """ - eval_shared_models = model_util.verify_and_update_eval_shared_models( - eval_shared_model - ) - - # pylint: disable=no-value-for-parameter - return extractor.Extractor( - stage_name=_TFJS_PREDICT_EXTRACTOR_STAGE_NAME, - ptransform=_ExtractTFJSPredictions( - eval_config=eval_config, - eval_shared_models={m.model_name: m for m in eval_shared_models}, - ), - ) + """Creates an extractor for performing predictions on tfjs models. + + The extractor's PTransform loads and interprets the tfjs model against + every extract yielding a copy of the incoming extracts with an additional + extract added for the predictions keyed by tfma.PREDICTIONS_KEY. The model + inputs are searched for under tfma.FEATURES_KEY. If multiple + models are used the predictions will be stored in a dict keyed by model name. + + Args: + ---- + eval_config: Eval config. + eval_shared_model: Shared model (single-model evaluation) or dict of shared + models keyed by model name (multi-model evaluation). + + Returns: + ------- + Extractor for extracting predictions. + """ + eval_shared_models = model_util.verify_and_update_eval_shared_models( + eval_shared_model + ) + + # pylint: disable=no-value-for-parameter + return extractor.Extractor( + stage_name=_TFJS_PREDICT_EXTRACTOR_STAGE_NAME, + ptransform=_ExtractTFJSPredictions( + eval_config=eval_config, + eval_shared_models={m.model_name: m for m in eval_shared_models}, + ), + ) diff --git a/tensorflow_model_analysis/extractors/tfjs_predict_extractor_test.py b/tensorflow_model_analysis/extractors/tfjs_predict_extractor_test.py index 92270b81e8..a194f4f38c 100644 --- a/tensorflow_model_analysis/extractors/tfjs_predict_extractor_test.py +++ b/tensorflow_model_analysis/extractors/tfjs_predict_extractor_test.py @@ -15,124 +15,127 @@ import tempfile -from absl.testing import parameterized import apache_beam as beam -from apache_beam.testing import util import tensorflow as tf +from absl.testing import parameterized +from apache_beam.testing import util +from google.protobuf import text_format +from tensorflow_metadata.proto.v0 import schema_pb2 +from tfx_bsl.tfxio import test_util + from tensorflow_model_analysis import constants from tensorflow_model_analysis.api import model_eval_lib -from tensorflow_model_analysis.extractors import features_extractor -from tensorflow_model_analysis.extractors import tfjs_predict_extractor +from tensorflow_model_analysis.extractors import ( + features_extractor, + tfjs_predict_extractor, +) from tensorflow_model_analysis.proto import config_pb2 from tensorflow_model_analysis.utils import test_util as testutil from tensorflow_model_analysis.utils.keras_lib import tf_keras -from tfx_bsl.tfxio import test_util - -from google.protobuf import text_format -from tensorflow_metadata.proto.v0 import schema_pb2 try: - from tensorflowjs.converters import converter # pylint: disable=g-import-not-at-top + from tensorflowjs.converters import converter # pylint: disable=g-import-not-at-top - _TFJS_IMPORTED = True + _TFJS_IMPORTED = True except ModuleNotFoundError: - _TFJS_IMPORTED = False + _TFJS_IMPORTED = False class TFJSPredictExtractorTest( testutil.TensorflowModelAnalysisTest, parameterized.TestCase ): - - @parameterized.named_parameters( - ('single_model_single_output', False, False), - ('single_model_multi_output', False, True), - ('multi_model_single_output', True, False), - ('multi_model_multi_output_batched_examples_batched_inputs', True, True), - ) - def testTFJSPredictExtractorWithKerasModel(self, multi_model, multi_output): - if not _TFJS_IMPORTED: - self.skipTest('This test requires TensorFlow JS.') - - input1 = tf_keras.layers.Input(shape=(1,), name='input1') - input2 = tf_keras.layers.Input(shape=(1,), name='input2', dtype=tf.int64) - input3 = tf_keras.layers.Input(shape=(1,), name='input3', dtype=tf.string) - inputs = [input1, input2, input3] - input_layer = tf_keras.layers.concatenate([ - inputs[0], - tf.cast(inputs[1], tf.float32), - tf.cast(inputs[2] == 'a', tf.float32), - ]) - output_layers = {} - output_layers['output1'] = tf_keras.layers.Dense( - 1, activation=tf.nn.sigmoid, name='output1' - )(input_layer) - if multi_output: - output_layers['output2'] = tf_keras.layers.Dense( - 1, activation=tf.nn.sigmoid, name='output2' - )(input_layer) - - model = tf_keras.models.Model(inputs, output_layers) - model.compile( - optimizer=tf_keras.optimizers.Adam(lr=0.001), - loss=tf_keras.losses.binary_crossentropy, - metrics=['accuracy'], + @parameterized.named_parameters( + ("single_model_single_output", False, False), + ("single_model_multi_output", False, True), + ("multi_model_single_output", True, False), + ("multi_model_multi_output_batched_examples_batched_inputs", True, True), ) + def testTFJSPredictExtractorWithKerasModel(self, multi_model, multi_output): + if not _TFJS_IMPORTED: + self.skipTest("This test requires TensorFlow JS.") + + input1 = tf_keras.layers.Input(shape=(1,), name="input1") + input2 = tf_keras.layers.Input(shape=(1,), name="input2", dtype=tf.int64) + input3 = tf_keras.layers.Input(shape=(1,), name="input3", dtype=tf.string) + inputs = [input1, input2, input3] + input_layer = tf_keras.layers.concatenate( + [ + inputs[0], + tf.cast(inputs[1], tf.float32), + tf.cast(inputs[2] == "a", tf.float32), + ] + ) + output_layers = {} + output_layers["output1"] = tf_keras.layers.Dense( + 1, activation=tf.nn.sigmoid, name="output1" + )(input_layer) + if multi_output: + output_layers["output2"] = tf_keras.layers.Dense( + 1, activation=tf.nn.sigmoid, name="output2" + )(input_layer) + + model = tf_keras.models.Model(inputs, output_layers) + model.compile( + optimizer=tf_keras.optimizers.Adam(lr=0.001), + loss=tf_keras.losses.binary_crossentropy, + metrics=["accuracy"], + ) - train_features = { - 'input1': [[0.0], [1.0]], - 'input2': [[1], [0]], - 'input3': [[b'a'], [b'b']], - } - labels = {'output1': [[1], [0]]} - if multi_output: - labels['output2'] = [[1], [0]] - - example_weights = {'output1': [1.0, 0.5]} - if multi_output: - example_weights['output2'] = [1.0, 0.5] - dataset = tf.data.Dataset.from_tensor_slices( - (train_features, labels, example_weights) - ) - dataset = dataset.shuffle(buffer_size=1).repeat().batch(2) - model.fit(dataset, steps_per_epoch=1) - - src_model_path = tempfile.mkdtemp() - model.save(src_model_path) - - dst_model_path = tempfile.mkdtemp() - converter.convert([ - '--input_format=tf_saved_model', - '--saved_model_tags=serve', - '--signature_name=serving_default', - src_model_path, - dst_model_path, - ]) - - model_specs = [config_pb2.ModelSpec(name='model1', model_type='tf_js')] - if multi_model: - model_specs.append( - config_pb2.ModelSpec(name='model2', model_type='tf_js') - ) - - eval_config = config_pb2.EvalConfig(model_specs=model_specs) - eval_shared_models = [ - self.createTestEvalSharedModel( - model_name='model1', - model_path=dst_model_path, - model_type='tf_js', + train_features = { + "input1": [[0.0], [1.0]], + "input2": [[1], [0]], + "input3": [[b"a"], [b"b"]], + } + labels = {"output1": [[1], [0]]} + if multi_output: + labels["output2"] = [[1], [0]] + + example_weights = {"output1": [1.0, 0.5]} + if multi_output: + example_weights["output2"] = [1.0, 0.5] + dataset = tf.data.Dataset.from_tensor_slices( + (train_features, labels, example_weights) + ) + dataset = dataset.shuffle(buffer_size=1).repeat().batch(2) + model.fit(dataset, steps_per_epoch=1) + + src_model_path = tempfile.mkdtemp() + model.save(src_model_path) + + dst_model_path = tempfile.mkdtemp() + converter.convert( + [ + "--input_format=tf_saved_model", + "--saved_model_tags=serve", + "--signature_name=serving_default", + src_model_path, + dst_model_path, + ] ) - ] - if multi_model: - eval_shared_models.append( - self.createTestEvalSharedModel( - model_name='model2', - model_path=dst_model_path, - model_type='tf_js', - ) - ) - - schema = text_format.Parse( - """ + + model_specs = [config_pb2.ModelSpec(name="model1", model_type="tf_js")] + if multi_model: + model_specs.append(config_pb2.ModelSpec(name="model2", model_type="tf_js")) + + eval_config = config_pb2.EvalConfig(model_specs=model_specs) + eval_shared_models = [ + self.createTestEvalSharedModel( + model_name="model1", + model_path=dst_model_path, + model_type="tf_js", + ) + ] + if multi_model: + eval_shared_models.append( + self.createTestEvalSharedModel( + model_name="model2", + model_path=dst_model_path, + model_type="tf_js", + ) + ) + + schema = text_format.Parse( + """ feature { name: "input1" type: FLOAT @@ -150,64 +153,62 @@ def testTFJSPredictExtractorWithKerasModel(self, multi_model, multi_output): type: INT } """, - schema_pb2.Schema(), - ) - tfx_io = test_util.InMemoryTFExampleRecord( - schema=schema, raw_record_column_name=constants.ARROW_INPUT_COLUMN - ) - feature_extractor = features_extractor.FeaturesExtractor(eval_config) - predictor = tfjs_predict_extractor.TFJSPredictExtractor( - eval_config=eval_config, eval_shared_model=eval_shared_models - ) + schema_pb2.Schema(), + ) + tfx_io = test_util.InMemoryTFExampleRecord( + schema=schema, raw_record_column_name=constants.ARROW_INPUT_COLUMN + ) + feature_extractor = features_extractor.FeaturesExtractor(eval_config) + predictor = tfjs_predict_extractor.TFJSPredictExtractor( + eval_config=eval_config, eval_shared_model=eval_shared_models + ) - examples = [ - self._makeExample( - input1=0.0, input2=1, input3=b'a', non_model_feature=0 - ), - self._makeExample( - input1=1.0, input2=0, input3=b'b', non_model_feature=1 - ), - ] - - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'Create' - >> beam.Create( - [e.SerializeToString() for e in examples], reshuffle=False - ) - | 'BatchExamples' >> tfx_io.BeamSource(batch_size=2) - | 'InputsToExtracts' >> model_eval_lib.BatchedInputsToExtracts() - | feature_extractor.stage_name >> feature_extractor.ptransform - | predictor.stage_name >> predictor.ptransform - ) - - # pylint: enable=no-value-for-parameter - - def check_result(got): - try: - self.assertLen(got, 1) - got = got[0] - self.assertIn(constants.PREDICTIONS_KEY, got) - for model in ('model1', 'model2') if multi_model else (''): - per_model_result = got[constants.PREDICTIONS_KEY] - if model: - self.assertIn(model, per_model_result) - per_model_result = per_model_result[model] - for output in ('Identity', 'Identity_1') if multi_output else (''): - per_output_result = per_model_result - if output: - self.assertIn(output, per_output_result) - per_output_result = per_output_result[output] - self.assertLen(per_output_result, 2) - - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(result, check_result, label='result') - - -if __name__ == '__main__': - tf.compat.v1.enable_v2_behavior() - tf.test.main() + examples = [ + self._makeExample(input1=0.0, input2=1, input3=b"a", non_model_feature=0), + self._makeExample(input1=1.0, input2=0, input3=b"b", non_model_feature=1), + ] + + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "Create" + >> beam.Create( + [e.SerializeToString() for e in examples], reshuffle=False + ) + | "BatchExamples" >> tfx_io.BeamSource(batch_size=2) + | "InputsToExtracts" >> model_eval_lib.BatchedInputsToExtracts() + | feature_extractor.stage_name >> feature_extractor.ptransform + | predictor.stage_name >> predictor.ptransform + ) + + # pylint: enable=no-value-for-parameter + + def check_result(got): + try: + self.assertLen(got, 1) + got = got[0] + self.assertIn(constants.PREDICTIONS_KEY, got) + for model in ("model1", "model2") if multi_model else (""): + per_model_result = got[constants.PREDICTIONS_KEY] + if model: + self.assertIn(model, per_model_result) + per_model_result = per_model_result[model] + for output in ( + ("Identity", "Identity_1") if multi_output else ("") + ): + per_output_result = per_model_result + if output: + self.assertIn(output, per_output_result) + per_output_result = per_output_result[output] + self.assertLen(per_output_result, 2) + + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(result, check_result, label="result") + + +if __name__ == "__main__": + tf.compat.v1.enable_v2_behavior() + tf.test.main() diff --git a/tensorflow_model_analysis/extractors/tfjs_predict_extractor_util.py b/tensorflow_model_analysis/extractors/tfjs_predict_extractor_util.py index af9fd12002..e917f93e25 100644 --- a/tensorflow_model_analysis/extractors/tfjs_predict_extractor_util.py +++ b/tensorflow_model_analysis/extractors/tfjs_predict_extractor_util.py @@ -24,16 +24,16 @@ def get_tfjs_binary(): - """Download and return the path to the tfjs binary.""" - if sys.platform == 'darwin': - url = 'http://storage.googleapis.com/tfjs-inference/tfjs-inference-macos' - else: - url = 'http://storage.googleapis.com/tfjs-inference/tfjs-inference-linux' + """Download and return the path to the tfjs binary.""" + if sys.platform == "darwin": + url = "http://storage.googleapis.com/tfjs-inference/tfjs-inference-macos" + else: + url = "http://storage.googleapis.com/tfjs-inference/tfjs-inference-linux" - base_path = tempfile.mkdtemp() - path = os.path.join(base_path, 'binary') - with urllib.request.urlopen(url) as response: - with tf.io.gfile.GFile(path, 'w') as file: - shutil.copyfileobj(response, file) - subprocess.check_call(['chmod', '+x', path]) - return path + base_path = tempfile.mkdtemp() + path = os.path.join(base_path, "binary") + with urllib.request.urlopen(url) as response: + with tf.io.gfile.GFile(path, "w") as file: + shutil.copyfileobj(response, file) + subprocess.check_call(["chmod", "+x", path]) + return path diff --git a/tensorflow_model_analysis/extractors/tflite_predict_extractor.py b/tensorflow_model_analysis/extractors/tflite_predict_extractor.py index 2b7917b41d..9acbc024b9 100644 --- a/tensorflow_model_analysis/extractors/tflite_predict_extractor.py +++ b/tensorflow_model_analysis/extractors/tflite_predict_extractor.py @@ -17,200 +17,200 @@ import copy from typing import Callable, Dict, Sequence, Union -from absl import logging import apache_beam as beam import numpy as np import tensorflow as tf +from absl import logging + from tensorflow_model_analysis import constants from tensorflow_model_analysis.api import types from tensorflow_model_analysis.extractors import extractor from tensorflow_model_analysis.proto import config_pb2 -from tensorflow_model_analysis.utils import model_util -from tensorflow_model_analysis.utils import util +from tensorflow_model_analysis.utils import model_util, util _OpResolverType = tf.lite.experimental.OpResolverType -_TFLITE_PREDICT_EXTRACTOR_STAGE_NAME = 'ExtractTFLitePredictions' +_TFLITE_PREDICT_EXTRACTOR_STAGE_NAME = "ExtractTFLitePredictions" -_QUANTIZATION_PARAMETERS = 'quantization_parameters' -_INDEX = 'index' -_SCALES = 'scales' -_ZERO_POINTS = 'zero_points' -_DTYPE = 'dtype' +_QUANTIZATION_PARAMETERS = "quantization_parameters" +_INDEX = "index" +_SCALES = "scales" +_ZERO_POINTS = "zero_points" +_DTYPE = "dtype" # TODO(b/149981535) Determine if we should merge with RunInference. @beam.typehints.with_input_types(types.Extracts) @beam.typehints.with_output_types(types.Extracts) class TFLitePredictionDoFn(model_util.BatchReducibleBatchedDoFnWithModels): - """A DoFn that loads tflite models and predicts.""" - - def __init__( - self, - eval_config: config_pb2.EvalConfig, - eval_shared_models: Dict[str, types.EvalSharedModel], - ) -> None: - super().__init__({k: v.model_loader for k, v in eval_shared_models.items()}) - self._eval_config = eval_config - - def setup(self): - super().setup() - self._interpreters = {} - - major, minor, _ = tf.version.VERSION.split('.') - op_resolver_type = _OpResolverType.AUTO - # TODO(b/207600661): drop BUILTIN_WITHOUT_DEFAULT_DELEGATES once the issue - # is fixed. - if int(major) > 2 or (int(major) == 2 and int(minor) >= 5): - op_resolver_type = _OpResolverType.BUILTIN_WITHOUT_DEFAULT_DELEGATES - for model_name, model_contents in self._loaded_models.items(): - self._interpreters[model_name] = self._make_interpreter( - model_content=model_contents.contents, - experimental_op_resolver_type=op_resolver_type, - ) - - def _make_interpreter(self, **kwargs) -> tf.lite.Interpreter: - return tf.lite.Interpreter(**kwargs) - - def _post_process_result(self, input_tensor: np.ndarray) -> np.ndarray: - """Custom post processor for TFLite predictions, default is no-op.""" - return input_tensor - - def _get_input_name_from_input_detail(self, input_detail): - """Get input name from input detail. - - Args: - input_detail: the details for a model input. - - Returns: - Input name. The signature key prefix and argument postfix will be removed. - """ - input_name = input_detail['name'] - # TFLite saved model converter inserts the signature key name at beginning - # of the input names. TFLite rewriter assumes that the default signature key - # ('serving_default') will be used as an exported name when saving. - if input_name.startswith('serving_default_'): - input_name = input_name[len('serving_default_') :] - # Remove argument that starts with ':'. - input_name = input_name.split(':')[0] - return input_name - - # TODO: cr/540071949 - Consider deduplicating this code. - def _dequantize( - self, tensor: np.ndarray, scale: np.ndarray, zero_point: np.ndarray - ) -> np.ndarray: - """Performs dequantization according to the spec: http://shortn/_QPyWQx5mhW.""" - if scale.size == 0 or zero_point.size == 0: - return tensor.astype(np.float64) - return (tensor - zero_point) * scale - - def _quantize( - self, - tensor: np.ndarray, - scale: np.ndarray, - zero_point: np.ndarray, - dtype: np.dtype, - ) -> np.ndarray: - """Performs quantization according to the spec: http://shortn/_QPyWQx5mhW.""" - if scale.size == 0 or zero_point.size == 0: - return tensor - return (tensor / scale + zero_point).astype(dtype) - - def _batch_reducible_process( - self, element: types.Extracts - ) -> Sequence[types.Extracts]: - """Invokes the tflite model on the provided inputs and stores the result.""" - result = copy.copy(element) - - batched_features = element[constants.FEATURES_KEY] - batch_size = util.batch_size(batched_features) - - for spec in self._eval_config.model_specs: - model_name = spec.name if len(self._eval_config.model_specs) > 1 else '' - if model_name not in self._loaded_models: - raise ValueError( - 'model for "{}" not found: eval_config={}'.format( - spec.name, self._eval_config + """A DoFn that loads tflite models and predicts.""" + + def __init__( + self, + eval_config: config_pb2.EvalConfig, + eval_shared_models: Dict[str, types.EvalSharedModel], + ) -> None: + super().__init__({k: v.model_loader for k, v in eval_shared_models.items()}) + self._eval_config = eval_config + + def setup(self): + super().setup() + self._interpreters = {} + + major, minor, _ = tf.version.VERSION.split(".") + op_resolver_type = _OpResolverType.AUTO + # TODO(b/207600661): drop BUILTIN_WITHOUT_DEFAULT_DELEGATES once the issue + # is fixed. + if int(major) > 2 or (int(major) == 2 and int(minor) >= 5): + op_resolver_type = _OpResolverType.BUILTIN_WITHOUT_DEFAULT_DELEGATES + for model_name, model_contents in self._loaded_models.items(): + self._interpreters[model_name] = self._make_interpreter( + model_content=model_contents.contents, + experimental_op_resolver_type=op_resolver_type, ) - ) - - interpreter = self._interpreters[model_name] - - input_details = interpreter.get_input_details() - output_details = interpreter.get_output_details() - - input_features = collections.defaultdict(list) - for i in input_details: - input_name = self._get_input_name_from_input_detail(i) - # The batch dimension is the specific batch size of the last time the - # model was invoked. Set it to 1 to "reset". - input_shape = [1] + list(i['shape'])[1:] - input_type = i[_DTYPE] - for idx in range(batch_size): - if input_name in batched_features: - value = batched_features[input_name][idx] - else: - value = None - if value is None or np.any(np.equal(value, None)): - default = -1 if input_type in [np.float32, np.int64] else '' - value = np.full(input_shape, default, dtype=input_type) - logging.log_every_n( - logging.WARNING, - 'Feature %s not found. Setting default value.', - 100, - input_name, - ) - else: - value = np.reshape(value, input_shape) - input_features[input_name].append(value) - # Concatenate with numpy to avoid implicit conversion to tf.Tensor - # which causes byte-string inputs to fail the set_tensor call. - input_features[input_name] = np.concatenate( - input_features[input_name], axis=0 - ) - if np.shape(input_features[input_name]) != tuple(i['shape']): - interpreter.resize_tensor_input( - i[_INDEX], np.shape(input_features[input_name]) - ) - interpreter.allocate_tensors() - - for i in input_details: - input_name = self._get_input_name_from_input_detail(i) - params = i[_QUANTIZATION_PARAMETERS] - interpreter.set_tensor( - i[_INDEX], - self._quantize( - input_features[input_name], - params[_SCALES], - params[_ZERO_POINTS], - i[_DTYPE], - ), - ) - interpreter.invoke() - - outputs = {} - for o in output_details: - tensor = interpreter.get_tensor(o[_INDEX]) - params = o[_QUANTIZATION_PARAMETERS] - dequantized_tensor = self._dequantize( - tensor, params[_SCALES], params[_ZERO_POINTS] - ) - - outputs[o['name']] = self._post_process_result(dequantized_tensor) - for v in outputs.values(): - if util.batch_size(v) != batch_size: - raise ValueError('Did not get the expected number of results.') - - if len(outputs) == 1: - outputs = list(outputs.values())[0] - - if len(self._eval_config.model_specs) == 1: - result[constants.PREDICTIONS_KEY] = outputs - else: - if constants.PREDICTIONS_KEY not in result: - result[constants.PREDICTIONS_KEY] = {} - result[constants.PREDICTIONS_KEY][spec.name] = outputs - return [result] + def _make_interpreter(self, **kwargs) -> tf.lite.Interpreter: + return tf.lite.Interpreter(**kwargs) + + def _post_process_result(self, input_tensor: np.ndarray) -> np.ndarray: + """Custom post processor for TFLite predictions, default is no-op.""" + return input_tensor + + def _get_input_name_from_input_detail(self, input_detail): + """Get input name from input detail. + + Args: + ---- + input_detail: the details for a model input. + + Returns: + ------- + Input name. The signature key prefix and argument postfix will be removed. + """ + input_name = input_detail["name"] + # TFLite saved model converter inserts the signature key name at beginning + # of the input names. TFLite rewriter assumes that the default signature key + # ('serving_default') will be used as an exported name when saving. + if input_name.startswith("serving_default_"): + input_name = input_name[len("serving_default_") :] + # Remove argument that starts with ':'. + input_name = input_name.split(":")[0] + return input_name + + # TODO: cr/540071949 - Consider deduplicating this code. + def _dequantize( + self, tensor: np.ndarray, scale: np.ndarray, zero_point: np.ndarray + ) -> np.ndarray: + """Performs dequantization according to the spec: http://shortn/_QPyWQx5mhW.""" + if scale.size == 0 or zero_point.size == 0: + return tensor.astype(np.float64) + return (tensor - zero_point) * scale + + def _quantize( + self, + tensor: np.ndarray, + scale: np.ndarray, + zero_point: np.ndarray, + dtype: np.dtype, + ) -> np.ndarray: + """Performs quantization according to the spec: http://shortn/_QPyWQx5mhW.""" + if scale.size == 0 or zero_point.size == 0: + return tensor + return (tensor / scale + zero_point).astype(dtype) + + def _batch_reducible_process( + self, element: types.Extracts + ) -> Sequence[types.Extracts]: + """Invokes the tflite model on the provided inputs and stores the result.""" + result = copy.copy(element) + + batched_features = element[constants.FEATURES_KEY] + batch_size = util.batch_size(batched_features) + + for spec in self._eval_config.model_specs: + model_name = spec.name if len(self._eval_config.model_specs) > 1 else "" + if model_name not in self._loaded_models: + raise ValueError( + f'model for "{spec.name}" not found: eval_config={self._eval_config}' + ) + + interpreter = self._interpreters[model_name] + + input_details = interpreter.get_input_details() + output_details = interpreter.get_output_details() + + input_features = collections.defaultdict(list) + for i in input_details: + input_name = self._get_input_name_from_input_detail(i) + # The batch dimension is the specific batch size of the last time the + # model was invoked. Set it to 1 to "reset". + input_shape = [1] + list(i["shape"])[1:] + input_type = i[_DTYPE] + for idx in range(batch_size): + if input_name in batched_features: + value = batched_features[input_name][idx] + else: + value = None + if value is None or np.any(np.equal(value, None)): + default = -1 if input_type in [np.float32, np.int64] else "" + value = np.full(input_shape, default, dtype=input_type) + logging.log_every_n( + logging.WARNING, + "Feature %s not found. Setting default value.", + 100, + input_name, + ) + else: + value = np.reshape(value, input_shape) + input_features[input_name].append(value) + # Concatenate with numpy to avoid implicit conversion to tf.Tensor + # which causes byte-string inputs to fail the set_tensor call. + input_features[input_name] = np.concatenate( + input_features[input_name], axis=0 + ) + if np.shape(input_features[input_name]) != tuple(i["shape"]): + interpreter.resize_tensor_input( + i[_INDEX], np.shape(input_features[input_name]) + ) + interpreter.allocate_tensors() + + for i in input_details: + input_name = self._get_input_name_from_input_detail(i) + params = i[_QUANTIZATION_PARAMETERS] + interpreter.set_tensor( + i[_INDEX], + self._quantize( + input_features[input_name], + params[_SCALES], + params[_ZERO_POINTS], + i[_DTYPE], + ), + ) + interpreter.invoke() + + outputs = {} + for o in output_details: + tensor = interpreter.get_tensor(o[_INDEX]) + params = o[_QUANTIZATION_PARAMETERS] + dequantized_tensor = self._dequantize( + tensor, params[_SCALES], params[_ZERO_POINTS] + ) + + outputs[o["name"]] = self._post_process_result(dequantized_tensor) + + for v in outputs.values(): + if util.batch_size(v) != batch_size: + raise ValueError("Did not get the expected number of results.") + + if len(outputs) == 1: + outputs = list(outputs.values())[0] + + if len(self._eval_config.model_specs) == 1: + result[constants.PREDICTIONS_KEY] = outputs + else: + if constants.PREDICTIONS_KEY not in result: + result[constants.PREDICTIONS_KEY] = {} + result[constants.PREDICTIONS_KEY][spec.name] = outputs + return [result] @beam.ptransform_fn @@ -222,60 +222,62 @@ def _ExtractTFLitePredictions( # pylint: disable=invalid-name eval_shared_models: Dict[str, types.EvalSharedModel], do_fn: Callable[..., TFLitePredictionDoFn], ) -> beam.pvalue.PCollection: - """A PTransform that adds predictions and possibly other tensors to extracts. - - Args: - extracts: PCollection of extracts containing model inputs keyed by - tfma.FEATURES_KEY. - eval_config: Eval config. - eval_shared_models: Shared model parameters keyed by model name. - do_fn: Constructor for TFLitePredictionDoFn. - - Returns: - PCollection of Extracts updated with the predictions. - """ - return extracts | 'Predict' >> beam.ParDo( - do_fn( - eval_config=eval_config, - eval_shared_models=eval_shared_models, - ) - ) + """A PTransform that adds predictions and possibly other tensors to extracts. + + Args: + ---- + extracts: PCollection of extracts containing model inputs keyed by + tfma.FEATURES_KEY. + eval_config: Eval config. + eval_shared_models: Shared model parameters keyed by model name. + do_fn: Constructor for TFLitePredictionDoFn. + + Returns: + ------- + PCollection of Extracts updated with the predictions. + """ + return extracts | "Predict" >> beam.ParDo( + do_fn( + eval_config=eval_config, + eval_shared_models=eval_shared_models, + ) + ) def TFLitePredictExtractor( eval_config: config_pb2.EvalConfig, - eval_shared_model: Union[ - types.EvalSharedModel, Dict[str, types.EvalSharedModel] - ], + eval_shared_model: Union[types.EvalSharedModel, Dict[str, types.EvalSharedModel]], do_fn: Callable[..., TFLitePredictionDoFn] = TFLitePredictionDoFn, ) -> extractor.Extractor: - """Creates an extractor for performing predictions on tflite models. - - The extractor's PTransform loads and interprets the tflite flatbuffer against - every extract yielding a copy of the incoming extracts with an additional - extract added for the predictions keyed by tfma.PREDICTIONS_KEY. The model - inputs are searched for under tfma.FEATURES_KEY. If multiple - models are used the predictions will be stored in a dict keyed by model name. - - Args: - eval_config: Eval config. - eval_shared_model: Shared model (single-model evaluation) or dict of shared - models keyed by model name (multi-model evaluation). - do_fn: Constructor for TFLitePredictionDoFn. - - Returns: - Extractor for extracting predictions. - """ - eval_shared_models = model_util.verify_and_update_eval_shared_models( - eval_shared_model - ) - - # pylint: disable=no-value-for-parameter - return extractor.Extractor( - stage_name=_TFLITE_PREDICT_EXTRACTOR_STAGE_NAME, - ptransform=_ExtractTFLitePredictions( - eval_config=eval_config, - eval_shared_models={m.model_name: m for m in eval_shared_models}, - do_fn=do_fn, - ), - ) + """Creates an extractor for performing predictions on tflite models. + + The extractor's PTransform loads and interprets the tflite flatbuffer against + every extract yielding a copy of the incoming extracts with an additional + extract added for the predictions keyed by tfma.PREDICTIONS_KEY. The model + inputs are searched for under tfma.FEATURES_KEY. If multiple + models are used the predictions will be stored in a dict keyed by model name. + + Args: + ---- + eval_config: Eval config. + eval_shared_model: Shared model (single-model evaluation) or dict of shared + models keyed by model name (multi-model evaluation). + do_fn: Constructor for TFLitePredictionDoFn. + + Returns: + ------- + Extractor for extracting predictions. + """ + eval_shared_models = model_util.verify_and_update_eval_shared_models( + eval_shared_model + ) + + # pylint: disable=no-value-for-parameter + return extractor.Extractor( + stage_name=_TFLITE_PREDICT_EXTRACTOR_STAGE_NAME, + ptransform=_ExtractTFLitePredictions( + eval_config=eval_config, + eval_shared_models={m.model_name: m for m in eval_shared_models}, + do_fn=do_fn, + ), + ) diff --git a/tensorflow_model_analysis/extractors/tflite_predict_extractor_test.py b/tensorflow_model_analysis/extractors/tflite_predict_extractor_test.py index cfd69d75ab..dffced36eb 100644 --- a/tensorflow_model_analysis/extractors/tflite_predict_extractor_test.py +++ b/tensorflow_model_analysis/extractors/tflite_predict_extractor_test.py @@ -17,23 +17,25 @@ import os import tempfile -from absl.testing import parameterized import apache_beam as beam -from apache_beam.testing import util import tensorflow as tf +from absl.testing import parameterized +from apache_beam.testing import util +from google.protobuf import text_format +from tensorflow_metadata.proto.v0 import schema_pb2 +from tfx_bsl.tfxio import test_util + from tensorflow_model_analysis import constants from tensorflow_model_analysis.api import model_eval_lib -from tensorflow_model_analysis.extractors import features_extractor -from tensorflow_model_analysis.extractors import tflite_predict_extractor +from tensorflow_model_analysis.extractors import ( + features_extractor, + tflite_predict_extractor, +) from tensorflow_model_analysis.proto import config_pb2 from tensorflow_model_analysis.utils import test_util as testutil from tensorflow_model_analysis.utils.keras_lib import tf_keras -from tfx_bsl.tfxio import test_util - -from google.protobuf import text_format -from tensorflow_metadata.proto.v0 import schema_pb2 -_TF_MAJOR_VERSION = int(tf.version.VERSION.split('.')[0]) +_TF_MAJOR_VERSION = int(tf.version.VERSION.split(".")[0]) _MULTI_MODEL_CASES = [False, True] _MULTI_OUTPUT_CASES = [False, True] @@ -43,116 +45,115 @@ def random_genenerator(): - generator: tf.random.Generator = tf.random.Generator.from_seed(42) - for unused_i in range(10): - r = { - 'input1': generator.uniform(shape=(2, 1), minval=0.0, maxval=1.0), - 'input2': generator.uniform(shape=(2, 1), minval=0.0, maxval=1.0), - 'input3': tf.constant([[b'a'], [b'b']], shape=(2, 1), dtype=tf.string), - } - yield r + generator: tf.random.Generator = tf.random.Generator.from_seed(42) + for unused_i in range(10): + r = { + "input1": generator.uniform(shape=(2, 1), minval=0.0, maxval=1.0), + "input2": generator.uniform(shape=(2, 1), minval=0.0, maxval=1.0), + "input3": tf.constant([[b"a"], [b"b"]], shape=(2, 1), dtype=tf.string), + } + yield r class TFLitePredictExtractorTest( testutil.TensorflowModelAnalysisTest, parameterized.TestCase ): - - @parameterized.parameters( - itertools.product( - _MULTI_MODEL_CASES, - _MULTI_OUTPUT_CASES, - _BYTES_FEATURE_CASES, - _QUANTIZATION_CASES, - ) - ) - def testTFlitePredictExtractorWithKerasModel( - self, multi_model, multi_output, use_bytes_feature, use_quantization - ): - input1 = tf_keras.layers.Input(shape=(1,), name='input1') - input2 = tf_keras.layers.Input(shape=(1,), name='input2') - input3 = tf_keras.layers.Input(shape=(1,), name='input3', dtype=tf.string) - inputs = [input1, input2, input3] - if use_bytes_feature: - input_layer = tf_keras.layers.concatenate( - [inputs[0], inputs[1], tf.cast(inputs[2] == 'a', tf.float32)] - ) - else: - input_layer = tf_keras.layers.concatenate([inputs[0], inputs[1]]) - output_layers = {} - output_layers['output1'] = tf_keras.layers.Dense( - 1, activation=tf.nn.sigmoid, name='output1' - )(input_layer) - if multi_output: - output_layers['output2'] = tf_keras.layers.Dense( - 1, activation=tf.nn.sigmoid, name='output2' - )(input_layer) - - model = tf_keras.models.Model(inputs, output_layers) - model.compile( - optimizer=tf_keras.optimizers.Adam(lr=0.001), - loss=tf_keras.losses.binary_crossentropy, - metrics=['accuracy'], + @parameterized.parameters( + itertools.product( + _MULTI_MODEL_CASES, + _MULTI_OUTPUT_CASES, + _BYTES_FEATURE_CASES, + _QUANTIZATION_CASES, + ) ) + def testTFlitePredictExtractorWithKerasModel( + self, multi_model, multi_output, use_bytes_feature, use_quantization + ): + input1 = tf_keras.layers.Input(shape=(1,), name="input1") + input2 = tf_keras.layers.Input(shape=(1,), name="input2") + input3 = tf_keras.layers.Input(shape=(1,), name="input3", dtype=tf.string) + inputs = [input1, input2, input3] + if use_bytes_feature: + input_layer = tf_keras.layers.concatenate( + [inputs[0], inputs[1], tf.cast(inputs[2] == "a", tf.float32)] + ) + else: + input_layer = tf_keras.layers.concatenate([inputs[0], inputs[1]]) + output_layers = {} + output_layers["output1"] = tf_keras.layers.Dense( + 1, activation=tf.nn.sigmoid, name="output1" + )(input_layer) + if multi_output: + output_layers["output2"] = tf_keras.layers.Dense( + 1, activation=tf.nn.sigmoid, name="output2" + )(input_layer) + + model = tf_keras.models.Model(inputs, output_layers) + model.compile( + optimizer=tf_keras.optimizers.Adam(lr=0.001), + loss=tf_keras.losses.binary_crossentropy, + metrics=["accuracy"], + ) - train_features = { - 'input1': [[0.0], [1.0]], - 'input2': [[1.0], [0.0]], - 'input3': [[b'a'], [b'b']], - } - labels = {'output1': [[1], [0]]} - if multi_output: - labels['output2'] = [[1], [0]] - - example_weights = {'output1': [1.0, 0.5]} - if multi_output: - example_weights['output2'] = [1.0, 0.5] - dataset = tf.data.Dataset.from_tensor_slices( - (train_features, labels, example_weights) - ) - dataset = dataset.shuffle(buffer_size=1).repeat().batch(2) - model.fit(dataset, steps_per_epoch=1) - - converter = tf.compat.v2.lite.TFLiteConverter.from_keras_model(model) - if use_quantization: - converter.optimizations = [tf.lite.Optimize.DEFAULT] - converter.target_spec.supported_ops = [ - tf.lite.OpsSet.TFLITE_BUILTINS_INT8, - tf.lite.OpsSet.SELECT_TF_OPS, - ] - converter.inference_input_type = tf.uint8 - converter.inference_output_type = tf.uint8 - converter.representative_dataset = random_genenerator - tflite_model = converter.convert() - - tflite_model_dir = tempfile.mkdtemp() - with tf.io.gfile.GFile(os.path.join(tflite_model_dir, 'tflite'), 'wb') as f: - f.write(tflite_model) - - model_specs = [config_pb2.ModelSpec(name='model1', model_type='tf_lite')] - if multi_model: - model_specs.append( - config_pb2.ModelSpec(name='model2', model_type='tf_lite') - ) - - eval_config = config_pb2.EvalConfig(model_specs=model_specs) - eval_shared_models = [ - self.createTestEvalSharedModel( - model_name='model1', - model_path=tflite_model_dir, - model_type='tf_lite', + train_features = { + "input1": [[0.0], [1.0]], + "input2": [[1.0], [0.0]], + "input3": [[b"a"], [b"b"]], + } + labels = {"output1": [[1], [0]]} + if multi_output: + labels["output2"] = [[1], [0]] + + example_weights = {"output1": [1.0, 0.5]} + if multi_output: + example_weights["output2"] = [1.0, 0.5] + dataset = tf.data.Dataset.from_tensor_slices( + (train_features, labels, example_weights) ) - ] - if multi_model: - eval_shared_models.append( - self.createTestEvalSharedModel( - model_name='model2', - model_path=tflite_model_dir, - model_type='tf_lite', - ) - ) - - schema = text_format.Parse( - """ + dataset = dataset.shuffle(buffer_size=1).repeat().batch(2) + model.fit(dataset, steps_per_epoch=1) + + converter = tf.compat.v2.lite.TFLiteConverter.from_keras_model(model) + if use_quantization: + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.target_spec.supported_ops = [ + tf.lite.OpsSet.TFLITE_BUILTINS_INT8, + tf.lite.OpsSet.SELECT_TF_OPS, + ] + converter.inference_input_type = tf.uint8 + converter.inference_output_type = tf.uint8 + converter.representative_dataset = random_genenerator + tflite_model = converter.convert() + + tflite_model_dir = tempfile.mkdtemp() + with tf.io.gfile.GFile(os.path.join(tflite_model_dir, "tflite"), "wb") as f: + f.write(tflite_model) + + model_specs = [config_pb2.ModelSpec(name="model1", model_type="tf_lite")] + if multi_model: + model_specs.append( + config_pb2.ModelSpec(name="model2", model_type="tf_lite") + ) + + eval_config = config_pb2.EvalConfig(model_specs=model_specs) + eval_shared_models = [ + self.createTestEvalSharedModel( + model_name="model1", + model_path=tflite_model_dir, + model_type="tf_lite", + ) + ] + if multi_model: + eval_shared_models.append( + self.createTestEvalSharedModel( + model_name="model2", + model_path=tflite_model_dir, + model_type="tf_lite", + ) + ) + + schema = text_format.Parse( + """ feature { name: "input1" type: FLOAT @@ -170,64 +171,62 @@ def testTFlitePredictExtractorWithKerasModel( type: INT } """, - schema_pb2.Schema(), - ) - tfx_io = test_util.InMemoryTFExampleRecord( - schema=schema, raw_record_column_name=constants.ARROW_INPUT_COLUMN - ) - feature_extractor = features_extractor.FeaturesExtractor(eval_config) - predictor = tflite_predict_extractor.TFLitePredictExtractor( - eval_config=eval_config, eval_shared_model=eval_shared_models - ) + schema_pb2.Schema(), + ) + tfx_io = test_util.InMemoryTFExampleRecord( + schema=schema, raw_record_column_name=constants.ARROW_INPUT_COLUMN + ) + feature_extractor = features_extractor.FeaturesExtractor(eval_config) + predictor = tflite_predict_extractor.TFLitePredictExtractor( + eval_config=eval_config, eval_shared_model=eval_shared_models + ) - examples = [ - self._makeExample( - input1=0.0, input2=1.0, input3=b'a', non_model_feature=0 - ), - self._makeExample( - input1=1.0, input2=0.0, input3=b'b', non_model_feature=1 - ), - ] - - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'Create' - >> beam.Create( - [e.SerializeToString() for e in examples], reshuffle=False - ) - | 'BatchExamples' >> tfx_io.BeamSource(batch_size=2) - | 'InputsToExtracts' >> model_eval_lib.BatchedInputsToExtracts() - | feature_extractor.stage_name >> feature_extractor.ptransform - | predictor.stage_name >> predictor.ptransform - ) - - # pylint: enable=no-value-for-parameter - - def check_result(got): - try: - self.assertLen(got, 1) - got = got[0] - self.assertIn(constants.PREDICTIONS_KEY, got) - for model in ('model1', 'model2') if multi_model else (''): - per_model_result = got[constants.PREDICTIONS_KEY] - if model: - self.assertIn(model, per_model_result) - per_model_result = per_model_result[model] - for output in ('Identity', 'Identity_1') if multi_output else (''): - per_output_result = per_model_result - if output: - self.assertIn(output, per_output_result) - per_output_result = per_output_result[output] - self.assertLen(per_output_result, 2) - - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(result, check_result, label='result') - - -if __name__ == '__main__': - tf.compat.v1.enable_v2_behavior() - tf.test.main() + examples = [ + self._makeExample(input1=0.0, input2=1.0, input3=b"a", non_model_feature=0), + self._makeExample(input1=1.0, input2=0.0, input3=b"b", non_model_feature=1), + ] + + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "Create" + >> beam.Create( + [e.SerializeToString() for e in examples], reshuffle=False + ) + | "BatchExamples" >> tfx_io.BeamSource(batch_size=2) + | "InputsToExtracts" >> model_eval_lib.BatchedInputsToExtracts() + | feature_extractor.stage_name >> feature_extractor.ptransform + | predictor.stage_name >> predictor.ptransform + ) + + # pylint: enable=no-value-for-parameter + + def check_result(got): + try: + self.assertLen(got, 1) + got = got[0] + self.assertIn(constants.PREDICTIONS_KEY, got) + for model in ("model1", "model2") if multi_model else (""): + per_model_result = got[constants.PREDICTIONS_KEY] + if model: + self.assertIn(model, per_model_result) + per_model_result = per_model_result[model] + for output in ( + ("Identity", "Identity_1") if multi_output else ("") + ): + per_output_result = per_model_result + if output: + self.assertIn(output, per_output_result) + per_output_result = per_output_result[output] + self.assertLen(per_output_result, 2) + + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(result, check_result, label="result") + + +if __name__ == "__main__": + tf.compat.v1.enable_v2_behavior() + tf.test.main() diff --git a/tensorflow_model_analysis/extractors/tfx_bsl_predictions_extractor.py b/tensorflow_model_analysis/extractors/tfx_bsl_predictions_extractor.py index ad113dac28..f2214b5a4a 100644 --- a/tensorflow_model_analysis/extractors/tfx_bsl_predictions_extractor.py +++ b/tensorflow_model_analysis/extractors/tfx_bsl_predictions_extractor.py @@ -18,20 +18,21 @@ import apache_beam as beam import tensorflow as tf +from tensorflow_serving.apis import prediction_log_pb2 +from tfx_bsl.public.beam import run_inference +from tfx_bsl.public.proto import model_spec_pb2 + from tensorflow_model_analysis import constants from tensorflow_model_analysis.api import types -from tensorflow_model_analysis.extractors import extractor -from tensorflow_model_analysis.extractors import inference_base -from tensorflow_model_analysis.extractors import predictions_extractor +from tensorflow_model_analysis.extractors import ( + extractor, + inference_base, + predictions_extractor, +) from tensorflow_model_analysis.proto import config_pb2 from tensorflow_model_analysis.utils import model_util -from tfx_bsl.public.beam import run_inference -from tfx_bsl.public.proto import model_spec_pb2 - -from tensorflow_serving.apis import prediction_log_pb2 - -_K = TypeVar('_K') +_K = TypeVar("_K") PossibleInputTypes = Union[tf.train.Example, bytes] KeyAndOutput = Tuple[_K, PossibleInputTypes] MapModelNameToOutput = Dict[str, prediction_log_pb2.PredictionLog] @@ -40,62 +41,63 @@ # TODO(b/372308995) Add new tests based on Keras models. class TfxBslInferenceWrapper(beam.PTransform): - """Wrapper for TFX-BSL bulk inference implementation.""" - - def __init__( - self, - model_specs: List[config_pb2.ModelSpec], - name_to_eval_shared_model: Dict[str, types.EvalSharedModel], - ): - """Converts TFMA config into library-specific configuration. - - Args: - model_specs: TFMA ModelSpec config to be translated to TFX-BSL Config. - name_to_eval_shared_model: Map of model name to associated EvalSharedModel - object. - """ - super().__init__() - model_names = [] - inference_specs = [] - for model_spec in model_specs: - eval_shared_model = model_util.get_eval_shared_model( - model_spec.name, name_to_eval_shared_model - ) - inference_spec_type = model_spec_pb2.InferenceSpecType( - saved_model_spec=model_spec_pb2.SavedModelSpec( - model_path=eval_shared_model.model_path, - tag=eval_shared_model.model_loader.tags, - signature_name=[model_spec.signature_name], - ), - batch_parameters=model_spec_pb2.BatchParameters( - min_batch_size=model_spec.inference_batch_size, - max_batch_size=model_spec.inference_batch_size, - ), - ) - model_names.append(model_spec.name) - inference_specs.append(inference_spec_type) - self._aligned_model_names = tuple(model_names) - self._aligned_inference_specs = tuple(inference_specs) - - def expand( - self, pcoll: beam.PCollection[KeyAndOutput] - ) -> beam.PCollection[KeyAndOutputMap]: - # TODO(b/241022420): Set load_override_fn here to avoid loading the model - # twice. - return ( - pcoll - | 'TfxBslBulkInference' - >> run_inference.RunInferencePerModel( - inference_spec_types=self._aligned_inference_specs - ) - | 'CreateModelNameToPredictionLog' - >> beam.MapTuple( - lambda extracts, logs: ( # pylint: disable=g-long-lambda - extracts, - dict(zip(self._aligned_model_names, logs)), + """Wrapper for TFX-BSL bulk inference implementation.""" + + def __init__( + self, + model_specs: List[config_pb2.ModelSpec], + name_to_eval_shared_model: Dict[str, types.EvalSharedModel], + ): + """Converts TFMA config into library-specific configuration. + + Args: + ---- + model_specs: TFMA ModelSpec config to be translated to TFX-BSL Config. + name_to_eval_shared_model: Map of model name to associated EvalSharedModel + object. + """ + super().__init__() + model_names = [] + inference_specs = [] + for model_spec in model_specs: + eval_shared_model = model_util.get_eval_shared_model( + model_spec.name, name_to_eval_shared_model + ) + inference_spec_type = model_spec_pb2.InferenceSpecType( + saved_model_spec=model_spec_pb2.SavedModelSpec( + model_path=eval_shared_model.model_path, + tag=eval_shared_model.model_loader.tags, + signature_name=[model_spec.signature_name], + ), + batch_parameters=model_spec_pb2.BatchParameters( + min_batch_size=model_spec.inference_batch_size, + max_batch_size=model_spec.inference_batch_size, + ), + ) + model_names.append(model_spec.name) + inference_specs.append(inference_spec_type) + self._aligned_model_names = tuple(model_names) + self._aligned_inference_specs = tuple(inference_specs) + + def expand( + self, pcoll: beam.PCollection[KeyAndOutput] + ) -> beam.PCollection[KeyAndOutputMap]: + # TODO(b/241022420): Set load_override_fn here to avoid loading the model + # twice. + return ( + pcoll + | "TfxBslBulkInference" + >> run_inference.RunInferencePerModel( + inference_spec_types=self._aligned_inference_specs + ) + | "CreateModelNameToPredictionLog" + >> beam.MapTuple( + lambda extracts, logs: ( # pylint: disable=g-long-lambda + extracts, + dict(zip(self._aligned_model_names, logs)), + ) ) ) - ) def TfxBslPredictionsExtractor( @@ -104,69 +106,71 @@ def TfxBslPredictionsExtractor( output_batch_size: Optional[int] = None, output_keypath: Iterable[str] = (constants.PREDICTIONS_KEY,), ) -> extractor.Extractor: - """Creates an extractor for performing predictions over a batch. - - The extractor's PTransform loads and runs the serving saved_model(s) against - every Extracts yielding a copy of the incoming Extracts with an additional - Extracts added for the predictions keyed by tfma.PREDICTIONS_KEY. The model - inputs are searched for under tfma.FEATURES_KEY (keras only) or tfma.INPUT_KEY - (if tfma.FEATURES_KEY is not set or the model is non-keras). If multiple - models are used the predictions will be stored in a dict keyed by model name. - - Note that the prediction_key in the ModelSpecs also serves as a key into the - dict of the prediction's output. - - Args: - eval_config: Eval config. - eval_shared_model: Shared model (single-model evaluation) or list of shared - models (multi-model evaluation). - output_batch_size: Sets a static output batch size for bulk inference. Note: - this only affects the rebatched output batch size to set inference batch - size set ModelSpec.inference_batch_size. - output_keypath: A sequence of keys to be used as the path to traverse and - insert the outputs in the extract. - - Returns: - Extractor for extracting predictions. - """ - eval_shared_models = model_util.verify_and_update_eval_shared_models( - eval_shared_model - ) - # This should never happen, but verify_and_update_eval_shared_models can - # theoretically return None or empty iterables. - if not eval_shared_models: - raise ValueError( - 'No valid model(s) were provided. Please ensure that ' - 'EvalConfig.ModelSpec is correctly configured to enable ' - 'using the PredictionsExtractor.' + """Creates an extractor for performing predictions over a batch. + + The extractor's PTransform loads and runs the serving saved_model(s) against + every Extracts yielding a copy of the incoming Extracts with an additional + Extracts added for the predictions keyed by tfma.PREDICTIONS_KEY. The model + inputs are searched for under tfma.FEATURES_KEY (keras only) or tfma.INPUT_KEY + (if tfma.FEATURES_KEY is not set or the model is non-keras). If multiple + models are used the predictions will be stored in a dict keyed by model name. + + Note that the prediction_key in the ModelSpecs also serves as a key into the + dict of the prediction's output. + + Args: + ---- + eval_config: Eval config. + eval_shared_model: Shared model (single-model evaluation) or list of shared + models (multi-model evaluation). + output_batch_size: Sets a static output batch size for bulk inference. Note: + this only affects the rebatched output batch size to set inference batch + size set ModelSpec.inference_batch_size. + output_keypath: A sequence of keys to be used as the path to traverse and + insert the outputs in the extract. + + Returns: + ------- + Extractor for extracting predictions. + """ + eval_shared_models = model_util.verify_and_update_eval_shared_models( + eval_shared_model ) + # This should never happen, but verify_and_update_eval_shared_models can + # theoretically return None or empty iterables. + if not eval_shared_models: + raise ValueError( + "No valid model(s) were provided. Please ensure that " + "EvalConfig.ModelSpec is correctly configured to enable " + "using the PredictionsExtractor." + ) - name_to_eval_shared_model = {m.model_name: m for m in eval_shared_models} - model_specs = [] - for model_spec in eval_config.model_specs: - if not model_spec.signature_name: - eval_shared_model = model_util.get_eval_shared_model( - model_spec.name, name_to_eval_shared_model - ) - model_spec = copy.copy(model_spec) - # Select a default signature. Note that this may differ from the - # 'serving_default' signature. - model_spec.signature_name = ( - model_util.get_default_signature_name_from_model_path( - eval_shared_model.model_path - ) - ) - model_specs.append(model_spec) - - tfx_bsl_inference_ptransform = inference_base.RunInference( - inference_ptransform=TfxBslInferenceWrapper( - model_specs, name_to_eval_shared_model - ), - output_batch_size=output_batch_size, - output_keypath=output_keypath, - ) - # pylint: disable=no-value-for-parameter - return extractor.Extractor( - stage_name=predictions_extractor.PREDICTIONS_EXTRACTOR_STAGE_NAME, - ptransform=tfx_bsl_inference_ptransform, - ) + name_to_eval_shared_model = {m.model_name: m for m in eval_shared_models} + model_specs = [] + for model_spec in eval_config.model_specs: + if not model_spec.signature_name: + eval_shared_model = model_util.get_eval_shared_model( + model_spec.name, name_to_eval_shared_model + ) + model_spec = copy.copy(model_spec) + # Select a default signature. Note that this may differ from the + # 'serving_default' signature. + model_spec.signature_name = ( + model_util.get_default_signature_name_from_model_path( + eval_shared_model.model_path + ) + ) + model_specs.append(model_spec) + + tfx_bsl_inference_ptransform = inference_base.RunInference( + inference_ptransform=TfxBslInferenceWrapper( + model_specs, name_to_eval_shared_model + ), + output_batch_size=output_batch_size, + output_keypath=output_keypath, + ) + # pylint: disable=no-value-for-parameter + return extractor.Extractor( + stage_name=predictions_extractor.PREDICTIONS_EXTRACTOR_STAGE_NAME, + ptransform=tfx_bsl_inference_ptransform, + ) diff --git a/tensorflow_model_analysis/extractors/transformed_features_extractor.py b/tensorflow_model_analysis/extractors/transformed_features_extractor.py index ff499bb070..c522304a48 100644 --- a/tensorflow_model_analysis/extractors/transformed_features_extractor.py +++ b/tensorflow_model_analysis/extractors/transformed_features_extractor.py @@ -16,49 +16,52 @@ from typing import Dict import apache_beam as beam + from tensorflow_model_analysis import constants from tensorflow_model_analysis.api import types from tensorflow_model_analysis.extractors import extractor from tensorflow_model_analysis.proto import config_pb2 from tensorflow_model_analysis.utils import model_util -_TRANSFORMED_FEATURES_EXTRACTOR_STAGE_NAME = 'ExtractTransformedFeatures' +_TRANSFORMED_FEATURES_EXTRACTOR_STAGE_NAME = "ExtractTransformedFeatures" # TODO(b/173029091): Re-add tft_layer. -_DEFAULT_SIGNATURE_NAMES = ('transformed_features', 'transformed_labels') +_DEFAULT_SIGNATURE_NAMES = ("transformed_features", "transformed_labels") def TransformedFeaturesExtractor( eval_config: config_pb2.EvalConfig, eval_shared_model: types.MaybeMultipleEvalSharedModels, ) -> extractor.Extractor: - """Creates an extractor for extracting transformed features. - - The extractor's PTransform loads the saved_model(s) invoking the preprocessing - functions against every extract yielding a copy of the incoming extracts with - a tfma.TRANSFORMED_FEATURES_KEY containing the output from the preprocessing - functions. - - Args: - eval_config: Eval config. - eval_shared_model: Shared model (single-model evaluation) or list of shared - models (multi-model evaluation). - - Returns: - Extractor for extracting preprocessed features. - """ - eval_shared_models = model_util.verify_and_update_eval_shared_models( - eval_shared_model - ) - - # pylint: disable=no-value-for-parameter - return extractor.Extractor( - stage_name=_TRANSFORMED_FEATURES_EXTRACTOR_STAGE_NAME, - ptransform=_ExtractTransformedFeatures( - eval_config=eval_config, - eval_shared_models={m.model_name: m for m in eval_shared_models}, - ), - ) + """Creates an extractor for extracting transformed features. + + The extractor's PTransform loads the saved_model(s) invoking the preprocessing + functions against every extract yielding a copy of the incoming extracts with + a tfma.TRANSFORMED_FEATURES_KEY containing the output from the preprocessing + functions. + + Args: + ---- + eval_config: Eval config. + eval_shared_model: Shared model (single-model evaluation) or list of shared + models (multi-model evaluation). + + Returns: + ------- + Extractor for extracting preprocessed features. + """ + eval_shared_models = model_util.verify_and_update_eval_shared_models( + eval_shared_model + ) + + # pylint: disable=no-value-for-parameter + return extractor.Extractor( + stage_name=_TRANSFORMED_FEATURES_EXTRACTOR_STAGE_NAME, + ptransform=_ExtractTransformedFeatures( + eval_config=eval_config, + eval_shared_models={m.model_name: m for m in eval_shared_models}, + ), + ) @beam.ptransform_fn @@ -69,31 +72,33 @@ def _ExtractTransformedFeatures( # pylint: disable=invalid-name eval_config: config_pb2.EvalConfig, eval_shared_models: Dict[str, types.EvalSharedModel], ) -> beam.pvalue.PCollection: - """A PTransform that updates extracts to include transformed features. - - Args: - extracts: PCollection of extracts containing raw inputs keyed by - tfma.FEATURES_KEY (if preprocessing function inputs are named) or - tfma.INPUTS_KEY (if preprocessing functions take raw tf.Examples as input) - eval_config: Eval config. - eval_shared_models: Shared model parameters keyed by model name. - - Returns: - PCollection of Extracts updated with the to include transformed features - stored under the key tfma.TRANSFORMED_FEATURES_KEY. - """ - signature_names = {} - for spec in eval_config.model_specs: - model_name = '' if len(eval_config.model_specs) == 1 else spec.name - signature_names[model_name] = list(spec.preprocessing_function_names) - - return extracts | 'Predict' >> beam.ParDo( - model_util.ModelSignaturesDoFn( - model_specs=eval_config.model_specs, - eval_shared_models=eval_shared_models, - output_keypath=(constants.TRANSFORMED_FEATURES_KEY,), - signature_names=signature_names, - default_signature_names=list(_DEFAULT_SIGNATURE_NAMES), - prefer_dict_outputs=True, - ) - ) + """A PTransform that updates extracts to include transformed features. + + Args: + ---- + extracts: PCollection of extracts containing raw inputs keyed by + tfma.FEATURES_KEY (if preprocessing function inputs are named) or + tfma.INPUTS_KEY (if preprocessing functions take raw tf.Examples as input) + eval_config: Eval config. + eval_shared_models: Shared model parameters keyed by model name. + + Returns: + ------- + PCollection of Extracts updated with the to include transformed features + stored under the key tfma.TRANSFORMED_FEATURES_KEY. + """ + signature_names = {} + for spec in eval_config.model_specs: + model_name = "" if len(eval_config.model_specs) == 1 else spec.name + signature_names[model_name] = list(spec.preprocessing_function_names) + + return extracts | "Predict" >> beam.ParDo( + model_util.ModelSignaturesDoFn( + model_specs=eval_config.model_specs, + eval_shared_models=eval_shared_models, + output_keypath=(constants.TRANSFORMED_FEATURES_KEY,), + signature_names=signature_names, + default_signature_names=list(_DEFAULT_SIGNATURE_NAMES), + prefer_dict_outputs=True, + ) + ) diff --git a/tensorflow_model_analysis/extractors/transformed_features_extractor_test.py b/tensorflow_model_analysis/extractors/transformed_features_extractor_test.py index 1a6f1c6f31..bf36508cf7 100644 --- a/tensorflow_model_analysis/extractors/transformed_features_extractor_test.py +++ b/tensorflow_model_analysis/extractors/transformed_features_extractor_test.py @@ -16,33 +16,33 @@ import tempfile import unittest -from absl.testing import parameterized import apache_beam as beam -from apache_beam.testing import util import tensorflow as tf +from absl.testing import parameterized +from apache_beam.testing import util +from google.protobuf import text_format +from tensorflow_metadata.proto.v0 import schema_pb2 +from tfx_bsl.tfxio import tensor_adapter, test_util + from tensorflow_model_analysis import constants from tensorflow_model_analysis.api import model_eval_lib -from tensorflow_model_analysis.extractors import features_extractor -from tensorflow_model_analysis.extractors import transformed_features_extractor +from tensorflow_model_analysis.extractors import ( + features_extractor, + transformed_features_extractor, +) from tensorflow_model_analysis.proto import config_pb2 from tensorflow_model_analysis.utils import test_util as testutil from tensorflow_model_analysis.utils.keras_lib import tf_keras -from tfx_bsl.tfxio import tensor_adapter -from tfx_bsl.tfxio import test_util -from google.protobuf import text_format -from tensorflow_metadata.proto.v0 import schema_pb2 - -_TF_MAJOR_VERSION = int(tf.version.VERSION.split('.')[0]) +_TF_MAJOR_VERSION = int(tf.version.VERSION.split(".")[0]) class TransformedFeaturesExtractorTest( testutil.TensorflowModelAnalysisTest, parameterized.TestCase ): - - def createDenseInputsSchema(self): - return text_format.Parse( - """ + def createDenseInputsSchema(self): + return text_format.Parse( + """ tensor_representation_group { key: "" value { @@ -79,234 +79,226 @@ def createDenseInputsSchema(self): type: INT } """, - schema_pb2.Schema(), - ) + schema_pb2.Schema(), + ) - def createModelWithMultipleDenseInputs(self, save_as_keras): - input1 = tf_keras.layers.Input(shape=(1,), name='input_1') - input2 = tf_keras.layers.Input(shape=(1,), name='input_2') - inputs = [input1, input2] - input_layer = tf_keras.layers.concatenate(inputs) - output_layer = tf_keras.layers.Dense( - 1, activation=tf.nn.sigmoid, name='output' - )(input_layer) - model = tf_keras.models.Model(inputs, output_layer) + def createModelWithMultipleDenseInputs(self, save_as_keras): + input1 = tf_keras.layers.Input(shape=(1,), name="input_1") + input2 = tf_keras.layers.Input(shape=(1,), name="input_2") + inputs = [input1, input2] + input_layer = tf_keras.layers.concatenate(inputs) + output_layer = tf_keras.layers.Dense( + 1, activation=tf.nn.sigmoid, name="output" + )(input_layer) + model = tf_keras.models.Model(inputs, output_layer) - # Add tft_layer to model to test callables stored as attributes - model.tft_layer = tf_keras.models.Model( - inputs, {'tft_feature': output_layer, 'tft_label': output_layer} - ) + # Add tft_layer to model to test callables stored as attributes + model.tft_layer = tf_keras.models.Model( + inputs, {"tft_feature": output_layer, "tft_label": output_layer} + ) - @tf.function - def serving_default(serialized_tf_examples): - parsed_features = tf.io.parse_example( - serialized_tf_examples, - { - 'input_1': tf.io.FixedLenFeature([1], dtype=tf.float32), - 'input_2': tf.io.FixedLenFeature([1], dtype=tf.float32), - }, - ) - return model(parsed_features) + @tf.function + def serving_default(serialized_tf_examples): + parsed_features = tf.io.parse_example( + serialized_tf_examples, + { + "input_1": tf.io.FixedLenFeature([1], dtype=tf.float32), + "input_2": tf.io.FixedLenFeature([1], dtype=tf.float32), + }, + ) + return model(parsed_features) - @tf.function - def transformed_features(features): - return { - 'transformed_feature': features['input_1'], - } + @tf.function + def transformed_features(features): + return { + "transformed_feature": features["input_1"], + } - @tf.function - def transformed_labels(features): - return {'transformed_label': features['input_2']} + @tf.function + def transformed_labels(features): + return {"transformed_label": features["input_2"]} - @tf.function - def custom_preprocessing(features): - return { - 'custom_feature': features['input_1'], - 'custom_label': features['input_2'], - } + @tf.function + def custom_preprocessing(features): + return { + "custom_feature": features["input_1"], + "custom_label": features["input_2"], + } - single_input_spec = tf.TensorSpec( - shape=(None,), dtype=tf.string, name='examples' - ) - multi_input_spec = { - 'input_1': tf.TensorSpec( - shape=(None, 1), dtype=tf.float32, name='input_1' - ), - 'input_2': tf.TensorSpec( - shape=(None, 1), dtype=tf.float32, name='input_2' - ), - } - signatures = { - 'serving_default': serving_default.get_concrete_function( - single_input_spec + single_input_spec = tf.TensorSpec( + shape=(None,), dtype=tf.string, name="examples" + ) + multi_input_spec = { + "input_1": tf.TensorSpec(shape=(None, 1), dtype=tf.float32, name="input_1"), + "input_2": tf.TensorSpec(shape=(None, 1), dtype=tf.float32, name="input_2"), + } + signatures = { + "serving_default": serving_default.get_concrete_function(single_input_spec), + "transformed_labels": transformed_labels.get_concrete_function( + multi_input_spec + ), + "transformed_features": transformed_features.get_concrete_function( + multi_input_spec + ), + "custom_preprocessing": custom_preprocessing.get_concrete_function( + multi_input_spec + ), + } + + export_path = tempfile.mkdtemp() + if save_as_keras: + model.save(export_path, save_format="tf", signatures=signatures) + else: + tf.saved_model.save(model, export_path, signatures=signatures) + return export_path + + @parameterized.named_parameters( + ( + "keras_defaults", + True, + [], + { + "features": [ + "input_1", # raw feature + "input_2", # raw feature + "non_model_feature", # from schema + ], + "transformed_features": [ + # TODO(b/173029091): Re-add tft_layer + # 'tft_feature', # added by tft_layer + # 'tft_label', # added by tft_layer + "transformed_feature", # added by transformed_features + "transformed_label", # added by transformed_labels + ], + }, ), - 'transformed_labels': transformed_labels.get_concrete_function( - multi_input_spec + ( + "tf_defaults", + False, + [], + { + "features": [ + "input_1", # raw feature + "input_2", # raw feature + "non_model_feature", # from schema + ], + "transformed_features": [ + # TODO(b/173029091): Re-add tft_layer + # 'tft_feature', # added by tft_layer + # 'tft_label', # added by tft_layer + "transformed_feature", # added by transformed_features + "transformed_label", # added by transformed_labels + ], + }, ), - 'transformed_features': transformed_features.get_concrete_function( - multi_input_spec + ( + "keras_custom", + True, + ["custom_preprocessing"], + { + "features": [ + "input_1", # raw feature + "input_2", # raw feature + "non_model_feature", # from schema + ], + "transformed_features": [ + "custom_feature", # added by custom_preprocessing + "custom_label", # added by custom_preprocessing + ], + }, ), - 'custom_preprocessing': custom_preprocessing.get_concrete_function( - multi_input_spec + ( + "tf_custom", + False, + ["custom_preprocessing"], + { + "features": [ + "input_1", # raw feature + "input_2", # raw feature + "non_model_feature", # from schema + ], + "transformed_features": [ + "custom_feature", # added by custom_preprocessing + "custom_label", # added by custom_preprocessing + ], + }, ), - } - - export_path = tempfile.mkdtemp() - if save_as_keras: - model.save(export_path, save_format='tf', signatures=signatures) - else: - tf.saved_model.save(model, export_path, signatures=signatures) - return export_path - - @parameterized.named_parameters( - ( - 'keras_defaults', - True, - [], - { - 'features': [ - 'input_1', # raw feature - 'input_2', # raw feature - 'non_model_feature', # from schema - ], - 'transformed_features': [ - # TODO(b/173029091): Re-add tft_layer - # 'tft_feature', # added by tft_layer - # 'tft_label', # added by tft_layer - 'transformed_feature', # added by transformed_features - 'transformed_label', # added by transformed_labels - ], - }, - ), - ( - 'tf_defaults', - False, - [], - { - 'features': [ - 'input_1', # raw feature - 'input_2', # raw feature - 'non_model_feature', # from schema - ], - 'transformed_features': [ - # TODO(b/173029091): Re-add tft_layer - # 'tft_feature', # added by tft_layer - # 'tft_label', # added by tft_layer - 'transformed_feature', # added by transformed_features - 'transformed_label', # added by transformed_labels - ], - }, - ), - ( - 'keras_custom', - True, - ['custom_preprocessing'], - { - 'features': [ - 'input_1', # raw feature - 'input_2', # raw feature - 'non_model_feature', # from schema - ], - 'transformed_features': [ - 'custom_feature', # added by custom_preprocessing - 'custom_label', # added by custom_preprocessing - ], - }, - ), - ( - 'tf_custom', - False, - ['custom_preprocessing'], - { - 'features': [ - 'input_1', # raw feature - 'input_2', # raw feature - 'non_model_feature', # from schema - ], - 'transformed_features': [ - 'custom_feature', # added by custom_preprocessing - 'custom_label', # added by custom_preprocessing - ], - }, - ), - ) - @unittest.skipIf( - _TF_MAJOR_VERSION < 2, 'not all signatures supported for TF1' - ) - def testPreprocessedFeaturesExtractor( - self, save_as_keras, preprocessing_function_names, expected_extract_keys - ): - export_path = self.createModelWithMultipleDenseInputs(save_as_keras) + ) + @unittest.skipIf(_TF_MAJOR_VERSION < 2, "not all signatures supported for TF1") + def testPreprocessedFeaturesExtractor( + self, save_as_keras, preprocessing_function_names, expected_extract_keys + ): + export_path = self.createModelWithMultipleDenseInputs(save_as_keras) - eval_config = config_pb2.EvalConfig( - model_specs=[ - config_pb2.ModelSpec( - preprocessing_function_names=preprocessing_function_names + eval_config = config_pb2.EvalConfig( + model_specs=[ + config_pb2.ModelSpec( + preprocessing_function_names=preprocessing_function_names + ) + ] + ) + eval_shared_model = self.createKerasTestEvalSharedModel( + eval_saved_model_path=export_path, eval_config=eval_config + ) + schema = self.createDenseInputsSchema() + tfx_io = test_util.InMemoryTFExampleRecord( + schema=schema, raw_record_column_name=constants.ARROW_INPUT_COLUMN + ) + tensor_adapter_config = tensor_adapter.TensorAdapterConfig( + arrow_schema=tfx_io.ArrowSchema(), + tensor_representations=tfx_io.TensorRepresentations(), + ) + feature_extractor = features_extractor.FeaturesExtractor( + eval_config=eval_config, + tensor_representations=tensor_adapter_config.tensor_representations, + ) + transformation_extractor = ( + transformed_features_extractor.TransformedFeaturesExtractor( + eval_config=eval_config, eval_shared_model=eval_shared_model ) - ] - ) - eval_shared_model = self.createKerasTestEvalSharedModel( - eval_saved_model_path=export_path, eval_config=eval_config - ) - schema = self.createDenseInputsSchema() - tfx_io = test_util.InMemoryTFExampleRecord( - schema=schema, raw_record_column_name=constants.ARROW_INPUT_COLUMN - ) - tensor_adapter_config = tensor_adapter.TensorAdapterConfig( - arrow_schema=tfx_io.ArrowSchema(), - tensor_representations=tfx_io.TensorRepresentations(), - ) - feature_extractor = features_extractor.FeaturesExtractor( - eval_config=eval_config, - tensor_representations=tensor_adapter_config.tensor_representations, - ) - transformation_extractor = ( - transformed_features_extractor.TransformedFeaturesExtractor( - eval_config=eval_config, eval_shared_model=eval_shared_model ) - ) - examples = [ - self._makeExample(input_1=1.0, input_2=2.0), - self._makeExample(input_1=3.0, input_2=4.0), - self._makeExample(input_1=5.0, input_2=6.0), - ] + examples = [ + self._makeExample(input_1=1.0, input_2=2.0), + self._makeExample(input_1=3.0, input_2=4.0), + self._makeExample(input_1=5.0, input_2=6.0), + ] - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'Create' - >> beam.Create( - [e.SerializeToString() for e in examples], reshuffle=False - ) - | 'BatchExamples' >> tfx_io.BeamSource(batch_size=2) - | 'InputsToExtracts' >> model_eval_lib.BatchedInputsToExtracts() - | feature_extractor.stage_name >> feature_extractor.ptransform - | transformation_extractor.stage_name - >> transformation_extractor.ptransform - ) + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "Create" + >> beam.Create( + [e.SerializeToString() for e in examples], reshuffle=False + ) + | "BatchExamples" >> tfx_io.BeamSource(batch_size=2) + | "InputsToExtracts" >> model_eval_lib.BatchedInputsToExtracts() + | feature_extractor.stage_name >> feature_extractor.ptransform + | transformation_extractor.stage_name + >> transformation_extractor.ptransform + ) - # pylint: enable=no-value-for-parameter + # pylint: enable=no-value-for-parameter - def check_result(batches): - try: - self.assertLen(batches, 2) - for got in batches: - for extracts_key, feature_keys in expected_extract_keys.items(): - self.assertIn(extracts_key, got) - self.assertCountEqual( - set(feature_keys), - set(got[extracts_key]), - msg=f'got[{extracts_key}]={got[extracts_key]}', - ) + def check_result(batches): + try: + self.assertLen(batches, 2) + for got in batches: + for extracts_key, feature_keys in expected_extract_keys.items(): + self.assertIn(extracts_key, got) + self.assertCountEqual( + set(feature_keys), + set(got[extracts_key]), + msg=f"got[{extracts_key}]={got[extracts_key]}", + ) - except AssertionError as err: - raise util.BeamAssertException(err) + except AssertionError as err: + raise util.BeamAssertException(err) - util.assert_that(result, check_result, label='result') + util.assert_that(result, check_result, label="result") -if __name__ == '__main__': - tf.compat.v1.enable_v2_behavior() - tf.test.main() +if __name__ == "__main__": + tf.compat.v1.enable_v2_behavior() + tf.test.main() diff --git a/tensorflow_model_analysis/extractors/unbatch_extractor.py b/tensorflow_model_analysis/extractors/unbatch_extractor.py index a3683799c6..2d6d0309c5 100644 --- a/tensorflow_model_analysis/extractors/unbatch_extractor.py +++ b/tensorflow_model_analysis/extractors/unbatch_extractor.py @@ -17,75 +17,77 @@ import apache_beam as beam import pandas as pd + from tensorflow_model_analysis import constants from tensorflow_model_analysis.api import types from tensorflow_model_analysis.extractors import extractor from tensorflow_model_analysis.utils import util -UNBATCH_EXTRACTOR_STAGE_NAME = 'ExtractUnbatchedInputs' +UNBATCH_EXTRACTOR_STAGE_NAME = "ExtractUnbatchedInputs" def UnbatchExtractor() -> extractor.Extractor: - """Creates an extractor for unbatching batched extracts. + """Creates an extractor for unbatching batched extracts. - This extractor removes Arrow RecordBatch from the batched extract and outputs - per-example extracts with the remaining keys. We assume that the remaining - keys in the input extract contain list of objects (one per example). + This extractor removes Arrow RecordBatch from the batched extract and outputs + per-example extracts with the remaining keys. We assume that the remaining + keys in the input extract contain list of objects (one per example). - Returns: - Extractor for unbatching batched extracts. - """ - # pylint: disable=no-value-for-parameter - return extractor.Extractor( - stage_name=UNBATCH_EXTRACTOR_STAGE_NAME, ptransform=_UnbatchInputs() - ) + Returns + ------- + Extractor for unbatching batched extracts. + """ + # pylint: disable=no-value-for-parameter + return extractor.Extractor( + stage_name=UNBATCH_EXTRACTOR_STAGE_NAME, ptransform=_UnbatchInputs() + ) def _extract_unbatched_inputs( # pylint: disable=invalid-name mixed_legacy_batched_extract: types.Extracts, ) -> Sequence[types.Extracts]: - """Extract features, predictions, labels and weights from batched extract.""" - batched_extract = {} - # TODO(mdreves): Remove record batch - keys_to_retain = set(mixed_legacy_batched_extract) - if constants.ARROW_RECORD_BATCH_KEY in keys_to_retain: - keys_to_retain.remove(constants.ARROW_RECORD_BATCH_KEY) - dataframe = pd.DataFrame() - for key in keys_to_retain: - # Previously a batch of transformed features were stored as a list of dicts - # instead of a dict of np.arrays with batch dimensions. These legacy - # conversions are done using dataframes instead. - if isinstance(mixed_legacy_batched_extract[key], list): - try: - dataframe[key] = mixed_legacy_batched_extract[key] - except Exception as e: - raise RuntimeError( - f'Exception encountered while adding key {key} with ' - f'batched length {len(mixed_legacy_batched_extract[key])}' - ) from e + """Extract features, predictions, labels and weights from batched extract.""" + batched_extract = {} + # TODO(mdreves): Remove record batch + keys_to_retain = set(mixed_legacy_batched_extract) + if constants.ARROW_RECORD_BATCH_KEY in keys_to_retain: + keys_to_retain.remove(constants.ARROW_RECORD_BATCH_KEY) + dataframe = pd.DataFrame() + for key in keys_to_retain: + # Previously a batch of transformed features were stored as a list of dicts + # instead of a dict of np.arrays with batch dimensions. These legacy + # conversions are done using dataframes instead. + if isinstance(mixed_legacy_batched_extract[key], list): + try: + dataframe[key] = mixed_legacy_batched_extract[key] + except Exception as e: + raise RuntimeError( + f"Exception encountered while adding key {key} with " + f"batched length {len(mixed_legacy_batched_extract[key])}" + ) from e + else: + batched_extract[key] = mixed_legacy_batched_extract[key] + unbatched_extracts = util.split_extracts(batched_extract) + legacy_unbatched_extracts = dataframe.to_dict(orient="records") + if unbatched_extracts and legacy_unbatched_extracts: + if len(unbatched_extracts) != len(legacy_unbatched_extracts): + raise ValueError( + f"Batch sizes have differing values: {len(unbatched_extracts)} != " + f"{len(legacy_unbatched_extracts)}, " + f"unbatched_extracts={unbatched_extracts}, " + f"legacy_unbatched_extracts={legacy_unbatched_extracts}" + ) + result = [] + for unbatched_extract, legacy_unbatched_extract in zip( + unbatched_extracts, legacy_unbatched_extracts + ): + legacy_unbatched_extract.update(unbatched_extract) + result.append(legacy_unbatched_extract) + return result + elif legacy_unbatched_extracts: + return legacy_unbatched_extracts else: - batched_extract[key] = mixed_legacy_batched_extract[key] - unbatched_extracts = util.split_extracts(batched_extract) - legacy_unbatched_extracts = dataframe.to_dict(orient='records') - if unbatched_extracts and legacy_unbatched_extracts: - if len(unbatched_extracts) != len(legacy_unbatched_extracts): - raise ValueError( - f'Batch sizes have differing values: {len(unbatched_extracts)} != ' - f'{len(legacy_unbatched_extracts)}, ' - f'unbatched_extracts={unbatched_extracts}, ' - f'legacy_unbatched_extracts={legacy_unbatched_extracts}' - ) - result = [] - for unbatched_extract, legacy_unbatched_extract in zip( - unbatched_extracts, legacy_unbatched_extracts - ): - legacy_unbatched_extract.update(unbatched_extract) - result.append(legacy_unbatched_extract) - return result - elif legacy_unbatched_extracts: - return legacy_unbatched_extracts - else: - return unbatched_extracts + return unbatched_extracts @beam.ptransform_fn @@ -94,12 +96,14 @@ def _extract_unbatched_inputs( # pylint: disable=invalid-name def _UnbatchInputs( extracts: beam.pvalue.PCollection, ) -> beam.pvalue.PCollection: - """Extracts unbatched inputs from batched extracts. + """Extracts unbatched inputs from batched extracts. - Args: - extracts: PCollection containing batched extracts. + Args: + ---- + extracts: PCollection containing batched extracts. - Returns: - PCollection of per-example extracts. - """ - return extracts | 'UnbatchInputs' >> beam.FlatMap(_extract_unbatched_inputs) + Returns: + ------- + PCollection of per-example extracts. + """ + return extracts | "UnbatchInputs" >> beam.FlatMap(_extract_unbatched_inputs) diff --git a/tensorflow_model_analysis/extractors/unbatch_extractor_test.py b/tensorflow_model_analysis/extractors/unbatch_extractor_test.py index a611c7ce26..8965be0411 100644 --- a/tensorflow_model_analysis/extractors/unbatch_extractor_test.py +++ b/tensorflow_model_analysis/extractors/unbatch_extractor_test.py @@ -14,72 +14,73 @@ """Test for unbatch extractor.""" import apache_beam as beam -from apache_beam.testing import util import numpy as np import tensorflow as tf +from apache_beam.testing import util +from google.protobuf import text_format +from tensorflow_metadata.proto.v0 import schema_pb2 +from tfx_bsl.tfxio import test_util + from tensorflow_model_analysis import constants from tensorflow_model_analysis.api import model_eval_lib -from tensorflow_model_analysis.extractors import example_weights_extractor -from tensorflow_model_analysis.extractors import features_extractor -from tensorflow_model_analysis.extractors import labels_extractor -from tensorflow_model_analysis.extractors import materialized_predictions_extractor -from tensorflow_model_analysis.extractors import unbatch_extractor +from tensorflow_model_analysis.extractors import ( + example_weights_extractor, + features_extractor, + labels_extractor, + materialized_predictions_extractor, + unbatch_extractor, +) from tensorflow_model_analysis.proto import config_pb2 from tensorflow_model_analysis.utils import test_util as testutil -from tfx_bsl.tfxio import test_util - -from google.protobuf import text_format -from tensorflow_metadata.proto.v0 import schema_pb2 class UnbatchExtractorTest(testutil.TensorflowModelAnalysisTest): + def testExtractUnbatchedInputsRaisesChainedException(self): + batched_extracts = { + "features": [ + { + "label": np.array([1.0]), + "fixed_int": np.array([1]), + }, + { + "label": np.array([2.0]), + "fixed_int": np.array([2]), + }, + ], + "labels": [ + np.array([1.0]), + ], + } + with self.assertRaisesRegex( + RuntimeError, + "Exception encountered while adding key .* with batched length .", + ) as ctx: + unbatch_extractor._extract_unbatched_inputs(batched_extracts) + self.assertIsInstance(ctx.exception.__cause__, ValueError) + self.assertRegex( + str(ctx.exception.__cause__), + r"Length of values \(.\) does not match length of index \(.\)", + ) - def testExtractUnbatchedInputsRaisesChainedException(self): - batched_extracts = { - 'features': [ - { - 'label': np.array([1.0]), - 'fixed_int': np.array([1]), - }, - { - 'label': np.array([2.0]), - 'fixed_int': np.array([2]), - }, - ], - 'labels': [ - np.array([1.0]), - ], - } - with self.assertRaisesRegex( - RuntimeError, - 'Exception encountered while adding key .* with batched length .', - ) as ctx: - unbatch_extractor._extract_unbatched_inputs(batched_extracts) - self.assertIsInstance(ctx.exception.__cause__, ValueError) - self.assertRegex( - str(ctx.exception.__cause__), - r'Length of values \(.\) does not match length of index \(.\)', - ) - - def testUnbatchExtractor(self): - model_spec = config_pb2.ModelSpec( - label_key='label', example_weight_key='example_weight' - ) - eval_config = config_pb2.EvalConfig(model_specs=[model_spec]) - feature_extractor = features_extractor.FeaturesExtractor(eval_config) - label_extractor = labels_extractor.LabelsExtractor(eval_config) - example_weight_extractor = ( - example_weights_extractor.ExampleWeightsExtractor(eval_config) - ) - predict_extractor = ( - materialized_predictions_extractor.MaterializedPredictionsExtractor( + def testUnbatchExtractor(self): + model_spec = config_pb2.ModelSpec( + label_key="label", example_weight_key="example_weight" + ) + eval_config = config_pb2.EvalConfig(model_specs=[model_spec]) + feature_extractor = features_extractor.FeaturesExtractor(eval_config) + label_extractor = labels_extractor.LabelsExtractor(eval_config) + example_weight_extractor = example_weights_extractor.ExampleWeightsExtractor( eval_config ) - ) - unbatch_inputs_extractor = unbatch_extractor.UnbatchExtractor() + predict_extractor = ( + materialized_predictions_extractor.MaterializedPredictionsExtractor( + eval_config + ) + ) + unbatch_inputs_extractor = unbatch_extractor.UnbatchExtractor() - schema = text_format.Parse( - """ + schema = text_format.Parse( + """ feature { name: "label" type: FLOAT @@ -101,133 +102,139 @@ def testUnbatchExtractor(self): type: BYTES } """, - schema_pb2.Schema(), - ) - tfx_io = test_util.InMemoryTFExampleRecord( - schema=schema, raw_record_column_name=constants.ARROW_INPUT_COLUMN - ) - examples = [ - self._makeExample( - label=1.0, - example_weight=0.5, - fixed_int=1, - fixed_float=1.0, - fixed_string='fixed_string1', - ), - self._makeExample( - label=0.0, - example_weight=0.0, - fixed_int=1, - fixed_float=1.0, - fixed_string='fixed_string2', - ), - self._makeExample( - label=0.0, - example_weight=1.0, - fixed_int=2, - fixed_float=0.0, - fixed_string='fixed_string3', - ), - ] + schema_pb2.Schema(), + ) + tfx_io = test_util.InMemoryTFExampleRecord( + schema=schema, raw_record_column_name=constants.ARROW_INPUT_COLUMN + ) + examples = [ + self._makeExample( + label=1.0, + example_weight=0.5, + fixed_int=1, + fixed_float=1.0, + fixed_string="fixed_string1", + ), + self._makeExample( + label=0.0, + example_weight=0.0, + fixed_int=1, + fixed_float=1.0, + fixed_string="fixed_string2", + ), + self._makeExample( + label=0.0, + example_weight=1.0, + fixed_int=2, + fixed_float=0.0, + fixed_string="fixed_string3", + ), + ] - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'Create' - >> beam.Create( - [e.SerializeToString() for e in examples], reshuffle=False - ) - | 'BatchExamples' >> tfx_io.BeamSource(batch_size=3) - | 'InputsToExtracts' >> model_eval_lib.BatchedInputsToExtracts() - | feature_extractor.stage_name >> feature_extractor.ptransform - | label_extractor.stage_name >> label_extractor.ptransform - | example_weight_extractor.stage_name - >> example_weight_extractor.ptransform - | predict_extractor.stage_name >> predict_extractor.ptransform - | unbatch_inputs_extractor.stage_name - >> unbatch_inputs_extractor.ptransform - ) + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "Create" + >> beam.Create( + [e.SerializeToString() for e in examples], reshuffle=False + ) + | "BatchExamples" >> tfx_io.BeamSource(batch_size=3) + | "InputsToExtracts" >> model_eval_lib.BatchedInputsToExtracts() + | feature_extractor.stage_name >> feature_extractor.ptransform + | label_extractor.stage_name >> label_extractor.ptransform + | example_weight_extractor.stage_name + >> example_weight_extractor.ptransform + | predict_extractor.stage_name >> predict_extractor.ptransform + | unbatch_inputs_extractor.stage_name + >> unbatch_inputs_extractor.ptransform + ) - # pylint: enable=no-value-for-parameter + # pylint: enable=no-value-for-parameter - def check_result(got): - try: - self.assertLen(got, 3) - self.assertDictElementsAlmostEqual( - got[0][constants.FEATURES_KEY], - { - 'fixed_int': np.array([1]), - 'fixed_float': np.array([1.0]), - }, - ) - self.assertEqual( - got[0][constants.FEATURES_KEY]['fixed_string'], - np.array([b'fixed_string1']), - ) - self.assertAlmostEqual(got[0][constants.LABELS_KEY], np.array([1.0])) - self.assertAlmostEqual( - got[0][constants.EXAMPLE_WEIGHTS_KEY], np.array([0.5]) - ) - self.assertDictElementsAlmostEqual( - got[1][constants.FEATURES_KEY], - { - 'fixed_int': np.array([1]), - 'fixed_float': np.array([1.0]), - }, - ) - self.assertEqual( - got[1][constants.FEATURES_KEY]['fixed_string'], - np.array([b'fixed_string2']), - ) - self.assertAlmostEqual(got[1][constants.LABELS_KEY], np.array([0.0])) - self.assertAlmostEqual( - got[1][constants.EXAMPLE_WEIGHTS_KEY], np.array([0.0]) - ) - self.assertDictElementsAlmostEqual( - got[2][constants.FEATURES_KEY], - { - 'fixed_int': np.array([2]), - 'fixed_float': np.array([0.0]), - }, - ) - self.assertEqual( - got[2][constants.FEATURES_KEY]['fixed_string'], - np.array([b'fixed_string3']), - ) - self.assertAlmostEqual(got[2][constants.LABELS_KEY], np.array([0.0])) - self.assertAlmostEqual( - got[2][constants.EXAMPLE_WEIGHTS_KEY], np.array([1.0]) - ) + def check_result(got): + try: + self.assertLen(got, 3) + self.assertDictElementsAlmostEqual( + got[0][constants.FEATURES_KEY], + { + "fixed_int": np.array([1]), + "fixed_float": np.array([1.0]), + }, + ) + self.assertEqual( + got[0][constants.FEATURES_KEY]["fixed_string"], + np.array([b"fixed_string1"]), + ) + self.assertAlmostEqual( + got[0][constants.LABELS_KEY], np.array([1.0]) + ) + self.assertAlmostEqual( + got[0][constants.EXAMPLE_WEIGHTS_KEY], np.array([0.5]) + ) + self.assertDictElementsAlmostEqual( + got[1][constants.FEATURES_KEY], + { + "fixed_int": np.array([1]), + "fixed_float": np.array([1.0]), + }, + ) + self.assertEqual( + got[1][constants.FEATURES_KEY]["fixed_string"], + np.array([b"fixed_string2"]), + ) + self.assertAlmostEqual( + got[1][constants.LABELS_KEY], np.array([0.0]) + ) + self.assertAlmostEqual( + got[1][constants.EXAMPLE_WEIGHTS_KEY], np.array([0.0]) + ) + self.assertDictElementsAlmostEqual( + got[2][constants.FEATURES_KEY], + { + "fixed_int": np.array([2]), + "fixed_float": np.array([0.0]), + }, + ) + self.assertEqual( + got[2][constants.FEATURES_KEY]["fixed_string"], + np.array([b"fixed_string3"]), + ) + self.assertAlmostEqual( + got[2][constants.LABELS_KEY], np.array([0.0]) + ) + self.assertAlmostEqual( + got[2][constants.EXAMPLE_WEIGHTS_KEY], np.array([1.0]) + ) - except AssertionError as err: - raise util.BeamAssertException(err) + except AssertionError as err: + raise util.BeamAssertException(err) - util.assert_that(result, check_result, label='result') + util.assert_that(result, check_result, label="result") - def testUnbatchExtractorMultiOutput(self): - model_spec = config_pb2.ModelSpec( - label_keys={'output1': 'label1', 'output2': 'label2'}, - example_weight_keys={ - 'output1': 'example_weight1', - 'output2': 'example_weight2', - }, - ) - eval_config = config_pb2.EvalConfig(model_specs=[model_spec]) - feature_extractor = features_extractor.FeaturesExtractor(eval_config) - label_extractor = labels_extractor.LabelsExtractor(eval_config) - example_weight_extractor = ( - example_weights_extractor.ExampleWeightsExtractor(eval_config) - ) - predict_extractor = ( - materialized_predictions_extractor.MaterializedPredictionsExtractor( + def testUnbatchExtractorMultiOutput(self): + model_spec = config_pb2.ModelSpec( + label_keys={"output1": "label1", "output2": "label2"}, + example_weight_keys={ + "output1": "example_weight1", + "output2": "example_weight2", + }, + ) + eval_config = config_pb2.EvalConfig(model_specs=[model_spec]) + feature_extractor = features_extractor.FeaturesExtractor(eval_config) + label_extractor = labels_extractor.LabelsExtractor(eval_config) + example_weight_extractor = example_weights_extractor.ExampleWeightsExtractor( eval_config ) - ) - unbatch_inputs_extractor = unbatch_extractor.UnbatchExtractor() + predict_extractor = ( + materialized_predictions_extractor.MaterializedPredictionsExtractor( + eval_config + ) + ) + unbatch_inputs_extractor = unbatch_extractor.UnbatchExtractor() - schema = text_format.Parse( - """ + schema = text_format.Parse( + """ feature { name: "label1" type: FLOAT @@ -257,132 +264,132 @@ def testUnbatchExtractorMultiOutput(self): type: BYTES } """, - schema_pb2.Schema(), - ) - tfx_io = test_util.InMemoryTFExampleRecord( - schema=schema, raw_record_column_name=constants.ARROW_INPUT_COLUMN - ) + schema_pb2.Schema(), + ) + tfx_io = test_util.InMemoryTFExampleRecord( + schema=schema, raw_record_column_name=constants.ARROW_INPUT_COLUMN + ) - examples = [ - self._makeExample( - label1=1.0, - label2=0.0, - example_weight1=0.5, - example_weight2=0.5, - fixed_int=1, - fixed_float=1.0, - fixed_string='fixed_string1', - ), - self._makeExample( - label1=1.0, - label2=1.0, - example_weight1=0.0, - example_weight2=1.0, - fixed_int=1, - fixed_float=1.0, - fixed_string='fixed_string2', - ), - ] + examples = [ + self._makeExample( + label1=1.0, + label2=0.0, + example_weight1=0.5, + example_weight2=0.5, + fixed_int=1, + fixed_float=1.0, + fixed_string="fixed_string1", + ), + self._makeExample( + label1=1.0, + label2=1.0, + example_weight1=0.0, + example_weight2=1.0, + fixed_int=1, + fixed_float=1.0, + fixed_string="fixed_string2", + ), + ] - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'Create' - >> beam.Create( - [e.SerializeToString() for e in examples], reshuffle=False - ) - | 'BatchExamples' >> tfx_io.BeamSource(batch_size=2) - | 'InputsToExtracts' >> model_eval_lib.BatchedInputsToExtracts() - | feature_extractor.stage_name >> feature_extractor.ptransform - | label_extractor.stage_name >> label_extractor.ptransform - | example_weight_extractor.stage_name - >> example_weight_extractor.ptransform - | predict_extractor.stage_name >> predict_extractor.ptransform - | unbatch_inputs_extractor.stage_name - >> unbatch_inputs_extractor.ptransform - ) + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "Create" + >> beam.Create( + [e.SerializeToString() for e in examples], reshuffle=False + ) + | "BatchExamples" >> tfx_io.BeamSource(batch_size=2) + | "InputsToExtracts" >> model_eval_lib.BatchedInputsToExtracts() + | feature_extractor.stage_name >> feature_extractor.ptransform + | label_extractor.stage_name >> label_extractor.ptransform + | example_weight_extractor.stage_name + >> example_weight_extractor.ptransform + | predict_extractor.stage_name >> predict_extractor.ptransform + | unbatch_inputs_extractor.stage_name + >> unbatch_inputs_extractor.ptransform + ) - # pylint: enable=no-value-for-parameter + # pylint: enable=no-value-for-parameter - def check_result(got): - try: - self.assertLen(got, 2) - self.assertDictElementsAlmostEqual( - got[0][constants.FEATURES_KEY], - { - 'fixed_int': np.array([1]), - 'fixed_float': np.array([1.0]), - }, - ) - self.assertEqual( - got[0][constants.FEATURES_KEY]['fixed_string'], - np.array([b'fixed_string1']), - ) - self.assertDictElementsAlmostEqual( - got[0][constants.LABELS_KEY], - {'output1': np.array([1.0]), 'output2': np.array([0.0])}, - ) - self.assertDictElementsAlmostEqual( - got[0][constants.EXAMPLE_WEIGHTS_KEY], - {'output1': np.array([0.5]), 'output2': np.array([0.5])}, - ) - self.assertDictElementsAlmostEqual( - got[1][constants.FEATURES_KEY], - { - 'fixed_int': np.array([1]), - 'fixed_float': np.array([1.0]), - }, - ) - self.assertEqual( - got[1][constants.FEATURES_KEY]['fixed_string'], - np.array([b'fixed_string2']), - ) - self.assertDictElementsAlmostEqual( - got[1][constants.LABELS_KEY], - {'output1': np.array([1.0]), 'output2': np.array([1.0])}, - ) - self.assertDictElementsAlmostEqual( - got[1][constants.EXAMPLE_WEIGHTS_KEY], - {'output1': np.array([0.0]), 'output2': np.array([1.0])}, - ) + def check_result(got): + try: + self.assertLen(got, 2) + self.assertDictElementsAlmostEqual( + got[0][constants.FEATURES_KEY], + { + "fixed_int": np.array([1]), + "fixed_float": np.array([1.0]), + }, + ) + self.assertEqual( + got[0][constants.FEATURES_KEY]["fixed_string"], + np.array([b"fixed_string1"]), + ) + self.assertDictElementsAlmostEqual( + got[0][constants.LABELS_KEY], + {"output1": np.array([1.0]), "output2": np.array([0.0])}, + ) + self.assertDictElementsAlmostEqual( + got[0][constants.EXAMPLE_WEIGHTS_KEY], + {"output1": np.array([0.5]), "output2": np.array([0.5])}, + ) + self.assertDictElementsAlmostEqual( + got[1][constants.FEATURES_KEY], + { + "fixed_int": np.array([1]), + "fixed_float": np.array([1.0]), + }, + ) + self.assertEqual( + got[1][constants.FEATURES_KEY]["fixed_string"], + np.array([b"fixed_string2"]), + ) + self.assertDictElementsAlmostEqual( + got[1][constants.LABELS_KEY], + {"output1": np.array([1.0]), "output2": np.array([1.0])}, + ) + self.assertDictElementsAlmostEqual( + got[1][constants.EXAMPLE_WEIGHTS_KEY], + {"output1": np.array([0.0]), "output2": np.array([1.0])}, + ) - except AssertionError as err: - raise util.BeamAssertException(err) + except AssertionError as err: + raise util.BeamAssertException(err) - util.assert_that(result, check_result, label='result') + util.assert_that(result, check_result, label="result") - def testUnbatchExtractorMultiModel(self): - model_spec1 = config_pb2.ModelSpec( - name='model1', - label_key='label', - example_weight_key='example_weight', - prediction_key='fixed_float', - ) - model_spec2 = config_pb2.ModelSpec( - name='model2', - label_keys={'output1': 'label1', 'output2': 'label2'}, - example_weight_keys={ - 'output1': 'example_weight1', - 'output2': 'example_weight2', - }, - prediction_keys={'output1': 'fixed_float', 'output2': 'fixed_float'}, - ) - eval_config = config_pb2.EvalConfig(model_specs=[model_spec1, model_spec2]) - feature_extractor = features_extractor.FeaturesExtractor(eval_config) - label_extractor = labels_extractor.LabelsExtractor(eval_config) - example_weight_extractor = ( - example_weights_extractor.ExampleWeightsExtractor(eval_config) - ) - predict_extractor = ( - materialized_predictions_extractor.MaterializedPredictionsExtractor( + def testUnbatchExtractorMultiModel(self): + model_spec1 = config_pb2.ModelSpec( + name="model1", + label_key="label", + example_weight_key="example_weight", + prediction_key="fixed_float", + ) + model_spec2 = config_pb2.ModelSpec( + name="model2", + label_keys={"output1": "label1", "output2": "label2"}, + example_weight_keys={ + "output1": "example_weight1", + "output2": "example_weight2", + }, + prediction_keys={"output1": "fixed_float", "output2": "fixed_float"}, + ) + eval_config = config_pb2.EvalConfig(model_specs=[model_spec1, model_spec2]) + feature_extractor = features_extractor.FeaturesExtractor(eval_config) + label_extractor = labels_extractor.LabelsExtractor(eval_config) + example_weight_extractor = example_weights_extractor.ExampleWeightsExtractor( eval_config ) - ) - unbatch_inputs_extractor = unbatch_extractor.UnbatchExtractor() + predict_extractor = ( + materialized_predictions_extractor.MaterializedPredictionsExtractor( + eval_config + ) + ) + unbatch_inputs_extractor = unbatch_extractor.UnbatchExtractor() - schema = text_format.Parse( - """ + schema = text_format.Parse( + """ feature { name: "label" type: FLOAT @@ -420,137 +427,137 @@ def testUnbatchExtractorMultiModel(self): type: BYTES } """, - schema_pb2.Schema(), - ) - tfx_io = test_util.InMemoryTFExampleRecord( - schema=schema, raw_record_column_name=constants.ARROW_INPUT_COLUMN - ) + schema_pb2.Schema(), + ) + tfx_io = test_util.InMemoryTFExampleRecord( + schema=schema, raw_record_column_name=constants.ARROW_INPUT_COLUMN + ) - examples = [ - self._makeExample( - label=1.0, - label1=1.0, - label2=0.0, - example_weight=0.5, - example_weight1=0.5, - example_weight2=0.5, - fixed_int=1, - fixed_float=1.0, - fixed_string='fixed_string1', - ), - self._makeExample( - label=1.0, - label1=1.0, - label2=1.0, - example_weight=0.0, - example_weight1=0.0, - example_weight2=1.0, - fixed_int=1, - fixed_float=2.0, - fixed_string='fixed_string2', - ), - ] + examples = [ + self._makeExample( + label=1.0, + label1=1.0, + label2=0.0, + example_weight=0.5, + example_weight1=0.5, + example_weight2=0.5, + fixed_int=1, + fixed_float=1.0, + fixed_string="fixed_string1", + ), + self._makeExample( + label=1.0, + label1=1.0, + label2=1.0, + example_weight=0.0, + example_weight1=0.0, + example_weight2=1.0, + fixed_int=1, + fixed_float=2.0, + fixed_string="fixed_string2", + ), + ] - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'Create' - >> beam.Create( - [e.SerializeToString() for e in examples], reshuffle=False - ) - | 'BatchExamples' >> tfx_io.BeamSource(batch_size=2) - | 'InputsToExtracts' >> model_eval_lib.BatchedInputsToExtracts() - | feature_extractor.stage_name >> feature_extractor.ptransform - | label_extractor.stage_name >> label_extractor.ptransform - | example_weight_extractor.stage_name - >> example_weight_extractor.ptransform - | predict_extractor.stage_name >> predict_extractor.ptransform - | unbatch_inputs_extractor.stage_name - >> unbatch_inputs_extractor.ptransform - ) + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "Create" + >> beam.Create( + [e.SerializeToString() for e in examples], reshuffle=False + ) + | "BatchExamples" >> tfx_io.BeamSource(batch_size=2) + | "InputsToExtracts" >> model_eval_lib.BatchedInputsToExtracts() + | feature_extractor.stage_name >> feature_extractor.ptransform + | label_extractor.stage_name >> label_extractor.ptransform + | example_weight_extractor.stage_name + >> example_weight_extractor.ptransform + | predict_extractor.stage_name >> predict_extractor.ptransform + | unbatch_inputs_extractor.stage_name + >> unbatch_inputs_extractor.ptransform + ) - # pylint: enable=no-value-for-parameter + # pylint: enable=no-value-for-parameter - def check_result(got): - try: - self.assertLen(got, 2) - self.assertDictElementsAlmostEqual( - got[0][constants.FEATURES_KEY], - { - 'fixed_int': np.array([1]), - }, - ) - self.assertEqual( - got[0][constants.FEATURES_KEY]['fixed_string'], - np.array([b'fixed_string1']), - ) - for model_name in ('model1', 'model2'): - self.assertIn(model_name, got[0][constants.LABELS_KEY]) - self.assertIn(model_name, got[0][constants.EXAMPLE_WEIGHTS_KEY]) - self.assertIn(model_name, got[0][constants.PREDICTIONS_KEY]) - self.assertAlmostEqual( - got[0][constants.LABELS_KEY]['model1'], np.array([1.0]) - ) - self.assertDictElementsAlmostEqual( - got[0][constants.LABELS_KEY]['model2'], - {'output1': np.array([1.0]), 'output2': np.array([0.0])}, - ) - self.assertAlmostEqual( - got[0][constants.EXAMPLE_WEIGHTS_KEY]['model1'], np.array([0.5]) - ) - self.assertDictElementsAlmostEqual( - got[0][constants.EXAMPLE_WEIGHTS_KEY]['model2'], - {'output1': np.array([0.5]), 'output2': np.array([0.5])}, - ) - self.assertAlmostEqual( - got[0][constants.PREDICTIONS_KEY]['model1'], np.array([1.0]) - ) - self.assertDictElementsAlmostEqual( - got[0][constants.PREDICTIONS_KEY]['model2'], - {'output1': np.array([1.0]), 'output2': np.array([1.0])}, - ) + def check_result(got): + try: + self.assertLen(got, 2) + self.assertDictElementsAlmostEqual( + got[0][constants.FEATURES_KEY], + { + "fixed_int": np.array([1]), + }, + ) + self.assertEqual( + got[0][constants.FEATURES_KEY]["fixed_string"], + np.array([b"fixed_string1"]), + ) + for model_name in ("model1", "model2"): + self.assertIn(model_name, got[0][constants.LABELS_KEY]) + self.assertIn(model_name, got[0][constants.EXAMPLE_WEIGHTS_KEY]) + self.assertIn(model_name, got[0][constants.PREDICTIONS_KEY]) + self.assertAlmostEqual( + got[0][constants.LABELS_KEY]["model1"], np.array([1.0]) + ) + self.assertDictElementsAlmostEqual( + got[0][constants.LABELS_KEY]["model2"], + {"output1": np.array([1.0]), "output2": np.array([0.0])}, + ) + self.assertAlmostEqual( + got[0][constants.EXAMPLE_WEIGHTS_KEY]["model1"], np.array([0.5]) + ) + self.assertDictElementsAlmostEqual( + got[0][constants.EXAMPLE_WEIGHTS_KEY]["model2"], + {"output1": np.array([0.5]), "output2": np.array([0.5])}, + ) + self.assertAlmostEqual( + got[0][constants.PREDICTIONS_KEY]["model1"], np.array([1.0]) + ) + self.assertDictElementsAlmostEqual( + got[0][constants.PREDICTIONS_KEY]["model2"], + {"output1": np.array([1.0]), "output2": np.array([1.0])}, + ) - self.assertDictElementsAlmostEqual( - got[1][constants.FEATURES_KEY], - { - 'fixed_int': np.array([1]), - }, - ) - self.assertEqual( - got[1][constants.FEATURES_KEY]['fixed_string'], - np.array([b'fixed_string2']), - ) - for model_name in ('model1', 'model2'): - self.assertIn(model_name, got[1][constants.LABELS_KEY]) - self.assertIn(model_name, got[1][constants.EXAMPLE_WEIGHTS_KEY]) - self.assertIn(model_name, got[1][constants.PREDICTIONS_KEY]) - self.assertAlmostEqual( - got[1][constants.LABELS_KEY]['model1'], np.array([1.0]) - ) - self.assertDictElementsAlmostEqual( - got[1][constants.LABELS_KEY]['model2'], - {'output1': np.array([1.0]), 'output2': np.array([1.0])}, - ) - self.assertAlmostEqual( - got[1][constants.EXAMPLE_WEIGHTS_KEY]['model1'], np.array([0.0]) - ) - self.assertDictElementsAlmostEqual( - got[1][constants.EXAMPLE_WEIGHTS_KEY]['model2'], - {'output1': np.array([0.0]), 'output2': np.array([1.0])}, - ) - self.assertAlmostEqual( - got[1][constants.PREDICTIONS_KEY]['model1'], np.array([2.0]) - ) - self.assertDictElementsAlmostEqual( - got[1][constants.PREDICTIONS_KEY]['model2'], - {'output1': np.array([2.0]), 'output2': np.array([2.0])}, - ) - except AssertionError as err: - raise util.BeamAssertException(err) + self.assertDictElementsAlmostEqual( + got[1][constants.FEATURES_KEY], + { + "fixed_int": np.array([1]), + }, + ) + self.assertEqual( + got[1][constants.FEATURES_KEY]["fixed_string"], + np.array([b"fixed_string2"]), + ) + for model_name in ("model1", "model2"): + self.assertIn(model_name, got[1][constants.LABELS_KEY]) + self.assertIn(model_name, got[1][constants.EXAMPLE_WEIGHTS_KEY]) + self.assertIn(model_name, got[1][constants.PREDICTIONS_KEY]) + self.assertAlmostEqual( + got[1][constants.LABELS_KEY]["model1"], np.array([1.0]) + ) + self.assertDictElementsAlmostEqual( + got[1][constants.LABELS_KEY]["model2"], + {"output1": np.array([1.0]), "output2": np.array([1.0])}, + ) + self.assertAlmostEqual( + got[1][constants.EXAMPLE_WEIGHTS_KEY]["model1"], np.array([0.0]) + ) + self.assertDictElementsAlmostEqual( + got[1][constants.EXAMPLE_WEIGHTS_KEY]["model2"], + {"output1": np.array([0.0]), "output2": np.array([1.0])}, + ) + self.assertAlmostEqual( + got[1][constants.PREDICTIONS_KEY]["model1"], np.array([2.0]) + ) + self.assertDictElementsAlmostEqual( + got[1][constants.PREDICTIONS_KEY]["model2"], + {"output1": np.array([2.0]), "output2": np.array([2.0])}, + ) + except AssertionError as err: + raise util.BeamAssertException(err) - util.assert_that(result, check_result, label='result') + util.assert_that(result, check_result, label="result") -if __name__ == '__main__': - tf.test.main() +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow_model_analysis/frontend/tfma-precision-recall-curve/demo/index.js b/tensorflow_model_analysis/frontend/tfma-precision-recall-curve/demo/index.js index abaa399081..694f5f4218 100644 --- a/tensorflow_model_analysis/frontend/tfma-precision-recall-curve/demo/index.js +++ b/tensorflow_model_analysis/frontend/tfma-precision-recall-curve/demo/index.js @@ -27,4 +27,3 @@ } plot.data = input; })(); - diff --git a/tensorflow_model_analysis/metrics/__init__.py b/tensorflow_model_analysis/metrics/__init__.py index eb2a04fe56..86da35708c 100644 --- a/tensorflow_model_analysis/metrics/__init__.py +++ b/tensorflow_model_analysis/metrics/__init__.py @@ -13,245 +13,289 @@ # limitations under the License. """Init module for TensorFlow Model Analysis metrics.""" -from tensorflow_model_analysis.metrics import bleu -from tensorflow_model_analysis.metrics import preprocessors -from tensorflow_model_analysis.metrics import rouge -from tensorflow_model_analysis.metrics.attributions import AttributionsMetric -from tensorflow_model_analysis.metrics.attributions import has_attributions_metrics -from tensorflow_model_analysis.metrics.attributions import MeanAbsoluteAttributions -from tensorflow_model_analysis.metrics.attributions import MeanAttributions -from tensorflow_model_analysis.metrics.attributions import TotalAbsoluteAttributions -from tensorflow_model_analysis.metrics.attributions import TotalAttributions -from tensorflow_model_analysis.metrics.calibration import Calibration -from tensorflow_model_analysis.metrics.calibration import MeanLabel -from tensorflow_model_analysis.metrics.calibration import MeanPrediction +from tensorflow_model_analysis.metrics import bleu, preprocessors, rouge +from tensorflow_model_analysis.metrics.attributions import ( + AttributionsMetric, + MeanAbsoluteAttributions, + MeanAttributions, + TotalAbsoluteAttributions, + TotalAttributions, + has_attributions_metrics, +) +from tensorflow_model_analysis.metrics.calibration import ( + Calibration, + MeanLabel, + MeanPrediction, +) from tensorflow_model_analysis.metrics.calibration_plot import CalibrationPlot -from tensorflow_model_analysis.metrics.confusion_matrix_metrics import AUC -from tensorflow_model_analysis.metrics.confusion_matrix_metrics import AUCCurve -from tensorflow_model_analysis.metrics.confusion_matrix_metrics import AUCPrecisionRecall -from tensorflow_model_analysis.metrics.confusion_matrix_metrics import AUCSummationMethod -from tensorflow_model_analysis.metrics.confusion_matrix_metrics import BalancedAccuracy -from tensorflow_model_analysis.metrics.confusion_matrix_metrics import BinaryAccuracy -from tensorflow_model_analysis.metrics.confusion_matrix_metrics import ConfusionMatrixAtThresholds -from tensorflow_model_analysis.metrics.confusion_matrix_metrics import DiagnosticOddsRatio -from tensorflow_model_analysis.metrics.confusion_matrix_metrics import F1Score -from tensorflow_model_analysis.metrics.confusion_matrix_metrics import FallOut -from tensorflow_model_analysis.metrics.confusion_matrix_metrics import FalseDiscoveryRate -from tensorflow_model_analysis.metrics.confusion_matrix_metrics import FalseNegatives -from tensorflow_model_analysis.metrics.confusion_matrix_metrics import FalseOmissionRate -from tensorflow_model_analysis.metrics.confusion_matrix_metrics import FalsePositives -from tensorflow_model_analysis.metrics.confusion_matrix_metrics import FN -from tensorflow_model_analysis.metrics.confusion_matrix_metrics import FNR -from tensorflow_model_analysis.metrics.confusion_matrix_metrics import FowlkesMallowsIndex -from tensorflow_model_analysis.metrics.confusion_matrix_metrics import FP -from tensorflow_model_analysis.metrics.confusion_matrix_metrics import FPR -from tensorflow_model_analysis.metrics.confusion_matrix_metrics import Informedness -from tensorflow_model_analysis.metrics.confusion_matrix_metrics import Markedness -from tensorflow_model_analysis.metrics.confusion_matrix_metrics import MatthewsCorrelationCoefficient -from tensorflow_model_analysis.metrics.confusion_matrix_metrics import MaxRecall -from tensorflow_model_analysis.metrics.confusion_matrix_metrics import MissRate -from tensorflow_model_analysis.metrics.confusion_matrix_metrics import NegativeLikelihoodRatio -from tensorflow_model_analysis.metrics.confusion_matrix_metrics import NegativePredictiveValue -from tensorflow_model_analysis.metrics.confusion_matrix_metrics import NPV -from tensorflow_model_analysis.metrics.confusion_matrix_metrics import PositiveLikelihoodRatio -from tensorflow_model_analysis.metrics.confusion_matrix_metrics import PPV -from tensorflow_model_analysis.metrics.confusion_matrix_metrics import Precision -from tensorflow_model_analysis.metrics.confusion_matrix_metrics import PrecisionAtRecall -from tensorflow_model_analysis.metrics.confusion_matrix_metrics import Prevalence -from tensorflow_model_analysis.metrics.confusion_matrix_metrics import PrevalenceThreshold -from tensorflow_model_analysis.metrics.confusion_matrix_metrics import Recall -from tensorflow_model_analysis.metrics.confusion_matrix_metrics import RecallAtPrecision -from tensorflow_model_analysis.metrics.confusion_matrix_metrics import SensitivityAtSpecificity -from tensorflow_model_analysis.metrics.confusion_matrix_metrics import Specificity -from tensorflow_model_analysis.metrics.confusion_matrix_metrics import SpecificityAtSensitivity -from tensorflow_model_analysis.metrics.confusion_matrix_metrics import ThreatScore -from tensorflow_model_analysis.metrics.confusion_matrix_metrics import TN -from tensorflow_model_analysis.metrics.confusion_matrix_metrics import TNR -from tensorflow_model_analysis.metrics.confusion_matrix_metrics import TP -from tensorflow_model_analysis.metrics.confusion_matrix_metrics import TPR -from tensorflow_model_analysis.metrics.confusion_matrix_metrics import TrueNegatives -from tensorflow_model_analysis.metrics.confusion_matrix_metrics import TruePositives +from tensorflow_model_analysis.metrics.confusion_matrix_metrics import ( + AUC, + FN, + FNR, + FP, + FPR, + NPV, + PPV, + TN, + TNR, + TP, + TPR, + AUCCurve, + AUCPrecisionRecall, + AUCSummationMethod, + BalancedAccuracy, + BinaryAccuracy, + ConfusionMatrixAtThresholds, + DiagnosticOddsRatio, + F1Score, + FallOut, + FalseDiscoveryRate, + FalseNegatives, + FalseOmissionRate, + FalsePositives, + FowlkesMallowsIndex, + Informedness, + Markedness, + MatthewsCorrelationCoefficient, + MaxRecall, + MissRate, + NegativeLikelihoodRatio, + NegativePredictiveValue, + PositiveLikelihoodRatio, + Precision, + PrecisionAtRecall, + Prevalence, + PrevalenceThreshold, + Recall, + RecallAtPrecision, + SensitivityAtSpecificity, + Specificity, + SpecificityAtSensitivity, + ThreatScore, + TrueNegatives, + TruePositives, +) from tensorflow_model_analysis.metrics.confusion_matrix_plot import ConfusionMatrixPlot -from tensorflow_model_analysis.metrics.cross_entropy_metrics import BinaryCrossEntropy -from tensorflow_model_analysis.metrics.cross_entropy_metrics import CategoricalCrossEntropy +from tensorflow_model_analysis.metrics.cross_entropy_metrics import ( + BinaryCrossEntropy, + CategoricalCrossEntropy, +) from tensorflow_model_analysis.metrics.exact_match import ExactMatch from tensorflow_model_analysis.metrics.example_count import ExampleCount -from tensorflow_model_analysis.metrics.flip_metrics import BooleanFlipRates -from tensorflow_model_analysis.metrics.flip_metrics import NegToNegFlipRate -from tensorflow_model_analysis.metrics.flip_metrics import NegToPosFlipRate -from tensorflow_model_analysis.metrics.flip_metrics import PosToNegFlipRate -from tensorflow_model_analysis.metrics.flip_metrics import PosToPosFlipRate -from tensorflow_model_analysis.metrics.flip_metrics import SymmetricFlipRate -from tensorflow_model_analysis.metrics.mean_regression_error import MeanAbsoluteError -from tensorflow_model_analysis.metrics.mean_regression_error import MeanAbsolutePercentageError -from tensorflow_model_analysis.metrics.mean_regression_error import MeanSquaredError -from tensorflow_model_analysis.metrics.mean_regression_error import MeanSquaredLogarithmicError -from tensorflow_model_analysis.metrics.metric_specs import default_binary_classification_specs -from tensorflow_model_analysis.metrics.metric_specs import default_multi_class_classification_specs -from tensorflow_model_analysis.metrics.metric_specs import default_regression_specs -from tensorflow_model_analysis.metrics.metric_specs import metric_thresholds_from_metrics_specs -from tensorflow_model_analysis.metrics.metric_specs import specs_from_metrics -from tensorflow_model_analysis.metrics.metric_types import CombinedFeaturePreprocessor -from tensorflow_model_analysis.metrics.metric_types import DerivedMetricComputation -from tensorflow_model_analysis.metrics.metric_types import FeaturePreprocessor -from tensorflow_model_analysis.metrics.metric_types import Metric -from tensorflow_model_analysis.metrics.metric_types import MetricComputation -from tensorflow_model_analysis.metrics.metric_types import MetricComputations -from tensorflow_model_analysis.metrics.metric_types import MetricKey -from tensorflow_model_analysis.metrics.metric_types import MetricsDict -from tensorflow_model_analysis.metrics.metric_types import PlotKey -from tensorflow_model_analysis.metrics.metric_types import Preprocessor -from tensorflow_model_analysis.metrics.metric_types import StandardMetricInputs -from tensorflow_model_analysis.metrics.metric_types import SubKey -from tensorflow_model_analysis.metrics.metric_util import merge_per_key_computations -from tensorflow_model_analysis.metrics.metric_util import to_label_prediction_example_weight -from tensorflow_model_analysis.metrics.metric_util import to_standard_metric_inputs +from tensorflow_model_analysis.metrics.flip_metrics import ( + BooleanFlipRates, + NegToNegFlipRate, + NegToPosFlipRate, + PosToNegFlipRate, + PosToPosFlipRate, + SymmetricFlipRate, +) +from tensorflow_model_analysis.metrics.mean_regression_error import ( + MeanAbsoluteError, + MeanAbsolutePercentageError, + MeanSquaredError, + MeanSquaredLogarithmicError, +) +from tensorflow_model_analysis.metrics.metric_specs import ( + default_binary_classification_specs, + default_multi_class_classification_specs, + default_regression_specs, + metric_thresholds_from_metrics_specs, + specs_from_metrics, +) +from tensorflow_model_analysis.metrics.metric_types import ( + CombinedFeaturePreprocessor, + DerivedMetricComputation, + FeaturePreprocessor, + Metric, + MetricComputation, + MetricComputations, + MetricKey, + MetricsDict, + PlotKey, + Preprocessor, + StandardMetricInputs, + SubKey, +) +from tensorflow_model_analysis.metrics.metric_util import ( + merge_per_key_computations, + to_label_prediction_example_weight, + to_standard_metric_inputs, +) from tensorflow_model_analysis.metrics.min_label_position import MinLabelPosition -from tensorflow_model_analysis.metrics.model_cosine_similarity import ModelCosineSimilarity -from tensorflow_model_analysis.metrics.multi_class_confusion_matrix_metrics import MultiClassConfusionMatrixAtThresholds -from tensorflow_model_analysis.metrics.multi_class_confusion_matrix_metrics import NO_PREDICTED_CLASS_ID -from tensorflow_model_analysis.metrics.multi_class_confusion_matrix_plot import MultiClassConfusionMatrixPlot -from tensorflow_model_analysis.metrics.multi_label_confusion_matrix_plot import MultiLabelConfusionMatrixPlot +from tensorflow_model_analysis.metrics.model_cosine_similarity import ( + ModelCosineSimilarity, +) +from tensorflow_model_analysis.metrics.multi_class_confusion_matrix_metrics import ( + NO_PREDICTED_CLASS_ID, + MultiClassConfusionMatrixAtThresholds, +) +from tensorflow_model_analysis.metrics.multi_class_confusion_matrix_plot import ( + MultiClassConfusionMatrixPlot, +) +from tensorflow_model_analysis.metrics.multi_label_confusion_matrix_plot import ( + MultiLabelConfusionMatrixPlot, +) from tensorflow_model_analysis.metrics.ndcg import NDCG -from tensorflow_model_analysis.metrics.object_detection_confusion_matrix_metrics import ObjectDetectionMaxRecall -from tensorflow_model_analysis.metrics.object_detection_confusion_matrix_metrics import ObjectDetectionPrecision -from tensorflow_model_analysis.metrics.object_detection_confusion_matrix_metrics import ObjectDetectionPrecisionAtRecall -from tensorflow_model_analysis.metrics.object_detection_confusion_matrix_metrics import ObjectDetectionRecall -from tensorflow_model_analysis.metrics.object_detection_confusion_matrix_metrics import ObjectDetectionThresholdAtRecall -from tensorflow_model_analysis.metrics.object_detection_confusion_matrix_plot import ObjectDetectionConfusionMatrixPlot -from tensorflow_model_analysis.metrics.object_detection_metrics import COCOAveragePrecision -from tensorflow_model_analysis.metrics.object_detection_metrics import COCOAverageRecall -from tensorflow_model_analysis.metrics.object_detection_metrics import COCOMeanAveragePrecision -from tensorflow_model_analysis.metrics.object_detection_metrics import COCOMeanAverageRecall -from tensorflow_model_analysis.metrics.prediction_difference_metrics import SymmetricPredictionDifference +from tensorflow_model_analysis.metrics.object_detection_confusion_matrix_metrics import ( + ObjectDetectionMaxRecall, + ObjectDetectionPrecision, + ObjectDetectionPrecisionAtRecall, + ObjectDetectionRecall, + ObjectDetectionThresholdAtRecall, +) +from tensorflow_model_analysis.metrics.object_detection_confusion_matrix_plot import ( + ObjectDetectionConfusionMatrixPlot, +) +from tensorflow_model_analysis.metrics.object_detection_metrics import ( + COCOAveragePrecision, + COCOAverageRecall, + COCOMeanAveragePrecision, + COCOMeanAverageRecall, +) +from tensorflow_model_analysis.metrics.prediction_difference_metrics import ( + SymmetricPredictionDifference, +) from tensorflow_model_analysis.metrics.query_statistics import QueryStatistics -from tensorflow_model_analysis.metrics.score_distribution_plot import ScoreDistributionPlot -from tensorflow_model_analysis.metrics.semantic_segmentation_confusion_matrix_metrics import SemanticSegmentationConfusionMatrix -from tensorflow_model_analysis.metrics.semantic_segmentation_confusion_matrix_metrics import SemanticSegmentationFalsePositive -from tensorflow_model_analysis.metrics.semantic_segmentation_confusion_matrix_metrics import SemanticSegmentationTruePositive -from tensorflow_model_analysis.metrics.set_match_confusion_matrix_metrics import SetMatchPrecision -from tensorflow_model_analysis.metrics.set_match_confusion_matrix_metrics import SetMatchRecall -from tensorflow_model_analysis.metrics.squared_pearson_correlation import SquaredPearsonCorrelation +from tensorflow_model_analysis.metrics.score_distribution_plot import ( + ScoreDistributionPlot, +) +from tensorflow_model_analysis.metrics.semantic_segmentation_confusion_matrix_metrics import ( + SemanticSegmentationConfusionMatrix, + SemanticSegmentationFalsePositive, + SemanticSegmentationTruePositive, +) +from tensorflow_model_analysis.metrics.set_match_confusion_matrix_metrics import ( + SetMatchPrecision, + SetMatchRecall, +) +from tensorflow_model_analysis.metrics.squared_pearson_correlation import ( + SquaredPearsonCorrelation, +) from tensorflow_model_analysis.metrics.stats import Mean -from tensorflow_model_analysis.metrics.tjur_discrimination import CoefficientOfDiscrimination -from tensorflow_model_analysis.metrics.tjur_discrimination import RelativeCoefficientOfDiscrimination -from tensorflow_model_analysis.metrics.weighted_example_count import WeightedExampleCount +from tensorflow_model_analysis.metrics.tjur_discrimination import ( + CoefficientOfDiscrimination, + RelativeCoefficientOfDiscrimination, +) +from tensorflow_model_analysis.metrics.weighted_example_count import ( + WeightedExampleCount, +) # TODO(b/143180976): Remove WeightedExampleCount. __all__ = [ - 'AttributionsMetric', - 'AUC', - 'AUCCurve', - 'AUCPrecisionRecall', - 'AUCSummationMethod', - 'BalancedAccuracy', - 'BinaryAccuracy', - 'BinaryCrossEntropy', - 'BooleanFlipRates', - 'Calibration', - 'CalibrationPlot', - 'CategoricalCrossEntropy', - 'COCOAveragePrecision', - 'COCOAverageRecall', - 'COCOMeanAveragePrecision', - 'COCOMeanAverageRecall', - 'CoefficientOfDiscrimination', - 'CombinedFeaturePreprocessor', - 'ConfusionMatrixAtThresholds', - 'ConfusionMatrixPlot', - 'default_binary_classification_specs', - 'default_multi_class_classification_specs', - 'default_regression_specs', - 'DerivedMetricComputation', - 'DiagnosticOddsRatio', - 'ExactMatch', - 'ExampleCount', - 'F1Score', - 'FallOut', - 'FalseDiscoveryRate', - 'FalseNegatives', - 'FalseOmissionRate', - 'FalsePositives', - 'FeaturePreprocessor', - 'FN', - 'FNR', - 'FowlkesMallowsIndex', - 'FP', - 'FPR', - 'has_attributions_metrics', - 'Informedness', - 'Markedness', - 'MatthewsCorrelationCoefficient', - 'MaxRecall', - 'Mean', - 'MeanAbsoluteAttributions', - 'MeanAbsoluteError', - 'MeanAbsolutePercentageError', - 'MeanAttributions', - 'MeanLabel', - 'MeanPrediction', - 'MeanSquaredError', - 'MeanSquaredLogarithmicError', - 'merge_per_key_computations', - 'Metric', - 'metric_thresholds_from_metrics_specs', - 'MetricComputation', - 'MetricComputations', - 'MetricKey', - 'MetricsDict', - 'MinLabelPosition', - 'MissRate', - 'MultiClassConfusionMatrixAtThresholds', - 'MultiClassConfusionMatrixPlot', - 'MultiLabelConfusionMatrixPlot', - 'NDCG', - 'NegativeLikelihoodRatio', - 'NegativePredictiveValue', - 'NO_PREDICTED_CLASS_ID', - 'NPV', - 'ObjectDetectionConfusionMatrixPlot', - 'ObjectDetectionMaxRecall', - 'ObjectDetectionPrecision', - 'ObjectDetectionPrecisionAtRecall', - 'ObjectDetectionRecall', - 'ObjectDetectionThresholdAtRecall', - 'PlotKey', - 'PositiveLikelihoodRatio', - 'PPV', - 'Precision', - 'PrecisionAtRecall', - 'Preprocessor', - 'Prevalence', - 'PrevalenceThreshold', - 'QueryStatistics', - 'Recall', - 'RecallAtPrecision', - 'RelativeCoefficientOfDiscrimination', - 'ScoreDistributionPlot', - 'SemanticSegmentationConfusionMatrix', - 'SemanticSegmentationFalsePositive', - 'SemanticSegmentationTruePositive', - 'SensitivityAtSpecificity', - 'SetMatchPrecision', - 'SetMatchRecall', - 'Specificity', - 'SpecificityAtSensitivity', - 'specs_from_metrics', - 'SquaredPearsonCorrelation', - 'StandardMetricInputs', - 'SubKey', - 'SymmetricPredictionDifference', - 'ThreatScore', - 'TN', - 'TNR', - 'to_label_prediction_example_weight', - 'to_standard_metric_inputs', - 'TotalAbsoluteAttributions', - 'TotalAttributions', - 'TP', - 'TPR', - 'TrueNegatives', - 'TruePositives', - 'WeightedExampleCount' + "AttributionsMetric", + "AUC", + "AUCCurve", + "AUCPrecisionRecall", + "AUCSummationMethod", + "BalancedAccuracy", + "BinaryAccuracy", + "BinaryCrossEntropy", + "BooleanFlipRates", + "Calibration", + "CalibrationPlot", + "CategoricalCrossEntropy", + "COCOAveragePrecision", + "COCOAverageRecall", + "COCOMeanAveragePrecision", + "COCOMeanAverageRecall", + "CoefficientOfDiscrimination", + "CombinedFeaturePreprocessor", + "ConfusionMatrixAtThresholds", + "ConfusionMatrixPlot", + "default_binary_classification_specs", + "default_multi_class_classification_specs", + "default_regression_specs", + "DerivedMetricComputation", + "DiagnosticOddsRatio", + "ExactMatch", + "ExampleCount", + "F1Score", + "FallOut", + "FalseDiscoveryRate", + "FalseNegatives", + "FalseOmissionRate", + "FalsePositives", + "FeaturePreprocessor", + "FN", + "FNR", + "FowlkesMallowsIndex", + "FP", + "FPR", + "has_attributions_metrics", + "Informedness", + "Markedness", + "MatthewsCorrelationCoefficient", + "MaxRecall", + "Mean", + "MeanAbsoluteAttributions", + "MeanAbsoluteError", + "MeanAbsolutePercentageError", + "MeanAttributions", + "MeanLabel", + "MeanPrediction", + "MeanSquaredError", + "MeanSquaredLogarithmicError", + "merge_per_key_computations", + "Metric", + "metric_thresholds_from_metrics_specs", + "MetricComputation", + "MetricComputations", + "MetricKey", + "MetricsDict", + "MinLabelPosition", + "MissRate", + "MultiClassConfusionMatrixAtThresholds", + "MultiClassConfusionMatrixPlot", + "MultiLabelConfusionMatrixPlot", + "NDCG", + "NegativeLikelihoodRatio", + "NegativePredictiveValue", + "NO_PREDICTED_CLASS_ID", + "NPV", + "ObjectDetectionConfusionMatrixPlot", + "ObjectDetectionMaxRecall", + "ObjectDetectionPrecision", + "ObjectDetectionPrecisionAtRecall", + "ObjectDetectionRecall", + "ObjectDetectionThresholdAtRecall", + "PlotKey", + "PositiveLikelihoodRatio", + "PPV", + "Precision", + "PrecisionAtRecall", + "Preprocessor", + "Prevalence", + "PrevalenceThreshold", + "QueryStatistics", + "Recall", + "RecallAtPrecision", + "RelativeCoefficientOfDiscrimination", + "ScoreDistributionPlot", + "SemanticSegmentationConfusionMatrix", + "SemanticSegmentationFalsePositive", + "SemanticSegmentationTruePositive", + "SensitivityAtSpecificity", + "SetMatchPrecision", + "SetMatchRecall", + "Specificity", + "SpecificityAtSensitivity", + "specs_from_metrics", + "SquaredPearsonCorrelation", + "StandardMetricInputs", + "SubKey", + "SymmetricPredictionDifference", + "ThreatScore", + "TN", + "TNR", + "to_label_prediction_example_weight", + "to_standard_metric_inputs", + "TotalAbsoluteAttributions", + "TotalAttributions", + "TP", + "TPR", + "TrueNegatives", + "TruePositives", + "WeightedExampleCount", ] diff --git a/tensorflow_model_analysis/metrics/aggregation.py b/tensorflow_model_analysis/metrics/aggregation.py index 1ec17b7a83..78e2e0cef6 100644 --- a/tensorflow_model_analysis/metrics/aggregation.py +++ b/tensorflow_model_analysis/metrics/aggregation.py @@ -16,297 +16,326 @@ from typing import Any, Dict, Iterable, List, Optional import apache_beam as beam -from tensorflow_model_analysis.metrics import metric_types -from tensorflow_model_analysis.metrics import metric_util + +from tensorflow_model_analysis.metrics import metric_types, metric_util from tensorflow_model_analysis.proto import config_pb2 -_CLASS_WEIGHTS_FROM_LABELS_NAME = '_class_weights_from_labels' +_CLASS_WEIGHTS_FROM_LABELS_NAME = "_class_weights_from_labels" def output_average( metric_name: str, output_weights: Dict[str, float], eval_config: Optional[config_pb2.EvalConfig] = None, - model_name: str = '', + model_name: str = "", sub_key: Optional[metric_types.SubKey] = None, - example_weighted: bool = False) -> metric_types.MetricComputations: - """Returns metric computations for computing output average of given metric. - - Args: - metric_name: Name of underlying metric average is being computed for. - output_weights: Output weights to use to compute metric. - eval_config: Eval config. - model_name: Optional model name. - sub_key: Optional sub key associated with metric (e.g. top_k). - example_weighted: True if example weights should be applied. - - Returns: - Computation for performing the output average. - """ - del eval_config - - key = metric_types.MetricKey( - name=metric_name, - model_name=model_name, - sub_key=sub_key, - example_weighted=example_weighted) - - def result( - metrics: Dict[metric_types.MetricKey, float] - ) -> Dict[metric_types.MetricKey, float]: - """Returns output average.""" - total_value = 0.0 - total_weight = 0.0 - for output_name, output_weight in output_weights.items(): - child_key = metric_types.MetricKey( - name=metric_name, - model_name=model_name, - output_name=output_name, - sub_key=sub_key, - example_weighted=example_weighted) - total_value += _to_float(metrics[child_key]) * output_weight - total_weight += output_weight - average = total_value / total_weight if total_weight else float('nan') - return {key: average} - - return [metric_types.DerivedMetricComputation(keys=[key], result=result)] + example_weighted: bool = False, +) -> metric_types.MetricComputations: + """Returns metric computations for computing output average of given metric. + + Args: + ---- + metric_name: Name of underlying metric average is being computed for. + output_weights: Output weights to use to compute metric. + eval_config: Eval config. + model_name: Optional model name. + sub_key: Optional sub key associated with metric (e.g. top_k). + example_weighted: True if example weights should be applied. + + Returns: + ------- + Computation for performing the output average. + """ + del eval_config + + key = metric_types.MetricKey( + name=metric_name, + model_name=model_name, + sub_key=sub_key, + example_weighted=example_weighted, + ) + + def result( + metrics: Dict[metric_types.MetricKey, float], + ) -> Dict[metric_types.MetricKey, float]: + """Returns output average.""" + total_value = 0.0 + total_weight = 0.0 + for output_name, output_weight in output_weights.items(): + child_key = metric_types.MetricKey( + name=metric_name, + model_name=model_name, + output_name=output_name, + sub_key=sub_key, + example_weighted=example_weighted, + ) + total_value += _to_float(metrics[child_key]) * output_weight + total_weight += output_weight + average = total_value / total_weight if total_weight else float("nan") + return {key: average} + + return [metric_types.DerivedMetricComputation(keys=[key], result=result)] def macro_average( metric_name: str, sub_keys: Iterable[metric_types.SubKey], eval_config: Optional[config_pb2.EvalConfig] = None, - model_name: str = '', - output_name: str = '', + model_name: str = "", + output_name: str = "", sub_key: Optional[metric_types.SubKey] = None, class_weights: Optional[Dict[int, float]] = None, - example_weighted: bool = False) -> metric_types.MetricComputations: - """Returns metric computations for computing macro average of given metric. - - Args: - metric_name: Name of underlying metric average is being computed for. - sub_keys: Sub keys used to compute the metric (e.g. class_ids, etc). - eval_config: Eval config. - model_name: Optional model name. - output_name: Optional output name. - sub_key: Optional sub key associated with aggregation metric (e.g. top_k). - class_weights: Optional class weights to apply. Required if sub_key is not - provided. If class_weights are provided, but a sub_key.class_id (if - sub_key is None) or sub_key.k (if sub_key is top_k) is not set or not - found in the dictionary then 0.0 is assumed. - example_weighted: True if example weights should be applied. - - Returns: - Computation for performing the macro average. - """ - del eval_config - - key = metric_types.MetricKey( - name=metric_name, - model_name=model_name, - output_name=output_name, - sub_key=sub_key, - aggregation_type=metric_types.AggregationType(macro_average=True), - example_weighted=example_weighted) - - def result( - metrics: Dict[metric_types.MetricKey, float] - ) -> Dict[metric_types.MetricKey, float]: - """Returns macro average.""" - total_value = 0.0 - total_weight = 0.0 - for sub_key in sub_keys: - child_key = metric_types.MetricKey( - name=metric_name, - model_name=model_name, - output_name=output_name, - sub_key=sub_key, - example_weighted=example_weighted) - if child_key not in metrics: - # Use private name if not found under metric name - child_key = metric_types.MetricKey( - name='_' + metric_name, - model_name=model_name, - output_name=output_name, - sub_key=sub_key, - example_weighted=example_weighted) - weight = 1.0 if not class_weights else 0.0 - offset = None - if (child_key.sub_key is not None and - child_key.sub_key.class_id is not None): - offset = child_key.sub_key.class_id - elif child_key.sub_key is not None and child_key.sub_key.k is not None: - offset = child_key.sub_key.k - if offset is not None and offset in class_weights: - weight = class_weights[offset] - total_value += _to_float(metrics[child_key]) * weight - total_weight += weight - average = total_value / total_weight if total_weight else float('nan') - return {key: average} - - return [metric_types.DerivedMetricComputation(keys=[key], result=result)] + example_weighted: bool = False, +) -> metric_types.MetricComputations: + """Returns metric computations for computing macro average of given metric. + + Args: + ---- + metric_name: Name of underlying metric average is being computed for. + sub_keys: Sub keys used to compute the metric (e.g. class_ids, etc). + eval_config: Eval config. + model_name: Optional model name. + output_name: Optional output name. + sub_key: Optional sub key associated with aggregation metric (e.g. top_k). + class_weights: Optional class weights to apply. Required if sub_key is not + provided. If class_weights are provided, but a sub_key.class_id (if + sub_key is None) or sub_key.k (if sub_key is top_k) is not set or not + found in the dictionary then 0.0 is assumed. + example_weighted: True if example weights should be applied. + + Returns: + ------- + Computation for performing the macro average. + """ + del eval_config + + key = metric_types.MetricKey( + name=metric_name, + model_name=model_name, + output_name=output_name, + sub_key=sub_key, + aggregation_type=metric_types.AggregationType(macro_average=True), + example_weighted=example_weighted, + ) + + def result( + metrics: Dict[metric_types.MetricKey, float], + ) -> Dict[metric_types.MetricKey, float]: + """Returns macro average.""" + total_value = 0.0 + total_weight = 0.0 + for sub_key in sub_keys: + child_key = metric_types.MetricKey( + name=metric_name, + model_name=model_name, + output_name=output_name, + sub_key=sub_key, + example_weighted=example_weighted, + ) + if child_key not in metrics: + # Use private name if not found under metric name + child_key = metric_types.MetricKey( + name="_" + metric_name, + model_name=model_name, + output_name=output_name, + sub_key=sub_key, + example_weighted=example_weighted, + ) + weight = 1.0 if not class_weights else 0.0 + offset = None + if child_key.sub_key is not None and child_key.sub_key.class_id is not None: + offset = child_key.sub_key.class_id + elif child_key.sub_key is not None and child_key.sub_key.k is not None: + offset = child_key.sub_key.k + if offset is not None and offset in class_weights: + weight = class_weights[offset] + total_value += _to_float(metrics[child_key]) * weight + total_weight += weight + average = total_value / total_weight if total_weight else float("nan") + return {key: average} + + return [metric_types.DerivedMetricComputation(keys=[key], result=result)] def weighted_macro_average( metric_name: str, sub_keys: Iterable[metric_types.SubKey], eval_config: Optional[config_pb2.EvalConfig] = None, - model_name: str = '', - output_name: str = '', + model_name: str = "", + output_name: str = "", sub_key: Optional[metric_types.SubKey] = None, class_weights: Optional[Dict[int, float]] = None, - example_weighted: bool = False) -> metric_types.MetricComputations: - """Returns metric computations for computing weighted macro average of metric. - - The weights per class are based on the percentage of positive labels for each - class. - - Args: - metric_name: Name of metric weighted average is being computed for. - sub_keys: Sub keys used to compute the metric (e.g. class_ids, etc). - eval_config: Eval config. - model_name: Optional model name. - output_name: Optional output name. - sub_key: Optional sub key associated with aggregation metric (e.g. top_k). - class_weights: Optional class weights to apply. Required if sub_key is not - provided. If class_weights are provided, but a sub_key.class_id (if - sub_key is None) or sub_key.k (if sub_key is top_k) is not set or not - found in the dictionary then 0.0 is assumed. Note that these weights are - applied in addition to the weights based on the positive labels for each - class. - example_weighted: True if example weights should be applied. - - Returns: - Computation for performing the weighted macro average. - """ - key = metric_types.MetricKey( - name=metric_name, - model_name=model_name, - output_name=output_name, - sub_key=sub_key, - aggregation_type=metric_types.AggregationType(macro_average=True), - example_weighted=example_weighted) - - class_ids = [k.class_id for k in sub_keys if k.class_id is not None] - - # Compute the weights for labels. - computations = _class_weights_from_labels( - class_ids=class_ids, - eval_config=eval_config, - model_name=model_name, - output_name=output_name, - example_weighted=example_weighted) - # Class weights metrics are based on a single computation and key. - class_weights_from_labels_key = computations[0].keys[0] - - def result( - metrics: Dict[metric_types.MetricKey, Any] - ) -> Dict[metric_types.MetricKey, float]: - """Returns weighted macro average.""" - class_weights_from_labels = metrics[class_weights_from_labels_key] - total_value = 0.0 - total_weight = 0.0 - for sub_key in sub_keys: - child_key = metric_types.MetricKey( - name=metric_name, - model_name=model_name, - output_name=output_name, - sub_key=sub_key, - example_weighted=example_weighted) - if child_key not in metrics: - # Use private name if not found under metric name - child_key = metric_types.MetricKey( - name='_' + metric_name, - model_name=model_name, - output_name=output_name, - sub_key=sub_key, - example_weighted=example_weighted) - weight = 1.0 if not class_weights else 0.0 - offset = None - if (child_key.sub_key is not None and - child_key.sub_key.class_id is not None): - offset = child_key.sub_key.class_id - elif child_key.sub_key is not None and child_key.sub_key.k is not None: - offset = child_key.sub_key.k - if offset is not None: - if (class_weights_from_labels and - child_key.sub_key.class_id in class_weights_from_labels): - weight = class_weights_from_labels[offset] - if class_weights and child_key.sub_key.class_id in class_weights: - weight *= class_weights[offset] - total_value += _to_float(metrics[child_key]) * weight - total_weight += weight - average = total_value / total_weight if total_weight else float('nan') - return {key: average} - - derived_computation = metric_types.DerivedMetricComputation( - keys=[key], result=result) - computations.append(derived_computation) - return computations + example_weighted: bool = False, +) -> metric_types.MetricComputations: + """Returns metric computations for computing weighted macro average of metric. + + The weights per class are based on the percentage of positive labels for each + class. + + Args: + ---- + metric_name: Name of metric weighted average is being computed for. + sub_keys: Sub keys used to compute the metric (e.g. class_ids, etc). + eval_config: Eval config. + model_name: Optional model name. + output_name: Optional output name. + sub_key: Optional sub key associated with aggregation metric (e.g. top_k). + class_weights: Optional class weights to apply. Required if sub_key is not + provided. If class_weights are provided, but a sub_key.class_id (if + sub_key is None) or sub_key.k (if sub_key is top_k) is not set or not + found in the dictionary then 0.0 is assumed. Note that these weights are + applied in addition to the weights based on the positive labels for each + class. + example_weighted: True if example weights should be applied. + + Returns: + ------- + Computation for performing the weighted macro average. + """ + key = metric_types.MetricKey( + name=metric_name, + model_name=model_name, + output_name=output_name, + sub_key=sub_key, + aggregation_type=metric_types.AggregationType(macro_average=True), + example_weighted=example_weighted, + ) + + class_ids = [k.class_id for k in sub_keys if k.class_id is not None] + + # Compute the weights for labels. + computations = _class_weights_from_labels( + class_ids=class_ids, + eval_config=eval_config, + model_name=model_name, + output_name=output_name, + example_weighted=example_weighted, + ) + # Class weights metrics are based on a single computation and key. + class_weights_from_labels_key = computations[0].keys[0] + + def result( + metrics: Dict[metric_types.MetricKey, Any], + ) -> Dict[metric_types.MetricKey, float]: + """Returns weighted macro average.""" + class_weights_from_labels = metrics[class_weights_from_labels_key] + total_value = 0.0 + total_weight = 0.0 + for sub_key in sub_keys: + child_key = metric_types.MetricKey( + name=metric_name, + model_name=model_name, + output_name=output_name, + sub_key=sub_key, + example_weighted=example_weighted, + ) + if child_key not in metrics: + # Use private name if not found under metric name + child_key = metric_types.MetricKey( + name="_" + metric_name, + model_name=model_name, + output_name=output_name, + sub_key=sub_key, + example_weighted=example_weighted, + ) + weight = 1.0 if not class_weights else 0.0 + offset = None + if child_key.sub_key is not None and child_key.sub_key.class_id is not None: + offset = child_key.sub_key.class_id + elif child_key.sub_key is not None and child_key.sub_key.k is not None: + offset = child_key.sub_key.k + if offset is not None: + if ( + class_weights_from_labels + and child_key.sub_key.class_id in class_weights_from_labels + ): + weight = class_weights_from_labels[offset] + if class_weights and child_key.sub_key.class_id in class_weights: + weight *= class_weights[offset] + total_value += _to_float(metrics[child_key]) * weight + total_weight += weight + average = total_value / total_weight if total_weight else float("nan") + return {key: average} + + derived_computation = metric_types.DerivedMetricComputation( + keys=[key], result=result + ) + computations.append(derived_computation) + return computations def _to_float(value: Any) -> float: - try: - return float(value) - except (ValueError, TypeError): - raise ValueError( - '{} is not aggregatable: value={}\n\nThis is most likely caused by a ' - 'configuration error in which the aggregate option was applied ' - 'incorrectly.'.format(value.__class__.__name__, value)) + try: + return float(value) + except (ValueError, TypeError): + raise ValueError( + f"{value.__class__.__name__} is not aggregatable: value={value}\n\nThis is most likely caused by a " + "configuration error in which the aggregate option was applied " + "incorrectly." + ) def _class_weights_from_labels( class_ids: List[int], name: str = _CLASS_WEIGHTS_FROM_LABELS_NAME, eval_config: Optional[config_pb2.EvalConfig] = None, - model_name: str = '', - output_name: str = '', - example_weighted: bool = False) -> metric_types.MetricComputations: - """Returns metric computations for class weights based on labels. - - Args: - class_ids: List of class Ids to compute weighted labels from. - name: Metric name. - eval_config: Eval config. - model_name: Optional model name (if multi-model evaluation). - output_name: Optional output name (if multi-output model type). - example_weighted: True if example weights should be applied. - """ - key = metric_types.MetricKey( - name=name, - model_name=model_name, - output_name=output_name, - example_weighted=example_weighted) - return [ - metric_types.MetricComputation( - keys=[key], - preprocessors=None, # Use default - combiner=_ClassWeightsFromLabelsCombiner( - key, - eval_config=eval_config, - example_weighted=example_weighted, - class_ids=class_ids)) - ] + model_name: str = "", + output_name: str = "", + example_weighted: bool = False, +) -> metric_types.MetricComputations: + """Returns metric computations for class weights based on labels. + + Args: + ---- + class_ids: List of class Ids to compute weighted labels from. + name: Metric name. + eval_config: Eval config. + model_name: Optional model name (if multi-model evaluation). + output_name: Optional output name (if multi-output model type). + example_weighted: True if example weights should be applied. + """ + key = metric_types.MetricKey( + name=name, + model_name=model_name, + output_name=output_name, + example_weighted=example_weighted, + ) + return [ + metric_types.MetricComputation( + keys=[key], + preprocessors=None, # Use default + combiner=_ClassWeightsFromLabelsCombiner( + key, + eval_config=eval_config, + example_weighted=example_weighted, + class_ids=class_ids, + ), + ) + ] class _ClassWeightsFromLabelsCombiner(beam.CombineFn): - """Computes class weights from labels.""" - - def __init__(self, key: metric_types.MetricKey, - eval_config: Optional[config_pb2.EvalConfig], - class_ids: List[int], example_weighted: bool): - self._key = key - self._eval_config = eval_config - self._class_ids = class_ids - self._example_weighted = example_weighted - - def create_accumulator(self) -> Dict[int, float]: - return {i: 0.0 for i in self._class_ids} - - def add_input(self, accumulator: Dict[int, float], - element: metric_types.StandardMetricInputs) -> Dict[int, float]: - for label, _, example_weight in ( - metric_util.to_label_prediction_example_weight( + """Computes class weights from labels.""" + + def __init__( + self, + key: metric_types.MetricKey, + eval_config: Optional[config_pb2.EvalConfig], + class_ids: List[int], + example_weighted: bool, + ): + self._key = key + self._eval_config = eval_config + self._class_ids = class_ids + self._example_weighted = example_weighted + + def create_accumulator(self) -> Dict[int, float]: + return {i: 0.0 for i in self._class_ids} + + def add_input( + self, accumulator: Dict[int, float], element: metric_types.StandardMetricInputs + ) -> Dict[int, float]: + for label, _, example_weight in metric_util.to_label_prediction_example_weight( element, eval_config=self._eval_config, model_name=self._key.model_name, @@ -314,36 +343,39 @@ def add_input(self, accumulator: Dict[int, float], example_weighted=self._example_weighted, flatten=False, allow_none=True, - require_single_example_weight=True)): - example_weight = float(example_weight) - if label is not None: - for class_id in self._class_ids: - if label.size == 1: - label_value = float(label.item() == class_id) - else: - if class_id >= len(label): - raise ValueError( - 'class_id {} used with weighted_macro_average is outside the ' - 'range of the label provided: label={}, ' - 'StandardMetricInput={}'.format(class_id, label, element)) - label_value = float(label[class_id]) - accumulator[class_id] += label_value * example_weight - return accumulator - - def merge_accumulators( - self, accumulators: Iterable[Dict[int, float]]) -> Dict[int, float]: - accumulators = iter(accumulators) - result = next(accumulators) - for accumulator in accumulators: - for k, v in accumulator.items(): - result[k] += v - return result - - def extract_output( - self, accumulator: Dict[int, float] - ) -> Dict[metric_types.MetricKey, Dict[int, float]]: - total = sum(v for v in accumulator.values()) - class_weights = { - k: (v / total) if total else 0.0 for k, v in accumulator.items() - } - return {self._key: class_weights} + require_single_example_weight=True, + ): + example_weight = float(example_weight) + if label is not None: + for class_id in self._class_ids: + if label.size == 1: + label_value = float(label.item() == class_id) + else: + if class_id >= len(label): + raise ValueError( + f"class_id {class_id} used with weighted_macro_average is outside the " + f"range of the label provided: label={label}, " + f"StandardMetricInput={element}" + ) + label_value = float(label[class_id]) + accumulator[class_id] += label_value * example_weight + return accumulator + + def merge_accumulators( + self, accumulators: Iterable[Dict[int, float]] + ) -> Dict[int, float]: + accumulators = iter(accumulators) + result = next(accumulators) + for accumulator in accumulators: + for k, v in accumulator.items(): + result[k] += v + return result + + def extract_output( + self, accumulator: Dict[int, float] + ) -> Dict[metric_types.MetricKey, Dict[int, float]]: + total = sum(v for v in accumulator.values()) + class_weights = { + k: (v / total) if total else 0.0 for k, v in accumulator.items() + } + return {self._key: class_weights} diff --git a/tensorflow_model_analysis/metrics/aggregation_test.py b/tensorflow_model_analysis/metrics/aggregation_test.py index 1798ad7eac..24dbb799fb 100644 --- a/tensorflow_model_analysis/metrics/aggregation_test.py +++ b/tensorflow_model_analysis/metrics/aggregation_test.py @@ -14,212 +14,218 @@ """Tests for aggregation metrics.""" import copy + import apache_beam as beam -from apache_beam.testing import util import numpy as np import tensorflow as tf -from tensorflow_model_analysis.metrics import aggregation -from tensorflow_model_analysis.metrics import metric_types -from tensorflow_model_analysis.metrics import metric_util +from apache_beam.testing import util + +from tensorflow_model_analysis.metrics import aggregation, metric_types, metric_util from tensorflow_model_analysis.proto import config_pb2 from tensorflow_model_analysis.utils import test_util class AggregationMetricsTest(test_util.TensorflowModelAnalysisTest): - - def testOutputAverage(self): - metric_name = 'test' - computations = aggregation.output_average( - metric_name, output_weights={ - 'output_1': 0.3, - 'output_2': 0.7 - }) - metric = computations[0] - - sub_metrics = {} - output_names = ('output_1', 'output_2', 'output_3') - output_values = (0.1, 0.2, 0.3) - for output_name, output_value in zip(output_names, output_values): - key = metric_types.MetricKey(name=metric_name, output_name=output_name) - sub_metrics[key] = output_value - - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'Create' >> beam.Create([((), sub_metrics)]) - | 'ComputeMetric' >> beam.Map(lambda x: (x[0], metric.result(x[1])))) - - # pylint: enable=no-value-for-parameter - - def check_result(got): - try: - self.assertLen(got, 1) - got_slice_key, got_metrics = got[0] - self.assertEqual(got_slice_key, ()) - key = metric.keys[0] - expected_value = (0.3 * 0.1 + 0.7 * 0.2) / (0.3 + 0.7) - self.assertDictElementsAlmostEqual( - got_metrics, {key: expected_value}, places=5) - - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(result, check_result, label='result') - - def testMacroAverage(self): - metric_name = 'test' - class_ids = [0, 1, 2] - sub_keys = [metric_types.SubKey(class_id=i) for i in class_ids] - sub_key_values = [0.1, 0.2, 0.3] - computations = aggregation.macro_average( - metric_name, - sub_keys, - eval_config=config_pb2.EvalConfig(), - class_weights={ - 0: 1.0, - 1: 1.0, - 2: 1.0 - }) - metric = computations[0] - - sub_metrics = {} - for sub_key, value in zip(sub_keys, sub_key_values): - key = metric_types.MetricKey(name=metric_name, sub_key=sub_key) - sub_metrics[key] = value - - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'Create' >> beam.Create([((), sub_metrics)]) - | 'ComputeMetric' >> beam.Map(lambda x: (x[0], metric.result(x[1])))) - - # pylint: enable=no-value-for-parameter - - def check_result(got): - try: - self.assertLen(got, 1) - got_slice_key, got_metrics = got[0] - self.assertEqual(got_slice_key, ()) - key = metric.keys[0] - expected_value = (0.1 + 0.2 + 0.3) / 3.0 - self.assertDictElementsAlmostEqual( - got_metrics, {key: expected_value}, places=5) - - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(result, check_result, label='result') - - def testMacroAverageWithWeights(self): - metric_name = 'test' - class_ids = [0, 1, 2] - class_weights = {0: 0.2, 1: 0.3, 2: 0.5} - sub_keys = [metric_types.SubKey(class_id=i) for i in class_ids] - sub_key_values = [0.1, 0.2, 0.3] - computations = aggregation.macro_average( - metric_name, - sub_keys, - eval_config=config_pb2.EvalConfig(), - class_weights=class_weights) - metric = computations[0] - - sub_metrics = {} - for sub_key, value in zip(sub_keys, sub_key_values): - key = metric_types.MetricKey(name=metric_name, sub_key=sub_key) - sub_metrics[key] = value - - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'Create' >> beam.Create([((), sub_metrics)]) - | 'ComputeMetric' >> beam.Map(lambda x: (x[0], metric.result(x[1])))) - - # pylint: enable=no-value-for-parameter - - def check_result(got): - try: - self.assertLen(got, 1) - got_slice_key, got_metrics = got[0] - self.assertEqual(got_slice_key, ()) - key = metric.keys[0] - expected_value = (0.1 * 0.2 + 0.2 * 0.3 + 0.3 * 0.5) / (0.2 + 0.3 + - 0.5) - self.assertDictElementsAlmostEqual( - got_metrics, {key: expected_value}, places=5) - - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(result, check_result, label='result') - - def testWeightedMacroAverage(self): - example1 = { - 'labels': np.array([0.0, 1.0, 1.0]), - 'predictions': np.array([0.0, 0.3, 0.7]), - 'example_weights': np.array([1.0]), - } - example2 = { - 'labels': np.array([0.0, 1.0, 1.0]), - 'predictions': np.array([0.5, 0.3, 0.8]), - 'example_weights': np.array([1.0]), - } - example3 = { - 'labels': np.array([1.0, 0.0, 1.0]), - 'predictions': np.array([0.3, 0.7, 0.9]), - 'example_weights': np.array([1.0]), - } - - metric_name = 'test' - class_ids = [0, 1, 2] - sub_keys = [metric_types.SubKey(class_id=i) for i in class_ids] - sub_key_values = [0.1, 0.2, 0.3] - computations = aggregation.weighted_macro_average( - metric_name, sub_keys, eval_config=config_pb2.EvalConfig()) - class_weights = computations[0] - metric = computations[1] - - def create_sub_metrics(sliced_metrics): - slice_value, metrics = sliced_metrics - metrics = copy.copy(metrics) - for sub_key, value in zip(sub_keys, sub_key_values): - key = metric_types.MetricKey(name=metric_name, sub_key=sub_key) - metrics[key] = value - return (slice_value, metrics) - - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'Create' >> beam.Create([example1, example2, example3]) - | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs) - | 'AddSlice' >> beam.Map(lambda x: ((), x)) - | 'ComputeClassWeights' >> beam.CombinePerKey(class_weights.combiner) - | 'CreateSubMetric' >> beam.Map(create_sub_metrics) - | 'ComputeMetric' >> beam.Map(lambda x: (x[0], metric.result(x[1])))) - - # pylint: enable=no-value-for-parameter - - def check_result(got): - try: - self.assertLen(got, 1) - got_slice_key, got_metrics = got[0] - self.assertEqual(got_slice_key, ()) - key = metric.keys[0] - # Labels: 1 x class_0, 2 x class_1, 3 x class_2 - # Class weights: 0.125, .3333, .5 - expected_value = (0.1 * 1.0 / 6.0) + (0.2 * 2.0 / 6.0) + (0.3 * 3.0 / - 6.0) - self.assertDictElementsAlmostEqual( - got_metrics, {key: expected_value}, places=5) - - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(result, check_result, label='result') - - -if __name__ == '__main__': - tf.test.main() + def testOutputAverage(self): + metric_name = "test" + computations = aggregation.output_average( + metric_name, output_weights={"output_1": 0.3, "output_2": 0.7} + ) + metric = computations[0] + + sub_metrics = {} + output_names = ("output_1", "output_2", "output_3") + output_values = (0.1, 0.2, 0.3) + for output_name, output_value in zip(output_names, output_values): + key = metric_types.MetricKey(name=metric_name, output_name=output_name) + sub_metrics[key] = output_value + + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "Create" >> beam.Create([((), sub_metrics)]) + | "ComputeMetric" >> beam.Map(lambda x: (x[0], metric.result(x[1]))) + ) + + # pylint: enable=no-value-for-parameter + + def check_result(got): + try: + self.assertLen(got, 1) + got_slice_key, got_metrics = got[0] + self.assertEqual(got_slice_key, ()) + key = metric.keys[0] + expected_value = (0.3 * 0.1 + 0.7 * 0.2) / (0.3 + 0.7) + self.assertDictElementsAlmostEqual( + got_metrics, {key: expected_value}, places=5 + ) + + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(result, check_result, label="result") + + def testMacroAverage(self): + metric_name = "test" + class_ids = [0, 1, 2] + sub_keys = [metric_types.SubKey(class_id=i) for i in class_ids] + sub_key_values = [0.1, 0.2, 0.3] + computations = aggregation.macro_average( + metric_name, + sub_keys, + eval_config=config_pb2.EvalConfig(), + class_weights={0: 1.0, 1: 1.0, 2: 1.0}, + ) + metric = computations[0] + + sub_metrics = {} + for sub_key, value in zip(sub_keys, sub_key_values): + key = metric_types.MetricKey(name=metric_name, sub_key=sub_key) + sub_metrics[key] = value + + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "Create" >> beam.Create([((), sub_metrics)]) + | "ComputeMetric" >> beam.Map(lambda x: (x[0], metric.result(x[1]))) + ) + + # pylint: enable=no-value-for-parameter + + def check_result(got): + try: + self.assertLen(got, 1) + got_slice_key, got_metrics = got[0] + self.assertEqual(got_slice_key, ()) + key = metric.keys[0] + expected_value = (0.1 + 0.2 + 0.3) / 3.0 + self.assertDictElementsAlmostEqual( + got_metrics, {key: expected_value}, places=5 + ) + + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(result, check_result, label="result") + + def testMacroAverageWithWeights(self): + metric_name = "test" + class_ids = [0, 1, 2] + class_weights = {0: 0.2, 1: 0.3, 2: 0.5} + sub_keys = [metric_types.SubKey(class_id=i) for i in class_ids] + sub_key_values = [0.1, 0.2, 0.3] + computations = aggregation.macro_average( + metric_name, + sub_keys, + eval_config=config_pb2.EvalConfig(), + class_weights=class_weights, + ) + metric = computations[0] + + sub_metrics = {} + for sub_key, value in zip(sub_keys, sub_key_values): + key = metric_types.MetricKey(name=metric_name, sub_key=sub_key) + sub_metrics[key] = value + + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "Create" >> beam.Create([((), sub_metrics)]) + | "ComputeMetric" >> beam.Map(lambda x: (x[0], metric.result(x[1]))) + ) + + # pylint: enable=no-value-for-parameter + + def check_result(got): + try: + self.assertLen(got, 1) + got_slice_key, got_metrics = got[0] + self.assertEqual(got_slice_key, ()) + key = metric.keys[0] + expected_value = (0.1 * 0.2 + 0.2 * 0.3 + 0.3 * 0.5) / ( + 0.2 + 0.3 + 0.5 + ) + self.assertDictElementsAlmostEqual( + got_metrics, {key: expected_value}, places=5 + ) + + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(result, check_result, label="result") + + def testWeightedMacroAverage(self): + example1 = { + "labels": np.array([0.0, 1.0, 1.0]), + "predictions": np.array([0.0, 0.3, 0.7]), + "example_weights": np.array([1.0]), + } + example2 = { + "labels": np.array([0.0, 1.0, 1.0]), + "predictions": np.array([0.5, 0.3, 0.8]), + "example_weights": np.array([1.0]), + } + example3 = { + "labels": np.array([1.0, 0.0, 1.0]), + "predictions": np.array([0.3, 0.7, 0.9]), + "example_weights": np.array([1.0]), + } + + metric_name = "test" + class_ids = [0, 1, 2] + sub_keys = [metric_types.SubKey(class_id=i) for i in class_ids] + sub_key_values = [0.1, 0.2, 0.3] + computations = aggregation.weighted_macro_average( + metric_name, sub_keys, eval_config=config_pb2.EvalConfig() + ) + class_weights = computations[0] + metric = computations[1] + + def create_sub_metrics(sliced_metrics): + slice_value, metrics = sliced_metrics + metrics = copy.copy(metrics) + for sub_key, value in zip(sub_keys, sub_key_values): + key = metric_types.MetricKey(name=metric_name, sub_key=sub_key) + metrics[key] = value + return (slice_value, metrics) + + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "Create" >> beam.Create([example1, example2, example3]) + | "Process" >> beam.Map(metric_util.to_standard_metric_inputs) + | "AddSlice" >> beam.Map(lambda x: ((), x)) + | "ComputeClassWeights" >> beam.CombinePerKey(class_weights.combiner) + | "CreateSubMetric" >> beam.Map(create_sub_metrics) + | "ComputeMetric" >> beam.Map(lambda x: (x[0], metric.result(x[1]))) + ) + + # pylint: enable=no-value-for-parameter + + def check_result(got): + try: + self.assertLen(got, 1) + got_slice_key, got_metrics = got[0] + self.assertEqual(got_slice_key, ()) + key = metric.keys[0] + # Labels: 1 x class_0, 2 x class_1, 3 x class_2 + # Class weights: 0.125, .3333, .5 + expected_value = ( + (0.1 * 1.0 / 6.0) + (0.2 * 2.0 / 6.0) + (0.3 * 3.0 / 6.0) + ) + self.assertDictElementsAlmostEqual( + got_metrics, {key: expected_value}, places=5 + ) + + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(result, check_result, label="result") + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow_model_analysis/metrics/attributions.py b/tensorflow_model_analysis/metrics/attributions.py index 2beb1388d0..8a85e5167a 100644 --- a/tensorflow_model_analysis/metrics/attributions.py +++ b/tensorflow_model_analysis/metrics/attributions.py @@ -14,73 +14,79 @@ """Attribution related metrics.""" import functools - from typing import Any, Dict, Iterable, List, Optional, Union import apache_beam as beam import numpy as np + from tensorflow_model_analysis import constants -from tensorflow_model_analysis.metrics import example_count -from tensorflow_model_analysis.metrics import metric_specs -from tensorflow_model_analysis.metrics import metric_types -from tensorflow_model_analysis.metrics import metric_util +from tensorflow_model_analysis.metrics import ( + example_count, + metric_specs, + metric_types, + metric_util, +) from tensorflow_model_analysis.proto import config_pb2 from tensorflow_model_analysis.utils import util -TOTAL_ATTRIBUTIONS_NAME = 'total_attributions' -TOTAL_ABSOLUTE_ATTRIBUTIONS_NAME = 'total_absolute_attributions' -MEAN_ATTRIBUTIONS_NAME = 'mean_attributions' -MEAN_ABSOLUTE_ATTRIBUTIONS_NAME = 'mean_absolute_attributions' +TOTAL_ATTRIBUTIONS_NAME = "total_attributions" +TOTAL_ABSOLUTE_ATTRIBUTIONS_NAME = "total_absolute_attributions" +MEAN_ATTRIBUTIONS_NAME = "mean_attributions" +MEAN_ABSOLUTE_ATTRIBUTIONS_NAME = "mean_absolute_attributions" class AttributionsMetric(metric_types.Metric): - """Base type for attribution metrics.""" + """Base type for attribution metrics.""" -def has_attributions_metrics( - metrics_specs: Iterable[config_pb2.MetricsSpec]) -> bool: - """Returns true if any of the metrics_specs have attributions metrics.""" - tfma_metric_classes = metric_types.registered_metrics() - for metrics_spec in metrics_specs: - for metric_config in metrics_spec.metrics: - instance = metric_specs.metric_instance(metric_config, - tfma_metric_classes) - if isinstance(instance, AttributionsMetric): - return True - return False +def has_attributions_metrics(metrics_specs: Iterable[config_pb2.MetricsSpec]) -> bool: + """Returns true if any of the metrics_specs have attributions metrics.""" + tfma_metric_classes = metric_types.registered_metrics() + for metrics_spec in metrics_specs: + for metric_config in metrics_spec.metrics: + instance = metric_specs.metric_instance(metric_config, tfma_metric_classes) + if isinstance(instance, AttributionsMetric): + return True + return False class MeanAttributions(AttributionsMetric): - """Mean attributions metric.""" + """Mean attributions metric.""" - def __init__(self, name: str = MEAN_ATTRIBUTIONS_NAME): - """Initializes mean attributions metric. + def __init__(self, name: str = MEAN_ATTRIBUTIONS_NAME): + """Initializes mean attributions metric. - Args: - name: Attribution metric name. - """ - super().__init__( - metric_util.merge_per_key_computations( - functools.partial(_mean_attributions, False)), - name=name) + Args: + ---- + name: Attribution metric name. + """ + super().__init__( + metric_util.merge_per_key_computations( + functools.partial(_mean_attributions, False) + ), + name=name, + ) metric_types.register_metric(MeanAttributions) class MeanAbsoluteAttributions(AttributionsMetric): - """Mean aboslute attributions metric.""" + """Mean aboslute attributions metric.""" - def __init__(self, name: str = MEAN_ABSOLUTE_ATTRIBUTIONS_NAME): - """Initializes mean absolute attributions metric. + def __init__(self, name: str = MEAN_ABSOLUTE_ATTRIBUTIONS_NAME): + """Initializes mean absolute attributions metric. - Args: - name: Attribution metric name. - """ - super().__init__( - metric_util.merge_per_key_computations( - functools.partial(_mean_attributions, True)), - name=name) + Args: + ---- + name: Attribution metric name. + """ + super().__init__( + metric_util.merge_per_key_computations( + functools.partial(_mean_attributions, True) + ), + name=name, + ) metric_types.register_metric(MeanAbsoluteAttributions) @@ -90,88 +96,99 @@ def _mean_attributions( absolute: bool = True, name: str = MEAN_ATTRIBUTIONS_NAME, eval_config: Optional[config_pb2.EvalConfig] = None, - model_name: str = '', - output_name: str = '', + model_name: str = "", + output_name: str = "", sub_key: Optional[metric_types.SubKey] = None, example_weighted: bool = False, ) -> metric_types.MetricComputations: - """Returns metric computations for mean attributions.""" - key = metric_types.AttributionsKey( - name=name, - model_name=model_name, - output_name=output_name, - sub_key=sub_key, - example_weighted=example_weighted) - - # Make sure total_attributions is calculated. - computations = _total_attributions_computations( - absolute=absolute, - eval_config=eval_config, - model_name=model_name, - output_name=output_name, - sub_key=sub_key, - example_weighted=example_weighted) - total_attributions_key = computations[-1].keys[-1] - # Make sure example_count is calculated - computations.extend( - example_count.example_count( - model_names=[model_name], - output_names=[output_name], - sub_keys=[sub_key], - example_weighted=example_weighted)) - example_count_key = computations[-1].keys[-1] - - def result( - metrics: Dict[metric_types.MetricKey, Any] - ) -> Dict[metric_types.AttributionsKey, Dict[str, Union[float, np.ndarray]]]: - """Returns mean attributions.""" - total_attributions = metrics[total_attributions_key] - count = metrics[example_count_key] - attributions = {} - for k, v in total_attributions.items(): - if np.isclose(count, 0.0): - attributions[k] = float('nan') - else: - attributions[k] = v / count - return {key: attributions} - - derived_computation = metric_types.DerivedMetricComputation( - keys=[key], result=result) - computations.append(derived_computation) - return computations + """Returns metric computations for mean attributions.""" + key = metric_types.AttributionsKey( + name=name, + model_name=model_name, + output_name=output_name, + sub_key=sub_key, + example_weighted=example_weighted, + ) + + # Make sure total_attributions is calculated. + computations = _total_attributions_computations( + absolute=absolute, + eval_config=eval_config, + model_name=model_name, + output_name=output_name, + sub_key=sub_key, + example_weighted=example_weighted, + ) + total_attributions_key = computations[-1].keys[-1] + # Make sure example_count is calculated + computations.extend( + example_count.example_count( + model_names=[model_name], + output_names=[output_name], + sub_keys=[sub_key], + example_weighted=example_weighted, + ) + ) + example_count_key = computations[-1].keys[-1] + + def result( + metrics: Dict[metric_types.MetricKey, Any], + ) -> Dict[metric_types.AttributionsKey, Dict[str, Union[float, np.ndarray]]]: + """Returns mean attributions.""" + total_attributions = metrics[total_attributions_key] + count = metrics[example_count_key] + attributions = {} + for k, v in total_attributions.items(): + if np.isclose(count, 0.0): + attributions[k] = float("nan") + else: + attributions[k] = v / count + return {key: attributions} + + derived_computation = metric_types.DerivedMetricComputation( + keys=[key], result=result + ) + computations.append(derived_computation) + return computations class TotalAttributions(AttributionsMetric): - """Total attributions metric.""" + """Total attributions metric.""" - def __init__(self, name: str = TOTAL_ATTRIBUTIONS_NAME): - """Initializes total attributions metric. + def __init__(self, name: str = TOTAL_ATTRIBUTIONS_NAME): + """Initializes total attributions metric. - Args: - name: Attribution metric name. - """ - super().__init__( - metric_util.merge_per_key_computations( - functools.partial(_total_attributions, False)), - name=name) + Args: + ---- + name: Attribution metric name. + """ + super().__init__( + metric_util.merge_per_key_computations( + functools.partial(_total_attributions, False) + ), + name=name, + ) metric_types.register_metric(TotalAttributions) class TotalAbsoluteAttributions(AttributionsMetric): - """Total absolute attributions metric.""" + """Total absolute attributions metric.""" - def __init__(self, name: str = TOTAL_ABSOLUTE_ATTRIBUTIONS_NAME): - """Initializes total absolute attributions metric. + def __init__(self, name: str = TOTAL_ABSOLUTE_ATTRIBUTIONS_NAME): + """Initializes total absolute attributions metric. - Args: - name: Attribution metric name. - """ - super().__init__( - metric_util.merge_per_key_computations( - functools.partial(_total_attributions, True)), - name=name) + Args: + ---- + name: Attribution metric name. + """ + super().__init__( + metric_util.merge_per_key_computations( + functools.partial(_total_attributions, True) + ), + name=name, + ) metric_types.register_metric(TotalAbsoluteAttributions) @@ -179,188 +196,204 @@ def __init__(self, name: str = TOTAL_ABSOLUTE_ATTRIBUTIONS_NAME): def _total_attributions( absolute: bool = True, - name: str = '', + name: str = "", eval_config: Optional[config_pb2.EvalConfig] = None, - model_name: str = '', - output_name: str = '', + model_name: str = "", + output_name: str = "", sub_key: Optional[metric_types.SubKey] = None, - example_weighted: bool = False) -> metric_types.MetricComputations: - """Returns metric computations for total attributions.""" - key = metric_types.AttributionsKey( - name=name, - model_name=model_name, - output_name=output_name, - sub_key=sub_key, - example_weighted=example_weighted) - - # Make sure total_attributions is calculated. - computations = _total_attributions_computations( - absolute=absolute, - eval_config=eval_config, - model_name=model_name, - output_name=output_name, - sub_key=sub_key, - example_weighted=example_weighted) - private_key = computations[-1].keys[-1] - - def result( - metrics: Dict[metric_types.MetricKey, Any] - ) -> Dict[metric_types.AttributionsKey, Dict[str, Union[float, np.ndarray]]]: - """Returns total attributions.""" - return {key: metrics[private_key]} - - derived_computation = metric_types.DerivedMetricComputation( - keys=[key], result=result) - computations.append(derived_computation) - return computations + example_weighted: bool = False, +) -> metric_types.MetricComputations: + """Returns metric computations for total attributions.""" + key = metric_types.AttributionsKey( + name=name, + model_name=model_name, + output_name=output_name, + sub_key=sub_key, + example_weighted=example_weighted, + ) + + # Make sure total_attributions is calculated. + computations = _total_attributions_computations( + absolute=absolute, + eval_config=eval_config, + model_name=model_name, + output_name=output_name, + sub_key=sub_key, + example_weighted=example_weighted, + ) + private_key = computations[-1].keys[-1] + + def result( + metrics: Dict[metric_types.MetricKey, Any], + ) -> Dict[metric_types.AttributionsKey, Dict[str, Union[float, np.ndarray]]]: + """Returns total attributions.""" + return {key: metrics[private_key]} + + derived_computation = metric_types.DerivedMetricComputation( + keys=[key], result=result + ) + computations.append(derived_computation) + return computations def _total_attributions_computations( absolute: bool = True, - name: str = '', + name: str = "", eval_config: Optional[config_pb2.EvalConfig] = None, - model_name: str = '', - output_name: str = '', + model_name: str = "", + output_name: str = "", sub_key: Optional[metric_types.SubKey] = None, - example_weighted: bool = False) -> metric_types.MetricComputations: - """Returns metric computations for total attributions. - - Args: - absolute: True to use absolute value when summing. - name: Metric name. - eval_config: Eval config. - model_name: Optional model name (if multi-model evaluation). - output_name: Optional output name (if multi-output model type). - sub_key: Optional sub key. - example_weighted: True if example weights should be applied. - """ - if not name: - if absolute: - name = '_' + TOTAL_ABSOLUTE_ATTRIBUTIONS_NAME - else: - name = '_' + TOTAL_ATTRIBUTIONS_NAME - key = metric_types.AttributionsKey( - name=name, - model_name=model_name, - output_name=output_name, - sub_key=sub_key, - example_weighted=example_weighted) - return [ - metric_types.MetricComputation( - keys=[key], - preprocessors=[metric_types.AttributionPreprocessor(feature_keys={})], - combiner=_TotalAttributionsCombiner(key, eval_config, absolute)) - ] + example_weighted: bool = False, +) -> metric_types.MetricComputations: + """Returns metric computations for total attributions. + + Args: + ---- + absolute: True to use absolute value when summing. + name: Metric name. + eval_config: Eval config. + model_name: Optional model name (if multi-model evaluation). + output_name: Optional output name (if multi-output model type). + sub_key: Optional sub key. + example_weighted: True if example weights should be applied. + """ + if not name: + if absolute: + name = "_" + TOTAL_ABSOLUTE_ATTRIBUTIONS_NAME + else: + name = "_" + TOTAL_ATTRIBUTIONS_NAME + key = metric_types.AttributionsKey( + name=name, + model_name=model_name, + output_name=output_name, + sub_key=sub_key, + example_weighted=example_weighted, + ) + return [ + metric_types.MetricComputation( + keys=[key], + preprocessors=[metric_types.AttributionPreprocessor(feature_keys={})], + combiner=_TotalAttributionsCombiner(key, eval_config, absolute), + ) + ] @beam.typehints.with_input_types(metric_types.StandardMetricInputs) -@beam.typehints.with_output_types(Dict[metric_types.AttributionsKey, - Dict[str, Union[float, np.ndarray]]]) +@beam.typehints.with_output_types( + Dict[metric_types.AttributionsKey, Dict[str, Union[float, np.ndarray]]] +) class _TotalAttributionsCombiner(beam.CombineFn): - """Computes total attributions.""" - - def __init__(self, key: metric_types.AttributionsKey, - eval_config: Optional[config_pb2.EvalConfig], absolute: bool): - self._key = key - self._eval_config = eval_config - self._absolute = absolute - - def _sum(self, a: List[float], b: Union[np.ndarray, List[float]]): - """Adds values in b to a at matching offsets.""" - if (isinstance(b, (float, np.floating)) or - (isinstance(b, np.ndarray) and b.size == 1)): - if len(a) != 1: - raise ValueError( - 'Attributions have different array sizes {} != {}'.format(a, b)) - a[0] += abs(float(b)) if self._absolute else float(b) - else: - if len(a) != len(b): - raise ValueError( - 'Attributions have different array sizes {} != {}'.format(a, b)) - for i, v in enumerate(b): - a[i] += abs(v) if self._absolute else v - - def create_accumulator(self) -> Dict[str, List[float]]: - return {} - - def add_input( - self, accumulator: Dict[str, List[float]], - extracts: metric_types.StandardMetricInputs) -> Dict[str, List[float]]: - if constants.ATTRIBUTIONS_KEY not in extracts: - raise ValueError( - '{} missing from extracts {}\n\n. An attribution extractor is ' - 'required to use attribution metrics'.format( - constants.ATTRIBUTIONS_KEY, extracts)) - attributions = extracts[constants.ATTRIBUTIONS_KEY] - if self._key.model_name: - attributions = util.get_by_keys(attributions, [self._key.model_name]) - if self._key.output_name: - attributions = util.get_by_keys(attributions, [self._key.output_name]) - _, _, example_weight = next( - metric_util.to_label_prediction_example_weight( - extracts, - eval_config=self._eval_config, - model_name=self._key.model_name, - output_name=self._key.output_name, - sub_key=self._key.sub_key, - example_weighted=self._key.example_weighted, - allow_none=True, - flatten=False)) - example_weight = float(example_weight) - for k, v in attributions.items(): - v = util.to_numpy(v) - if self._key.sub_key is not None: - if self._key.sub_key.class_id is not None: - v = _scores_by_class_id(self._key.sub_key.class_id, v) - elif self._key.sub_key.k is not None: - v = _scores_by_top_k(self._key.sub_key.k, v) - v = np.array(v[self._key.sub_key.k - 1]) - elif self._key.sub_key.top_k is not None: - v = _scores_by_top_k(self._key.sub_key.top_k, v) - if k not in accumulator: - accumulator[k] = [0.0] * v.size - self._sum(accumulator[k], v * example_weight) - return accumulator - - def merge_accumulators( - self, - accumulators: Iterable[Dict[str, List[float]]]) -> Dict[str, List[float]]: - accumulators = iter(accumulators) - result = next(accumulators) - for accumulator in accumulators: - for k, v in accumulator.items(): - if k in result: - self._sum(result[k], v) + """Computes total attributions.""" + + def __init__( + self, + key: metric_types.AttributionsKey, + eval_config: Optional[config_pb2.EvalConfig], + absolute: bool, + ): + self._key = key + self._eval_config = eval_config + self._absolute = absolute + + def _sum(self, a: List[float], b: Union[np.ndarray, List[float]]): + """Adds values in b to a at matching offsets.""" + if isinstance(b, (float, np.floating)) or ( + isinstance(b, np.ndarray) and b.size == 1 + ): + if len(a) != 1: + raise ValueError(f"Attributions have different array sizes {a} != {b}") + a[0] += abs(float(b)) if self._absolute else float(b) else: - result[k] = v - return result - - def extract_output( - self, accumulator: Dict[str, List[float]] - ) -> Dict[metric_types.AttributionsKey, Dict[str, Union[float, np.ndarray]]]: - result = {} - for k, v in accumulator.items(): - result[k] = v[0] if len(v) == 1 else np.array(v) - return {self._key: result} + if len(a) != len(b): + raise ValueError(f"Attributions have different array sizes {a} != {b}") + for i, v in enumerate(b): + a[i] += abs(v) if self._absolute else v + + def create_accumulator(self) -> Dict[str, List[float]]: + return {} + + def add_input( + self, + accumulator: Dict[str, List[float]], + extracts: metric_types.StandardMetricInputs, + ) -> Dict[str, List[float]]: + if constants.ATTRIBUTIONS_KEY not in extracts: + raise ValueError( + f"{constants.ATTRIBUTIONS_KEY} missing from extracts {extracts}\n\n. An attribution extractor is " + "required to use attribution metrics" + ) + attributions = extracts[constants.ATTRIBUTIONS_KEY] + if self._key.model_name: + attributions = util.get_by_keys(attributions, [self._key.model_name]) + if self._key.output_name: + attributions = util.get_by_keys(attributions, [self._key.output_name]) + _, _, example_weight = next( + metric_util.to_label_prediction_example_weight( + extracts, + eval_config=self._eval_config, + model_name=self._key.model_name, + output_name=self._key.output_name, + sub_key=self._key.sub_key, + example_weighted=self._key.example_weighted, + allow_none=True, + flatten=False, + ) + ) + example_weight = float(example_weight) + for k, v in attributions.items(): + v = util.to_numpy(v) + if self._key.sub_key is not None: + if self._key.sub_key.class_id is not None: + v = _scores_by_class_id(self._key.sub_key.class_id, v) + elif self._key.sub_key.k is not None: + v = _scores_by_top_k(self._key.sub_key.k, v) + v = np.array(v[self._key.sub_key.k - 1]) + elif self._key.sub_key.top_k is not None: + v = _scores_by_top_k(self._key.sub_key.top_k, v) + if k not in accumulator: + accumulator[k] = [0.0] * v.size + self._sum(accumulator[k], v * example_weight) + return accumulator + + def merge_accumulators( + self, accumulators: Iterable[Dict[str, List[float]]] + ) -> Dict[str, List[float]]: + accumulators = iter(accumulators) + result = next(accumulators) + for accumulator in accumulators: + for k, v in accumulator.items(): + if k in result: + self._sum(result[k], v) + else: + result[k] = v + return result + + def extract_output( + self, accumulator: Dict[str, List[float]] + ) -> Dict[metric_types.AttributionsKey, Dict[str, Union[float, np.ndarray]]]: + result = {} + for k, v in accumulator.items(): + result[k] = v[0] if len(v) == 1 else np.array(v) + return {self._key: result} def _scores_by_class_id(class_id: int, scores: np.ndarray) -> np.ndarray: - """Returns selected class ID or raises ValueError.""" - if class_id < 0 or class_id >= len(scores): - raise ValueError('class_id "{}" out of range for attribution {}'.format( - class_id, scores)) - return scores[class_id] + """Returns selected class ID or raises ValueError.""" + if class_id < 0 or class_id >= len(scores): + raise ValueError(f'class_id "{class_id}" out of range for attribution {scores}') + return scores[class_id] def _scores_by_top_k(top_k: int, scores: np.ndarray) -> np.ndarray: - """Returns top_k scores or raises ValueError if invalid value for top_k.""" - if scores.shape[-1] < top_k: - raise ValueError( - 'not enough attributions were provided to perform the requested ' - 'calcuations for top k. The requested value for k is {}, but the ' - 'values are {}\n\nThis may be caused by a metric configuration error ' - 'or an error in the pipeline.'.format(top_k, scores)) - - indices = np.argpartition(scores, -top_k)[-top_k:] - indices = indices[np.argsort(-scores[indices])] - return scores[indices] + """Returns top_k scores or raises ValueError if invalid value for top_k.""" + if scores.shape[-1] < top_k: + raise ValueError( + "not enough attributions were provided to perform the requested " + f"calcuations for top k. The requested value for k is {top_k}, but the " + f"values are {scores}\n\nThis may be caused by a metric configuration error " + "or an error in the pipeline." + ) + + indices = np.argpartition(scores, -top_k)[-top_k:] + indices = indices[np.argsort(-scores[indices])] + return scores[indices] diff --git a/tensorflow_model_analysis/metrics/attributions_test.py b/tensorflow_model_analysis/metrics/attributions_test.py index a7c7a939a6..c273d4a096 100644 --- a/tensorflow_model_analysis/metrics/attributions_test.py +++ b/tensorflow_model_analysis/metrics/attributions_test.py @@ -13,519 +13,536 @@ # limitations under the License. """Tests for attributions metrics.""" -from absl.testing import parameterized import apache_beam as beam -from apache_beam.testing import util import numpy as np import tensorflow as tf -from tensorflow_model_analysis.metrics import attributions -from tensorflow_model_analysis.metrics import metric_specs -from tensorflow_model_analysis.metrics import metric_types -from tensorflow_model_analysis.metrics import metric_util +from absl.testing import parameterized +from apache_beam.testing import util + +from tensorflow_model_analysis.metrics import ( + attributions, + metric_specs, + metric_types, + metric_util, +) from tensorflow_model_analysis.utils import test_util from tensorflow_model_analysis.utils.keras_lib import tf_keras -class AttributionsTest( - test_util.TensorflowModelAnalysisTest, parameterized.TestCase -): - - def testHasAttributionsMetrics(self): - specs_with_attributions = metric_specs.specs_from_metrics({ - 'output_name': [ - tf_keras.metrics.MeanSquaredError('mse'), - attributions.TotalAttributions(), - ] - }) - self.assertTrue( - attributions.has_attributions_metrics(specs_with_attributions)) - specs_without_attributions = metric_specs.specs_from_metrics([ - tf_keras.metrics.MeanSquaredError('mse'), - ]) - self.assertFalse( - attributions.has_attributions_metrics(specs_without_attributions)) - - def testMeanAttributions(self): - computation = attributions.MeanAttributions().computations()[-1] - - total_attributions_key = metric_types.AttributionsKey( - name='_total_attributions') - example_count_key = metric_types.MetricKey(name='example_count') - metrics = { - total_attributions_key: { - 'feature1': 1.0, - 'feature2': -2.0 +class AttributionsTest(test_util.TensorflowModelAnalysisTest, parameterized.TestCase): + def testHasAttributionsMetrics(self): + specs_with_attributions = metric_specs.specs_from_metrics( + { + "output_name": [ + tf_keras.metrics.MeanSquaredError("mse"), + attributions.TotalAttributions(), + ] + } + ) + self.assertTrue(attributions.has_attributions_metrics(specs_with_attributions)) + specs_without_attributions = metric_specs.specs_from_metrics( + [ + tf_keras.metrics.MeanSquaredError("mse"), + ] + ) + self.assertFalse( + attributions.has_attributions_metrics(specs_without_attributions) + ) + + def testMeanAttributions(self): + computation = attributions.MeanAttributions().computations()[-1] + + total_attributions_key = metric_types.AttributionsKey( + name="_total_attributions" + ) + example_count_key = metric_types.MetricKey(name="example_count") + metrics = { + total_attributions_key: {"feature1": 1.0, "feature2": -2.0}, + example_count_key: 0.5, + } + + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "Create" >> beam.Create([metrics]) + | "ComputeMetric" >> beam.Map(lambda x: ((), computation.result(x))) + ) + + # pylint: enable=no-value-for-parameter + + def check_result(got): + try: + self.assertLen(got, 1) + got_slice_key, got_attributions = got[0] + self.assertEqual(got_slice_key, ()) + mean_attributions_key = metric_types.AttributionsKey( + name="mean_attributions" + ) + self.assertIn(mean_attributions_key, got_attributions) + self.assertDictElementsAlmostEqual( + got_attributions[mean_attributions_key], + { + "feature1": 1.0 / 0.5, + "feature2": -2.0 / 0.5, + }, + ) + + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(result, check_result, label="result") + + def testMeanAbsoluteAttributions(self): + computation = attributions.MeanAbsoluteAttributions().computations()[-1] + + total_absolute_attributions_key = metric_types.AttributionsKey( + name="_total_absolute_attributions" + ) + example_count_key = metric_types.MetricKey(name="example_count") + metrics = { + total_absolute_attributions_key: {"feature1": 1.0, "feature2": 2.0}, + example_count_key: 0.5, + } + + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "Create" >> beam.Create([metrics]) + | "ComputeMetric" >> beam.Map(lambda x: ((), computation.result(x))) + ) + + # pylint: enable=no-value-for-parameter + + def check_result(got): + try: + self.assertLen(got, 1) + got_slice_key, got_attributions = got[0] + self.assertEqual(got_slice_key, ()) + mean_attributions_key = metric_types.AttributionsKey( + name="mean_absolute_attributions" + ) + self.assertIn(mean_attributions_key, got_attributions) + self.assertDictElementsAlmostEqual( + got_attributions[mean_attributions_key], + { + "feature1": 1.0 / 0.5, + "feature2": 2.0 / 0.5, + }, + ) + + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(result, check_result, label="result") + + @parameterized.named_parameters( + { + "testcase_name": "basic", + "model_name": "", + "output_name": "", + "examples": [ + { + "labels": None, + "predictions": None, + "example_weights": np.array(1.0), + "attributions": { + "feature1": 1.1, + "feature2": -1.2, + }, + }, + { + "labels": None, + "predictions": None, + "example_weights": np.array(1.0), + "attributions": {"feature1": -2.1, "feature2": 2.2}, + }, + { + "labels": None, + "predictions": None, + "example_weights": np.array(1.0), + "attributions": {"feature1": 3.1, "feature2": -3.2}, + }, + ], + "expected_values": { + "feature1": (1.1 - 2.1 + 3.1), + "feature2": (-1.2 + 2.2 - 3.2), + }, }, - example_count_key: 0.5 - } - - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'Create' >> beam.Create([metrics]) - | 'ComputeMetric' >> beam.Map(lambda x: ((), computation.result(x)))) - - # pylint: enable=no-value-for-parameter - - def check_result(got): - try: - self.assertLen(got, 1) - got_slice_key, got_attributions = got[0] - self.assertEqual(got_slice_key, ()) - mean_attributions_key = metric_types.AttributionsKey( - name='mean_attributions') - self.assertIn(mean_attributions_key, got_attributions) - self.assertDictElementsAlmostEqual( - got_attributions[mean_attributions_key], { - 'feature1': 1.0 / 0.5, - 'feature2': -2.0 / 0.5, - }) - - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(result, check_result, label='result') - - def testMeanAbsoluteAttributions(self): - computation = attributions.MeanAbsoluteAttributions().computations()[-1] - - total_absolute_attributions_key = metric_types.AttributionsKey( - name='_total_absolute_attributions') - example_count_key = metric_types.MetricKey(name='example_count') - metrics = { - total_absolute_attributions_key: { - 'feature1': 1.0, - 'feature2': 2.0 + { + "testcase_name": "multi-model", + "model_name": "model", + "output_name": "", + "examples": [ + { + "labels": None, + "predictions": None, + "example_weights": np.array(1.0), + "attributions": { + "model": {"feature1": 11.1, "feature2": -11.2}, + }, + }, + { + "labels": None, + "predictions": None, + "example_weights": np.array(1.0), + "attributions": { + "model": {"feature1": -22.1, "feature2": 22.2}, + }, + }, + { + "labels": None, + "predictions": None, + "example_weights": np.array(1.0), + "attributions": { + "model": {"feature1": 33.1, "feature2": -33.2}, + }, + }, + ], + "expected_values": { + "feature1": (11.1 - 22.1 + 33.1), + "feature2": (-11.2 + 22.2 - 33.2), + }, }, - example_count_key: 0.5 - } - - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'Create' >> beam.Create([metrics]) - | 'ComputeMetric' >> beam.Map(lambda x: ((), computation.result(x)))) - - # pylint: enable=no-value-for-parameter - - def check_result(got): - try: - self.assertLen(got, 1) - got_slice_key, got_attributions = got[0] - self.assertEqual(got_slice_key, ()) - mean_attributions_key = metric_types.AttributionsKey( - name='mean_absolute_attributions') - self.assertIn(mean_attributions_key, got_attributions) - self.assertDictElementsAlmostEqual( - got_attributions[mean_attributions_key], { - 'feature1': 1.0 / 0.5, - 'feature2': 2.0 / 0.5, - }) - - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(result, check_result, label='result') - - @parameterized.named_parameters( - { - 'testcase_name': 'basic', - 'model_name': '', - 'output_name': '', - 'examples': [{ - 'labels': None, - 'predictions': None, - 'example_weights': np.array(1.0), - 'attributions': { - 'feature1': 1.1, - 'feature2': -1.2, - } - }, { - 'labels': None, - 'predictions': None, - 'example_weights': np.array(1.0), - 'attributions': { - 'feature1': -2.1, - 'feature2': 2.2 - } - }, { - 'labels': None, - 'predictions': None, - 'example_weights': np.array(1.0), - 'attributions': { - 'feature1': 3.1, - 'feature2': -3.2 - } - }], - 'expected_values': { - 'feature1': (1.1 - 2.1 + 3.1), - 'feature2': (-1.2 + 2.2 - 3.2), - }, - }, - { - 'testcase_name': 'multi-model', - 'model_name': 'model', - 'output_name': '', - 'examples': [{ - 'labels': None, - 'predictions': None, - 'example_weights': np.array(1.0), - 'attributions': { - 'model': { - 'feature1': 11.1, - 'feature2': -11.2 - }, - } - }, { - 'labels': None, - 'predictions': None, - 'example_weights': np.array(1.0), - 'attributions': { - 'model': { - 'feature1': -22.1, - 'feature2': 22.2 - }, - } - }, { - 'labels': None, - 'predictions': None, - 'example_weights': np.array(1.0), - 'attributions': { - 'model': { - 'feature1': 33.1, - 'feature2': -33.2 - }, - } - }], - 'expected_values': { - 'feature1': (11.1 - 22.1 + 33.1), - 'feature2': (-11.2 + 22.2 - 33.2), - }, - }, - { - 'testcase_name': 'multi-model-multi-output', - 'model_name': 'model', - 'output_name': 'output', - 'examples': [{ - 'labels': None, - 'predictions': None, - 'example_weights': np.array(1.0), - 'attributions': { - 'model': { - 'output': { - 'feature1': 111.1, - 'feature2': -111.2 - }, - }, - } - }, { - 'labels': None, - 'predictions': None, - 'example_weights': np.array(1.0), - 'attributions': { - 'model': { - 'output': { - 'feature1': -222.1, - 'feature2': 222.2 - }, - }, - } - }, { - 'labels': None, - 'predictions': None, - 'example_weights': np.array(1.0), - 'attributions': { - 'model': { - 'output': { - 'feature1': 333.1, - 'feature2': -333.2 - }, - }, - } - }], - 'expected_values': { - 'feature1': (111.1 - 222.1 + 333.1), - 'feature2': (-111.2 + 222.2 - 333.2), - }, - }, - ) - def testTotalAttributionsWithMultiModelsAndOutputs(self, model_name, - output_name, examples, - expected_values): - computations = attributions.TotalAttributions().computations( - model_names=[model_name], output_names=[output_name]) - - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'Create' >> beam.Create(examples) - | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs) - | 'AddSlice' >> beam.Map(lambda x: ((), x)) - | - 'CombineAttributions' >> beam.CombinePerKey(computations[0].combiner) - | 'ComputeResult' >> beam.Map( # comment to add lamda on own line - lambda x: (x[0], computations[1].result(x[1])))) - - # pylint: enable=no-value-for-parameter - - def check_result(got): - try: - self.assertLen(got, 1) - got_slice_key, got_attributions = got[0] - self.assertEqual(got_slice_key, ()) - total_attributions_key = metric_types.AttributionsKey( - name='total_attributions', - model_name=model_name, - output_name=output_name) - self.assertIn(total_attributions_key, got_attributions) - self.assertDictElementsAlmostEqual( - got_attributions[total_attributions_key], expected_values) - - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(result, check_result, label='result') - - @parameterized.named_parameters(('empty', None, { - 'feature1': np.array([6.33, 6.39, 6.36]), - 'feature2': np.array([6.63, 6.69, 6.66]), - }), ('class_id', metric_types.SubKey(class_id=0), { - 'feature1': 6.33, - 'feature2': 6.63, - }), ('k', metric_types.SubKey(k=2), { - 'feature1': 6.36, - 'feature2': 6.66, - }), ('top_k', metric_types.SubKey(top_k=2), { - 'feature1': np.array([6.39, 6.36]), - 'feature2': np.array([6.69, 6.66]), - })) - def testTotalAttributionsWithSubKeys(self, sub_key, expected_values): - computations = attributions.TotalAttributions().computations( - sub_keys=[sub_key]) - - example1 = { - 'labels': None, - 'predictions': None, - 'example_weights': np.array(1.0), - 'attributions': { - 'feature1': [1.11, 1.13, 1.12], - 'feature2': [1.21, 1.23, 1.22] + { + "testcase_name": "multi-model-multi-output", + "model_name": "model", + "output_name": "output", + "examples": [ + { + "labels": None, + "predictions": None, + "example_weights": np.array(1.0), + "attributions": { + "model": { + "output": {"feature1": 111.1, "feature2": -111.2}, + }, + }, + }, + { + "labels": None, + "predictions": None, + "example_weights": np.array(1.0), + "attributions": { + "model": { + "output": {"feature1": -222.1, "feature2": 222.2}, + }, + }, + }, + { + "labels": None, + "predictions": None, + "example_weights": np.array(1.0), + "attributions": { + "model": { + "output": {"feature1": 333.1, "feature2": -333.2}, + }, + }, + }, + ], + "expected_values": { + "feature1": (111.1 - 222.1 + 333.1), + "feature2": (-111.2 + 222.2 - 333.2), + }, + }, + ) + def testTotalAttributionsWithMultiModelsAndOutputs( + self, model_name, output_name, examples, expected_values + ): + computations = attributions.TotalAttributions().computations( + model_names=[model_name], output_names=[output_name] + ) + + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "Create" >> beam.Create(examples) + | "Process" >> beam.Map(metric_util.to_standard_metric_inputs) + | "AddSlice" >> beam.Map(lambda x: ((), x)) + | "CombineAttributions" >> beam.CombinePerKey(computations[0].combiner) + | "ComputeResult" + >> beam.Map( # comment to add lamda on own line + lambda x: (x[0], computations[1].result(x[1])) + ) + ) + + # pylint: enable=no-value-for-parameter + + def check_result(got): + try: + self.assertLen(got, 1) + got_slice_key, got_attributions = got[0] + self.assertEqual(got_slice_key, ()) + total_attributions_key = metric_types.AttributionsKey( + name="total_attributions", + model_name=model_name, + output_name=output_name, + ) + self.assertIn(total_attributions_key, got_attributions) + self.assertDictElementsAlmostEqual( + got_attributions[total_attributions_key], expected_values + ) + + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(result, check_result, label="result") + + @parameterized.named_parameters( + ( + "empty", + None, + { + "feature1": np.array([6.33, 6.39, 6.36]), + "feature2": np.array([6.63, 6.69, 6.66]), + }, + ), + ( + "class_id", + metric_types.SubKey(class_id=0), + { + "feature1": 6.33, + "feature2": 6.63, + }, + ), + ( + "k", + metric_types.SubKey(k=2), + { + "feature1": 6.36, + "feature2": 6.66, + }, + ), + ( + "top_k", + metric_types.SubKey(top_k=2), + { + "feature1": np.array([6.39, 6.36]), + "feature2": np.array([6.69, 6.66]), + }, + ), + ) + def testTotalAttributionsWithSubKeys(self, sub_key, expected_values): + computations = attributions.TotalAttributions().computations(sub_keys=[sub_key]) + + example1 = { + "labels": None, + "predictions": None, + "example_weights": np.array(1.0), + "attributions": { + "feature1": [1.11, 1.13, 1.12], + "feature2": [1.21, 1.23, 1.22], + }, } - } - example2 = { - 'labels': None, - 'predictions': None, - 'example_weights': np.array(1.0), - 'attributions': { - 'feature1': [2.11, 2.13, 2.12], - 'feature2': [2.21, 2.23, 2.22] + example2 = { + "labels": None, + "predictions": None, + "example_weights": np.array(1.0), + "attributions": { + "feature1": [2.11, 2.13, 2.12], + "feature2": [2.21, 2.23, 2.22], + }, } - } - example3 = { - 'labels': None, - 'predictions': None, - 'example_weights': np.array(1.0), - 'attributions': { - 'feature1': np.array([3.11, 3.13, 3.12]), - 'feature2': np.array([3.21, 3.23, 3.22]) + example3 = { + "labels": None, + "predictions": None, + "example_weights": np.array(1.0), + "attributions": { + "feature1": np.array([3.11, 3.13, 3.12]), + "feature2": np.array([3.21, 3.23, 3.22]), + }, } - } - - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'Create' >> beam.Create([example1, example2, example3]) - | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs) - | 'AddSlice' >> beam.Map(lambda x: ((), x)) - | - 'CombineAttributions' >> beam.CombinePerKey(computations[0].combiner) - | 'ComputeResult' >> beam.Map( # comment to add lamda on own line - lambda x: (x[0], computations[1].result(x[1])))) - - # pylint: enable=no-value-for-parameter - - def check_result(got): - try: - self.assertLen(got, 1) - got_slice_key, got_attributions = got[0] - self.assertEqual(got_slice_key, ()) - total_attributions_key = metric_types.AttributionsKey( - name='total_attributions', sub_key=sub_key) - self.assertIn(total_attributions_key, got_attributions) - self.assertAllClose(got_attributions[total_attributions_key], - expected_values) - - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(result, check_result, label='result') - - @parameterized.named_parameters( - { - 'testcase_name': 'basic', - 'model_name': '', - 'output_name': '', - 'examples': [{ - 'labels': None, - 'predictions': None, - 'example_weights': np.array(0.5), - 'attributions': { - 'feature1': 1.1, - 'feature2': -1.2, - } - }, { - 'labels': None, - 'predictions': None, - 'example_weights': np.array(0.7), - 'attributions': { - 'feature1': 2.1, - 'feature2': -2.2 - } - }, { - 'labels': None, - 'predictions': None, - 'example_weights': np.array(0.9), - 'attributions': { - 'feature1': 3.1, - 'feature2': -3.2 - } - }], - 'expected_values': { - 'feature1': (1.1 * 0.5 + 2.1 * 0.7 + 3.1 * 0.9), - 'feature2': (1.2 * 0.5 + 2.2 * 0.7 + 3.2 * 0.9), - }, - }, - { - 'testcase_name': 'multi-model', - 'model_name': 'model', - 'output_name': '', - 'examples': [{ - 'labels': None, - 'predictions': None, - 'example_weights': None, - 'attributions': { - 'model': { - 'feature1': 11.1, - 'feature2': -11.2 - }, - } - }, { - 'labels': None, - 'predictions': None, - 'example_weights': None, - 'attributions': { - 'model': { - 'feature1': 22.1, - 'feature2': -22.2 - }, - } - }, { - 'labels': None, - 'predictions': None, - 'example_weights': None, - 'attributions': { - 'model': { - 'feature1': 33.1, - 'feature2': -33.2 - }, - } - }], - 'expected_values': { - 'feature1': (11.1 + 22.1 + 33.1), - 'feature2': (11.2 + 22.2 + 33.2), - }, - }, - { - 'testcase_name': 'multi-model-multi-output', - 'model_name': 'model', - 'output_name': 'output', - 'examples': [{ - 'labels': None, - 'predictions': None, - 'example_weights': np.array(1.0), - 'attributions': { - 'model': { - 'output': { - 'feature1': 111.1, - 'feature2': -111.2 - }, - }, - } - }, { - 'labels': None, - 'predictions': None, - 'example_weights': np.array(1.0), - 'attributions': { - 'model': { - 'output': { - 'feature1': 222.1, - 'feature2': -222.2 - }, - }, - } - }, { - 'labels': None, - 'predictions': None, - 'example_weights': np.array(1.0), - 'attributions': { - 'model': { - 'output': { - 'feature1': 333.1, - 'feature2': -333.2 - }, - }, - } - }], - 'expected_values': { - 'feature1': (111.1 + 222.1 + 333.1), - 'feature2': (111.2 + 222.2 + 333.2), - }, - }, - ) - def testTotalAbsoluteAttributionsWithMultiModelsAndOutputs( - self, model_name, output_name, examples, expected_values): - computations = attributions.TotalAbsoluteAttributions().computations( - model_names=[model_name], - output_names=[output_name], - example_weighted=True) - - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'Create' >> beam.Create(examples) - | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs) - | 'AddSlice' >> beam.Map(lambda x: ((), x)) - | - 'CombineAttributions' >> beam.CombinePerKey(computations[0].combiner) - | 'ComputeResult' >> beam.Map( # comment to add lamda on own line - lambda x: (x[0], computations[1].result(x[1])))) - - # pylint: enable=no-value-for-parameter - - def check_result(got): - try: - self.assertLen(got, 1) - got_slice_key, got_attributions = got[0] - self.assertEqual(got_slice_key, ()) - total_attributions_key = metric_types.AttributionsKey( - name='total_absolute_attributions', - model_name=model_name, - output_name=output_name, - example_weighted=True) - self.assertIn(total_attributions_key, got_attributions) - self.assertDictElementsAlmostEqual( - got_attributions[total_attributions_key], expected_values) - - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(result, check_result, label='result') - - -if __name__ == '__main__': - tf.test.main() + + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "Create" >> beam.Create([example1, example2, example3]) + | "Process" >> beam.Map(metric_util.to_standard_metric_inputs) + | "AddSlice" >> beam.Map(lambda x: ((), x)) + | "CombineAttributions" >> beam.CombinePerKey(computations[0].combiner) + | "ComputeResult" + >> beam.Map( # comment to add lamda on own line + lambda x: (x[0], computations[1].result(x[1])) + ) + ) + + # pylint: enable=no-value-for-parameter + + def check_result(got): + try: + self.assertLen(got, 1) + got_slice_key, got_attributions = got[0] + self.assertEqual(got_slice_key, ()) + total_attributions_key = metric_types.AttributionsKey( + name="total_attributions", sub_key=sub_key + ) + self.assertIn(total_attributions_key, got_attributions) + self.assertAllClose( + got_attributions[total_attributions_key], expected_values + ) + + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(result, check_result, label="result") + + @parameterized.named_parameters( + { + "testcase_name": "basic", + "model_name": "", + "output_name": "", + "examples": [ + { + "labels": None, + "predictions": None, + "example_weights": np.array(0.5), + "attributions": { + "feature1": 1.1, + "feature2": -1.2, + }, + }, + { + "labels": None, + "predictions": None, + "example_weights": np.array(0.7), + "attributions": {"feature1": 2.1, "feature2": -2.2}, + }, + { + "labels": None, + "predictions": None, + "example_weights": np.array(0.9), + "attributions": {"feature1": 3.1, "feature2": -3.2}, + }, + ], + "expected_values": { + "feature1": (1.1 * 0.5 + 2.1 * 0.7 + 3.1 * 0.9), + "feature2": (1.2 * 0.5 + 2.2 * 0.7 + 3.2 * 0.9), + }, + }, + { + "testcase_name": "multi-model", + "model_name": "model", + "output_name": "", + "examples": [ + { + "labels": None, + "predictions": None, + "example_weights": None, + "attributions": { + "model": {"feature1": 11.1, "feature2": -11.2}, + }, + }, + { + "labels": None, + "predictions": None, + "example_weights": None, + "attributions": { + "model": {"feature1": 22.1, "feature2": -22.2}, + }, + }, + { + "labels": None, + "predictions": None, + "example_weights": None, + "attributions": { + "model": {"feature1": 33.1, "feature2": -33.2}, + }, + }, + ], + "expected_values": { + "feature1": (11.1 + 22.1 + 33.1), + "feature2": (11.2 + 22.2 + 33.2), + }, + }, + { + "testcase_name": "multi-model-multi-output", + "model_name": "model", + "output_name": "output", + "examples": [ + { + "labels": None, + "predictions": None, + "example_weights": np.array(1.0), + "attributions": { + "model": { + "output": {"feature1": 111.1, "feature2": -111.2}, + }, + }, + }, + { + "labels": None, + "predictions": None, + "example_weights": np.array(1.0), + "attributions": { + "model": { + "output": {"feature1": 222.1, "feature2": -222.2}, + }, + }, + }, + { + "labels": None, + "predictions": None, + "example_weights": np.array(1.0), + "attributions": { + "model": { + "output": {"feature1": 333.1, "feature2": -333.2}, + }, + }, + }, + ], + "expected_values": { + "feature1": (111.1 + 222.1 + 333.1), + "feature2": (111.2 + 222.2 + 333.2), + }, + }, + ) + def testTotalAbsoluteAttributionsWithMultiModelsAndOutputs( + self, model_name, output_name, examples, expected_values + ): + computations = attributions.TotalAbsoluteAttributions().computations( + model_names=[model_name], output_names=[output_name], example_weighted=True + ) + + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "Create" >> beam.Create(examples) + | "Process" >> beam.Map(metric_util.to_standard_metric_inputs) + | "AddSlice" >> beam.Map(lambda x: ((), x)) + | "CombineAttributions" >> beam.CombinePerKey(computations[0].combiner) + | "ComputeResult" + >> beam.Map( # comment to add lamda on own line + lambda x: (x[0], computations[1].result(x[1])) + ) + ) + + # pylint: enable=no-value-for-parameter + + def check_result(got): + try: + self.assertLen(got, 1) + got_slice_key, got_attributions = got[0] + self.assertEqual(got_slice_key, ()) + total_attributions_key = metric_types.AttributionsKey( + name="total_absolute_attributions", + model_name=model_name, + output_name=output_name, + example_weighted=True, + ) + self.assertIn(total_attributions_key, got_attributions) + self.assertDictElementsAlmostEqual( + got_attributions[total_attributions_key], expected_values + ) + + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(result, check_result, label="result") + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow_model_analysis/metrics/binary_confusion_matrices.py b/tensorflow_model_analysis/metrics/binary_confusion_matrices.py index a63ba5f6e1..b51ff9b726 100644 --- a/tensorflow_model_analysis/metrics/binary_confusion_matrices.py +++ b/tensorflow_model_analysis/metrics/binary_confusion_matrices.py @@ -13,174 +13,202 @@ # limitations under the License. """Binary confusion matrices.""" -from typing import Any, Callable, Dict, Iterable, List, NamedTuple, Optional, Tuple, Union +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + NamedTuple, + Optional, + Tuple, + Union, +) import apache_beam as beam import numpy as np + from tensorflow_model_analysis.api import types -from tensorflow_model_analysis.contrib.aggregates import binary_confusion_matrices as bcm_computations -from tensorflow_model_analysis.metrics import calibration_histogram -from tensorflow_model_analysis.metrics import metric_types -from tensorflow_model_analysis.metrics import metric_util -from tensorflow_model_analysis.proto import config_pb2 -from tensorflow_model_analysis.proto import metrics_for_slice_pb2 +from tensorflow_model_analysis.contrib.aggregates import ( + binary_confusion_matrices as bcm_computations, +) +from tensorflow_model_analysis.metrics import ( + calibration_histogram, + metric_types, + metric_util, +) +from tensorflow_model_analysis.proto import config_pb2, metrics_for_slice_pb2 DEFAULT_NUM_THRESHOLDS = calibration_histogram.DEFAULT_NUM_BUCKETS _KERAS_DEFAULT_NUM_THRESHOLDS = 200 -BINARY_CONFUSION_MATRICES_NAME = '_binary_confusion_matrices' -BINARY_CONFUSION_EXAMPLES_NAME = '_binary_confusion_examples' +BINARY_CONFUSION_MATRICES_NAME = "_binary_confusion_matrices" +BINARY_CONFUSION_EXAMPLES_NAME = "_binary_confusion_examples" -_BINARY_CONFUSION_MATRIX_NAME = '_binary_confusion_matrix' +_BINARY_CONFUSION_MATRIX_NAME = "_binary_confusion_matrix" MatrixAccumulator = bcm_computations.MatrixAccumulator class Examples( - NamedTuple('Examples', [('thresholds', List[float]), - ('tp_examples', List[List[str]]), - ('tn_examples', List[List[str]]), - ('fp_examples', List[List[str]]), - ('fn_examples', List[List[str]])])): - """A set of examples for each binary confusion case at each threshold.""" + NamedTuple( + "Examples", + [ + ("thresholds", List[float]), + ("tp_examples", List[List[str]]), + ("tn_examples", List[List[str]]), + ("fp_examples", List[List[str]]), + ("fn_examples", List[List[str]]), + ], + ) +): + """A set of examples for each binary confusion case at each threshold.""" class Matrices( # pytype: disable=signature-mismatch # always-use-return-annotations types.StructuredMetricValue, NamedTuple( - 'Matrices', + "Matrices", [ - ('thresholds', List[float]), - ('tp', List[float]), - ('tn', List[float]), - ('fp', List[float]), - ('fn', List[float]), + ("thresholds", List[float]), + ("tp", List[float]), + ("tn", List[float]), + ("fp", List[float]), + ("fn", List[float]), ], ), ): - """A class representing a set of binary confusion matrices at thresholds. - - For each threshold, in addition to the count of examples per prediction and - label, this class also contains a sample of raw examples. Threshold values are - sorted, and the entries within tp[i], tn[i], fp[i], and fn[i] correspond to - thresholds[i]. - """ - - def _apply_binary_op_elementwise(self, other: 'Matrices', - op: Callable[[float, float], float]): - """Applies an operator elementwise on self and `other` matrices.""" - tp, tn, fp, fn = [], [], [], [] - self_idx, other_idx = 0, 0 - merged_thresholds = [] - while True: - if (self_idx < len(self.thresholds) and - other_idx < len(other.thresholds) and - self.thresholds[self_idx] == other.thresholds[other_idx]): - # threshold present in both, advance both indices - merged_thresholds.append(self.thresholds[self_idx]) - tp.append(op(self.tp[self_idx], other.tp[other_idx])) - tn.append(op(self.tn[self_idx], other.tn[other_idx])) - fp.append(op(self.fp[self_idx], other.fp[other_idx])) - fn.append(op(self.fn[self_idx], other.fn[other_idx])) - self_idx += 1 - other_idx += 1 - elif (self_idx < len(self.thresholds) and - (other_idx >= len(other.thresholds) or - self.thresholds[self_idx] < other.thresholds[other_idx])): - # threshold present in self but missing from other, use default values - # for other and advance self_idx - merged_thresholds.append(self.thresholds[self_idx]) - tp.append(op(self.tp[self_idx], 0)) - tn.append(op(self.tn[self_idx], 0)) - fp.append(op(self.fp[self_idx], 0)) - fn.append(op(self.fn[self_idx], 0)) - self_idx += 1 - elif (other_idx < len(other.thresholds) and - (self_idx >= len(self.thresholds) or - other.thresholds[self_idx] < self.thresholds[other_idx])): - # threshold present in other but missing from self, use default values - # for self and advance other_idx - merged_thresholds.append(other.thresholds[other_idx]) - tp.append(op(0, other.tp[other_idx])) - tn.append(op(0, other.tn[other_idx])) - fp.append(op(0, other.fp[other_idx])) - fn.append(op(0, other.fn[other_idx])) - other_idx += 1 - else: - assert (self_idx >= len(self.thresholds) and - other_idx >= len(other.thresholds)) - break - return Matrices(thresholds=merged_thresholds, tp=tp, tn=tn, fp=fp, fn=fn) - - def _apply_binary_op_broadcast(self, other: float, - op: Callable[[float, float], float]): - """Applies an operator on each element and the provided float.""" - return Matrices( - thresholds=self.thresholds, - tp=[op(tp, other) for tp in self.tp], - tn=[op(tn, other) for tn in self.tn], - fp=[op(fp, other) for fp in self.fp], - fn=[op(fn, other) for fn in self.fn]) - - def to_proto(self) -> metrics_for_slice_pb2.MetricValue: - """Converts matrices into ConfusionMatrixAtThresholds proto. - - If precision or recall are undefined then 1.0 and 0.0 will be used. - - Returns: - A MetricValue proto containing a ConfusionMatrixAtThresholds proto. + """A class representing a set of binary confusion matrices at thresholds. + + For each threshold, in addition to the count of examples per prediction and + label, this class also contains a sample of raw examples. Threshold values are + sorted, and the entries within tp[i], tn[i], fp[i], and fn[i] correspond to + thresholds[i]. """ - result = metrics_for_slice_pb2.MetricValue() - tp, fp = np.array(self.tp), np.array(self.fp) - tn, fn = np.array(self.tn), np.array(self.fn) - predicted_positives, labeled_positives = tp + fp, tp + fn - predicated_negatives, labeled_negatives = tn + fn, tn + fp - - precision = np.divide( - tp, - predicted_positives, - out=np.ones_like(predicted_positives), - where=(predicted_positives > 0), - ) - recall = np.divide( - tp, - labeled_positives, - out=np.zeros_like(labeled_positives), - where=(labeled_positives > 0), - ) - f1 = 2 * precision * recall / (precision + recall) - accuracy = (tp + tn) / (tp + tn + fp + fn) - false_positive_rate = fp / labeled_negatives - false_omission_rate = fn / predicated_negatives - confusion_matrix_at_thresholds_proto = result.confusion_matrix_at_thresholds - for i, threshold in enumerate(self.thresholds): - confusion_matrix_at_thresholds_proto.matrices.add( - threshold=round(threshold, 6), - true_positives=tp[i], - false_positives=fp[i], - true_negatives=tn[i], - false_negatives=fn[i], - precision=precision[i], - recall=recall[i], - false_positive_rate=false_positive_rate[i], - false_omission_rate=false_omission_rate[i], - f1=f1[i], - accuracy=accuracy[i], - ) - return result + def _apply_binary_op_elementwise( + self, other: "Matrices", op: Callable[[float, float], float] + ): + """Applies an operator elementwise on self and `other` matrices.""" + tp, tn, fp, fn = [], [], [], [] + self_idx, other_idx = 0, 0 + merged_thresholds = [] + while True: + if ( + self_idx < len(self.thresholds) + and other_idx < len(other.thresholds) + and self.thresholds[self_idx] == other.thresholds[other_idx] + ): + # threshold present in both, advance both indices + merged_thresholds.append(self.thresholds[self_idx]) + tp.append(op(self.tp[self_idx], other.tp[other_idx])) + tn.append(op(self.tn[self_idx], other.tn[other_idx])) + fp.append(op(self.fp[self_idx], other.fp[other_idx])) + fn.append(op(self.fn[self_idx], other.fn[other_idx])) + self_idx += 1 + other_idx += 1 + elif self_idx < len(self.thresholds) and ( + other_idx >= len(other.thresholds) + or self.thresholds[self_idx] < other.thresholds[other_idx] + ): + # threshold present in self but missing from other, use default values + # for other and advance self_idx + merged_thresholds.append(self.thresholds[self_idx]) + tp.append(op(self.tp[self_idx], 0)) + tn.append(op(self.tn[self_idx], 0)) + fp.append(op(self.fp[self_idx], 0)) + fn.append(op(self.fn[self_idx], 0)) + self_idx += 1 + elif other_idx < len(other.thresholds) and ( + self_idx >= len(self.thresholds) + or other.thresholds[self_idx] < self.thresholds[other_idx] + ): + # threshold present in other but missing from self, use default values + # for self and advance other_idx + merged_thresholds.append(other.thresholds[other_idx]) + tp.append(op(0, other.tp[other_idx])) + tn.append(op(0, other.tn[other_idx])) + fp.append(op(0, other.fp[other_idx])) + fn.append(op(0, other.fn[other_idx])) + other_idx += 1 + else: + assert self_idx >= len(self.thresholds) and other_idx >= len( + other.thresholds + ) + break + return Matrices(thresholds=merged_thresholds, tp=tp, tn=tn, fp=fp, fn=fn) + + def _apply_binary_op_broadcast( + self, other: float, op: Callable[[float, float], float] + ): + """Applies an operator on each element and the provided float.""" + return Matrices( + thresholds=self.thresholds, + tp=[op(tp, other) for tp in self.tp], + tn=[op(tn, other) for tn in self.tn], + fp=[op(fp, other) for fp in self.fp], + fn=[op(fn, other) for fn in self.fn], + ) + + def to_proto(self) -> metrics_for_slice_pb2.MetricValue: + """Converts matrices into ConfusionMatrixAtThresholds proto. + + If precision or recall are undefined then 1.0 and 0.0 will be used. + + Returns + ------- + A MetricValue proto containing a ConfusionMatrixAtThresholds proto. + """ + result = metrics_for_slice_pb2.MetricValue() + tp, fp = np.array(self.tp), np.array(self.fp) + tn, fn = np.array(self.tn), np.array(self.fn) + predicted_positives, labeled_positives = tp + fp, tp + fn + predicated_negatives, labeled_negatives = tn + fn, tn + fp + + precision = np.divide( + tp, + predicted_positives, + out=np.ones_like(predicted_positives), + where=(predicted_positives > 0), + ) + recall = np.divide( + tp, + labeled_positives, + out=np.zeros_like(labeled_positives), + where=(labeled_positives > 0), + ) + f1 = 2 * precision * recall / (precision + recall) + accuracy = (tp + tn) / (tp + tn + fp + fn) + false_positive_rate = fp / labeled_negatives + false_omission_rate = fn / predicated_negatives + confusion_matrix_at_thresholds_proto = result.confusion_matrix_at_thresholds + for i, threshold in enumerate(self.thresholds): + confusion_matrix_at_thresholds_proto.matrices.add( + threshold=round(threshold, 6), + true_positives=tp[i], + false_positives=fp[i], + true_negatives=tn[i], + false_negatives=fn[i], + precision=precision[i], + recall=recall[i], + false_positive_rate=false_positive_rate[i], + false_omission_rate=false_omission_rate[i], + f1=f1[i], + accuracy=accuracy[i], + ) + return result _EPSILON = 1e-7 def _interpolated_thresholds(num_thresholds: int) -> List[float]: - """Returns thresholds interpolated over a range equal to num_thresholds.""" - # The interpolation strategy used here matches that used by keras for AUC. - thresholds = [ - (i + 1) * 1.0 / (num_thresholds - 1) for i in range(num_thresholds - 2) - ] - return [-_EPSILON] + thresholds + [1.0 + _EPSILON] + """Returns thresholds interpolated over a range equal to num_thresholds.""" + # The interpolation strategy used here matches that used by keras for AUC. + thresholds = [ + (i + 1) * 1.0 / (num_thresholds - 1) for i in range(num_thresholds - 2) + ] + return [-_EPSILON] + thresholds + [1.0 + _EPSILON] def binary_confusion_matrices( @@ -188,8 +216,8 @@ def binary_confusion_matrices( thresholds: Optional[List[float]] = None, name: Optional[str] = None, eval_config: Optional[config_pb2.EvalConfig] = None, - model_name: str = '', - output_name: str = '', + model_name: str = "", + output_name: str = "", sub_key: Optional[metric_types.SubKey] = None, aggregation_type: Optional[metric_types.AggregationType] = None, class_weights: Optional[Dict[int, float]] = None, @@ -204,295 +232,315 @@ def binary_confusion_matrices( example_ids_count: Optional[int] = None, fractional_labels: bool = True, ) -> metric_types.MetricComputations: - """Returns metric computations for computing binary confusion matrices. - - Args: - num_thresholds: Number of thresholds to use. Thresholds will be calculated - using linear interpolation between 0.0 and 1.0 with equidistant values and - bondardaries at -epsilon and 1.0+epsilon. Values must be > 0. Only one of - num_thresholds or thresholds should be used. If used, num_thresholds must - be > 1. - thresholds: A specific set of thresholds to use. The caller is responsible - for marking the boundaries with +/-epsilon if desired. Only one of - num_thresholds or thresholds should be used. For metrics computed at top k - this may be a single negative threshold value (i.e. -inf). - name: Metric name containing binary_confusion_matrices.Matrices. - eval_config: Eval config. - model_name: Optional model name (if multi-model evaluation). - output_name: Optional output name (if multi-output model type). - sub_key: Optional sub key. - aggregation_type: Optional aggregation type. - class_weights: Optional class weights to apply to multi-class / multi-label - labels and predictions prior to flattening (when micro averaging is used). - example_weighted: True if example weights should be applied. - use_histogram: If true, matrices will be derived from calibration - histograms. - extract_label_prediction_and_weight: User-provided function argument that - yields label, prediction, and example weights for use in calculations - (relevant only when use_histogram flag is not true). - preprocessors: User-provided preprocessor for including additional extracts - in StandardMetricInputs (relevant only when use_histogram flag is not - true). - examples_name: Metric name containing binary_confusion_matrices.Examples. - (relevant only when use_histogram flag is not true and example_id_key is - set). - example_id_key: Feature key containing example id (relevant only when - use_histogram flag is not true). - example_ids_count: Max number of example ids to be extracted for false - positives and false negatives (relevant only when use_histogram flag is - not true). - fractional_labels: If true, each incoming tuple of (label, prediction, and - example weight) will be split into two tuples as follows (where l, p, w - represent the resulting label, prediction, and example weight values): (1) - l = 0.0, p = prediction, and w = example_weight * (1.0 - label) (2) l = - 1.0, p = prediction, and w = example_weight * label If enabled, an - exception will be raised if labels are not within [0, 1]. The - implementation is such that tuples associated with a weight of zero are - not yielded. This means it is safe to enable fractional_labels even when - the labels only take on the values of 0.0 or 1.0. - - Raises: - ValueError: If both num_thresholds and thresholds are set at the same time. - """ - # TF v1 Keras AUC turns num_thresholds parameters into thresholds which - # circumvents sharing of settings. If the thresholds match the interpolated - # version of the thresholds then reset back to num_thresholds. - if thresholds: - if (not num_thresholds and - thresholds == _interpolated_thresholds(len(thresholds))): - num_thresholds = len(thresholds) - thresholds = None - elif (num_thresholds - in (DEFAULT_NUM_THRESHOLDS, _KERAS_DEFAULT_NUM_THRESHOLDS) and - len(thresholds) == num_thresholds - 2): - thresholds = None - if num_thresholds is not None and thresholds is not None: - raise ValueError( - 'only one of thresholds or num_thresholds can be set at a time: ' - f'num_thresholds={num_thresholds}, thresholds={thresholds}, ' - f'len(thresholds)={len(thresholds)})') - if num_thresholds is None and thresholds is None: - num_thresholds = DEFAULT_NUM_THRESHOLDS - if num_thresholds is not None: - if num_thresholds <= 1: - raise ValueError('num_thresholds must be > 1') - # The interpolation strategy used here matches that used by keras for AUC. - thresholds = _interpolated_thresholds(num_thresholds) - - if use_histogram is None: - use_histogram = ( - num_thresholds is not None or - (len(thresholds) == 1 and thresholds[0] < 0)) - - if use_histogram and (examples_name or example_id_key or example_ids_count): - raise ValueError('Example sampling is only performed when not using the ' - 'histogram computation. However, use_histogram is true ' - f'and one of examples_name ("{examples_name}"), ' - f'examples_id_key ("{example_id_key}"), ' - f'or example_ids_count ({example_ids_count}) was ' - 'provided, which will have no effect.') - - if examples_name and not (example_id_key and example_ids_count): - raise ValueError('examples_name provided but either example_id_key or ' - 'example_ids_count was not. Examples will only be ' - 'returned when both example_id_key and ' - 'example_ids_count are provided, and when the ' - 'non-histogram computation is used. ' - f'example_id_key: "{example_id_key}" ' - f'example_ids_count: {example_ids_count}') - - if name is None: - name_args = { - 'example_id_key': example_id_key, - 'example_ids_count': example_ids_count - } - if num_thresholds: - name_args['num_thresholds'] = num_thresholds - else: - name_args['thresholds'] = thresholds - if preprocessors: - name_args['preprocessors'] = tuple(p.name for p in preprocessors) - if class_weights: - name_args['class_weights'] = class_weights - - name = metric_util.generate_private_name_from_arguments( - BINARY_CONFUSION_MATRICES_NAME, **name_args) - examples_name = metric_util.generate_private_name_from_arguments( - BINARY_CONFUSION_EXAMPLES_NAME, **name_args) - - matrices_key = metric_types.MetricKey( - name=name, - model_name=model_name, - output_name=output_name, - sub_key=sub_key, - example_weighted=example_weighted) - examples_key = metric_types.MetricKey( - name=examples_name, - model_name=model_name, - output_name=output_name, - sub_key=sub_key, - example_weighted=example_weighted) - - computations = [] - if use_histogram: - # Use calibration histogram to calculate matrices. For efficiency (unless - # all predictions are matched - i.e. thresholds <= 0) we will assume that - # other metrics will make use of the calibration histogram and re-use the - # default histogram for the given model_name/output_name/sub_key. This is - # also required to get accurate counts at the threshold boundaries. If this - # becomes an issue, then calibration histogram can be updated to support - # non-linear boundaries. - # If used for object_detection, to distinguish between histograms with - # different specs, we generate a unique name for it. - - # For precision/recall_at_k were a single large negative threshold - # is used, we only need one bucket. Note that the histogram will - # actually have 2 buckets: one that we set (which handles - # predictions > -1.0) and a default catch-all bucket (i.e. bucket 0) - # that the histogram creates for large negative predictions (i.e. - # predictions <= -1.0). - num_buckets = 1 if len(thresholds) == 1 and thresholds[0] <= 0 else None - - computations = calibration_histogram.calibration_histogram( - eval_config=eval_config, - num_buckets=num_buckets, + """Returns metric computations for computing binary confusion matrices. + + Args: + ---- + num_thresholds: Number of thresholds to use. Thresholds will be calculated + using linear interpolation between 0.0 and 1.0 with equidistant values and + bondardaries at -epsilon and 1.0+epsilon. Values must be > 0. Only one of + num_thresholds or thresholds should be used. If used, num_thresholds must + be > 1. + thresholds: A specific set of thresholds to use. The caller is responsible + for marking the boundaries with +/-epsilon if desired. Only one of + num_thresholds or thresholds should be used. For metrics computed at top k + this may be a single negative threshold value (i.e. -inf). + name: Metric name containing binary_confusion_matrices.Matrices. + eval_config: Eval config. + model_name: Optional model name (if multi-model evaluation). + output_name: Optional output name (if multi-output model type). + sub_key: Optional sub key. + aggregation_type: Optional aggregation type. + class_weights: Optional class weights to apply to multi-class / multi-label + labels and predictions prior to flattening (when micro averaging is used). + example_weighted: True if example weights should be applied. + use_histogram: If true, matrices will be derived from calibration + histograms. + extract_label_prediction_and_weight: User-provided function argument that + yields label, prediction, and example weights for use in calculations + (relevant only when use_histogram flag is not true). + preprocessors: User-provided preprocessor for including additional extracts + in StandardMetricInputs (relevant only when use_histogram flag is not + true). + examples_name: Metric name containing binary_confusion_matrices.Examples. + (relevant only when use_histogram flag is not true and example_id_key is + set). + example_id_key: Feature key containing example id (relevant only when + use_histogram flag is not true). + example_ids_count: Max number of example ids to be extracted for false + positives and false negatives (relevant only when use_histogram flag is + not true). + fractional_labels: If true, each incoming tuple of (label, prediction, and + example weight) will be split into two tuples as follows (where l, p, w + represent the resulting label, prediction, and example weight values): (1) + l = 0.0, p = prediction, and w = example_weight * (1.0 - label) (2) l = + 1.0, p = prediction, and w = example_weight * label If enabled, an + exception will be raised if labels are not within [0, 1]. The + implementation is such that tuples associated with a weight of zero are + not yielded. This means it is safe to enable fractional_labels even when + the labels only take on the values of 0.0 or 1.0. + + Raises: + ------ + ValueError: If both num_thresholds and thresholds are set at the same time. + """ + # TF v1 Keras AUC turns num_thresholds parameters into thresholds which + # circumvents sharing of settings. If the thresholds match the interpolated + # version of the thresholds then reset back to num_thresholds. + if thresholds: + if not num_thresholds and thresholds == _interpolated_thresholds( + len(thresholds) + ): + num_thresholds = len(thresholds) + thresholds = None + elif ( + num_thresholds in (DEFAULT_NUM_THRESHOLDS, _KERAS_DEFAULT_NUM_THRESHOLDS) + and len(thresholds) == num_thresholds - 2 + ): + thresholds = None + if num_thresholds is not None and thresholds is not None: + raise ValueError( + "only one of thresholds or num_thresholds can be set at a time: " + f"num_thresholds={num_thresholds}, thresholds={thresholds}, " + f"len(thresholds)={len(thresholds)})" + ) + if num_thresholds is None and thresholds is None: + num_thresholds = DEFAULT_NUM_THRESHOLDS + if num_thresholds is not None: + if num_thresholds <= 1: + raise ValueError("num_thresholds must be > 1") + # The interpolation strategy used here matches that used by keras for AUC. + thresholds = _interpolated_thresholds(num_thresholds) + + if use_histogram is None: + use_histogram = num_thresholds is not None or ( + len(thresholds) == 1 and thresholds[0] < 0 + ) + + if use_histogram and (examples_name or example_id_key or example_ids_count): + raise ValueError( + "Example sampling is only performed when not using the " + "histogram computation. However, use_histogram is true " + f'and one of examples_name ("{examples_name}"), ' + f'examples_id_key ("{example_id_key}"), ' + f"or example_ids_count ({example_ids_count}) was " + "provided, which will have no effect." + ) + + if examples_name and not (example_id_key and example_ids_count): + raise ValueError( + "examples_name provided but either example_id_key or " + "example_ids_count was not. Examples will only be " + "returned when both example_id_key and " + "example_ids_count are provided, and when the " + "non-histogram computation is used. " + f'example_id_key: "{example_id_key}" ' + f"example_ids_count: {example_ids_count}" + ) + + if name is None: + name_args = { + "example_id_key": example_id_key, + "example_ids_count": example_ids_count, + } + if num_thresholds: + name_args["num_thresholds"] = num_thresholds + else: + name_args["thresholds"] = thresholds + if preprocessors: + name_args["preprocessors"] = tuple(p.name for p in preprocessors) + if class_weights: + name_args["class_weights"] = class_weights + + name = metric_util.generate_private_name_from_arguments( + BINARY_CONFUSION_MATRICES_NAME, **name_args + ) + examples_name = metric_util.generate_private_name_from_arguments( + BINARY_CONFUSION_EXAMPLES_NAME, **name_args + ) + + matrices_key = metric_types.MetricKey( + name=name, model_name=model_name, output_name=output_name, - preprocessors=preprocessors, sub_key=sub_key, - aggregation_type=aggregation_type, - class_weights=class_weights, - example_weighted=example_weighted) - input_metric_key = computations[-1].keys[-1] - output_metric_keys = [matrices_key] - else: - if bool(example_ids_count) != bool(example_id_key): - raise ValueError('Both of example_ids_count and example_id_key must be ' - f'set, but got example_id_key: "{example_id_key}" and ' - f'example_ids_count: {example_ids_count}.') - computations = _binary_confusion_matrix_computation( - eval_config=eval_config, - name=name, - thresholds=thresholds, + example_weighted=example_weighted, + ) + examples_key = metric_types.MetricKey( + name=examples_name, model_name=model_name, output_name=output_name, sub_key=sub_key, - extract_label_prediction_and_weight=extract_label_prediction_and_weight, - preprocessors=preprocessors, - example_id_key=example_id_key, - example_ids_count=example_ids_count, - aggregation_type=aggregation_type, - class_weights=class_weights, example_weighted=example_weighted, - enable_fractional_labels=fractional_labels, ) - input_metric_key = computations[-1].keys[-1] - # matrices_key is last for backwards compatibility with code that: - # 1) used this computation as an input for a derived computation - # 2) only accessed the matrix counts - # 3) used computations[-1].keys[-1] to access the input key - output_metric_keys = ([matrices_key] if not example_id_key else - [examples_key, matrices_key]) - - def result( - metrics: Dict[metric_types.MetricKey, Any] - ) -> Dict[metric_types.MetricKey, Union[Matrices, Examples]]: - """Returns binary confusion matrices.""" - matrices = None + + computations = [] if use_histogram: - if len(thresholds) == 1 and thresholds[0] < 0: - # This case is used when all positive prediction values are relevant - # matches (e.g. when calculating top_k for precision/recall where the - # non-top_k values are expected to have been set to float('-inf')). - histogram = metrics[input_metric_key] - else: - # Calibration histogram uses intervals of the form [start, end) where - # the prediction >= start. The confusion matrices want intervals of the - # form (start, end] where the prediction > start. Add a small epsilon so - # that >= checks don't match. This correction shouldn't be needed in - # practice but allows for correctness in small tests. - rebin_thresholds = [t + _EPSILON if t != 0 else t for t in thresholds] - if thresholds[0] >= 0: - # Add -epsilon bucket to account for differences in histogram vs - # confusion matrix intervals mentioned above. If the epsilon bucket is - # missing the false negatives and false positives will be 0 for the - # first threshold. - rebin_thresholds = [-_EPSILON] + rebin_thresholds - if thresholds[-1] < 1.0: - # If the last threshold < 1.0, then add a fence post at 1.0 + epsilon - # otherwise true negatives and true positives will be overcounted. - rebin_thresholds = rebin_thresholds + [1.0 + _EPSILON] - histogram = calibration_histogram.rebin(rebin_thresholds, - metrics[input_metric_key]) - matrices = _histogram_to_binary_confusion_matrices(thresholds, histogram) - return {matrices_key: matrices} + # Use calibration histogram to calculate matrices. For efficiency (unless + # all predictions are matched - i.e. thresholds <= 0) we will assume that + # other metrics will make use of the calibration histogram and re-use the + # default histogram for the given model_name/output_name/sub_key. This is + # also required to get accurate counts at the threshold boundaries. If this + # becomes an issue, then calibration histogram can be updated to support + # non-linear boundaries. + # If used for object_detection, to distinguish between histograms with + # different specs, we generate a unique name for it. + + # For precision/recall_at_k were a single large negative threshold + # is used, we only need one bucket. Note that the histogram will + # actually have 2 buckets: one that we set (which handles + # predictions > -1.0) and a default catch-all bucket (i.e. bucket 0) + # that the histogram creates for large negative predictions (i.e. + # predictions <= -1.0). + num_buckets = 1 if len(thresholds) == 1 and thresholds[0] <= 0 else None + + computations = calibration_histogram.calibration_histogram( + eval_config=eval_config, + num_buckets=num_buckets, + model_name=model_name, + output_name=output_name, + preprocessors=preprocessors, + sub_key=sub_key, + aggregation_type=aggregation_type, + class_weights=class_weights, + example_weighted=example_weighted, + ) + input_metric_key = computations[-1].keys[-1] + output_metric_keys = [matrices_key] else: - matrices, examples = _accumulator_to_matrices_and_examples( - thresholds, metrics[input_metric_key]) - result = {matrices_key: matrices} - if example_id_key: - result[examples_key] = examples - return result - - derived_computation = metric_types.DerivedMetricComputation( - keys=output_metric_keys, result=result) - computations.append(derived_computation) - return computations + if bool(example_ids_count) != bool(example_id_key): + raise ValueError( + "Both of example_ids_count and example_id_key must be " + f'set, but got example_id_key: "{example_id_key}" and ' + f"example_ids_count: {example_ids_count}." + ) + computations = _binary_confusion_matrix_computation( + eval_config=eval_config, + name=name, + thresholds=thresholds, + model_name=model_name, + output_name=output_name, + sub_key=sub_key, + extract_label_prediction_and_weight=extract_label_prediction_and_weight, + preprocessors=preprocessors, + example_id_key=example_id_key, + example_ids_count=example_ids_count, + aggregation_type=aggregation_type, + class_weights=class_weights, + example_weighted=example_weighted, + enable_fractional_labels=fractional_labels, + ) + input_metric_key = computations[-1].keys[-1] + # matrices_key is last for backwards compatibility with code that: + # 1) used this computation as an input for a derived computation + # 2) only accessed the matrix counts + # 3) used computations[-1].keys[-1] to access the input key + output_metric_keys = ( + [matrices_key] if not example_id_key else [examples_key, matrices_key] + ) + + def result( + metrics: Dict[metric_types.MetricKey, Any], + ) -> Dict[metric_types.MetricKey, Union[Matrices, Examples]]: + """Returns binary confusion matrices.""" + matrices = None + if use_histogram: + if len(thresholds) == 1 and thresholds[0] < 0: + # This case is used when all positive prediction values are relevant + # matches (e.g. when calculating top_k for precision/recall where the + # non-top_k values are expected to have been set to float('-inf')). + histogram = metrics[input_metric_key] + else: + # Calibration histogram uses intervals of the form [start, end) where + # the prediction >= start. The confusion matrices want intervals of the + # form (start, end] where the prediction > start. Add a small epsilon so + # that >= checks don't match. This correction shouldn't be needed in + # practice but allows for correctness in small tests. + rebin_thresholds = [t + _EPSILON if t != 0 else t for t in thresholds] + if thresholds[0] >= 0: + # Add -epsilon bucket to account for differences in histogram vs + # confusion matrix intervals mentioned above. If the epsilon bucket is + # missing the false negatives and false positives will be 0 for the + # first threshold. + rebin_thresholds = [-_EPSILON] + rebin_thresholds + if thresholds[-1] < 1.0: + # If the last threshold < 1.0, then add a fence post at 1.0 + epsilon + # otherwise true negatives and true positives will be overcounted. + rebin_thresholds = rebin_thresholds + [1.0 + _EPSILON] + histogram = calibration_histogram.rebin( + rebin_thresholds, metrics[input_metric_key] + ) + matrices = _histogram_to_binary_confusion_matrices(thresholds, histogram) + return {matrices_key: matrices} + else: + matrices, examples = _accumulator_to_matrices_and_examples( + thresholds, metrics[input_metric_key] + ) + result = {matrices_key: matrices} + if example_id_key: + result[examples_key] = examples + return result + + derived_computation = metric_types.DerivedMetricComputation( + keys=output_metric_keys, result=result + ) + computations.append(derived_computation) + return computations def _histogram_to_binary_confusion_matrices( - thresholds: List[float], - histogram: calibration_histogram.Histogram) -> Matrices: - """Converts histogram to binary confusion matrices.""" - # tp(i) - sum of positive labels >= bucket i - # fp(i) - sum of negative labels >= bucket i - # fn(i) - sum of positive labels < bucket i - # tn(i) - sum of negative labels < bucket i - n = len(histogram) - tp = [0.0] * n - fp = [0.0] * n - tn = [0.0] * n - fn = [0.0] * n - for i in range(n): - start = i - end = n - i - 1 - start_pos = histogram[start].weighted_labels - start_neg = ( - histogram[start].weighted_examples - histogram[start].weighted_labels) - end_pos = histogram[end].weighted_labels - end_neg = ( - histogram[end].weighted_examples - histogram[end].weighted_labels) - tp[end] = tp[end + 1] + end_pos if end < n - 1 else end_pos - fp[end] = fp[end + 1] + end_neg if end < n - 1 else end_neg - if start + 1 < n: - tn[start + 1] = tn[start] + start_neg - fn[start + 1] = fn[start] + start_pos - # Check if need to remove -epsilon bucket (or reset back to 1 bucket). - threshold_offset = 0 - if (thresholds[0] >= 0 or len(thresholds) == 1) and len(histogram) > 1: - threshold_offset = 1 - tp = tp[threshold_offset:threshold_offset + len(thresholds)] - fp = fp[threshold_offset:threshold_offset + len(thresholds)] - tn = tn[threshold_offset:threshold_offset + len(thresholds)] - fn = fn[threshold_offset:threshold_offset + len(thresholds)] - # We sum all values >= bucket i, but TP/FP values greater that 1.0 + EPSILON - # should be 0.0. The FN/TN above 1.0 + _EPSILON should also be adjusted to - # match the TP/FP values at the start. - for i, t in enumerate(thresholds): - if t >= 1.0 + _EPSILON: - tp[i] = 0.0 - fp[i] = 0.0 - fn[i] = tp[0] - tn[i] = fp[0] - return Matrices(thresholds, tp, tn, fp, fn) + thresholds: List[float], histogram: calibration_histogram.Histogram +) -> Matrices: + """Converts histogram to binary confusion matrices.""" + # tp(i) - sum of positive labels >= bucket i + # fp(i) - sum of negative labels >= bucket i + # fn(i) - sum of positive labels < bucket i + # tn(i) - sum of negative labels < bucket i + n = len(histogram) + tp = [0.0] * n + fp = [0.0] * n + tn = [0.0] * n + fn = [0.0] * n + for i in range(n): + start = i + end = n - i - 1 + start_pos = histogram[start].weighted_labels + start_neg = ( + histogram[start].weighted_examples - histogram[start].weighted_labels + ) + end_pos = histogram[end].weighted_labels + end_neg = histogram[end].weighted_examples - histogram[end].weighted_labels + tp[end] = tp[end + 1] + end_pos if end < n - 1 else end_pos + fp[end] = fp[end + 1] + end_neg if end < n - 1 else end_neg + if start + 1 < n: + tn[start + 1] = tn[start] + start_neg + fn[start + 1] = fn[start] + start_pos + # Check if need to remove -epsilon bucket (or reset back to 1 bucket). + threshold_offset = 0 + if (thresholds[0] >= 0 or len(thresholds) == 1) and len(histogram) > 1: + threshold_offset = 1 + tp = tp[threshold_offset : threshold_offset + len(thresholds)] + fp = fp[threshold_offset : threshold_offset + len(thresholds)] + tn = tn[threshold_offset : threshold_offset + len(thresholds)] + fn = fn[threshold_offset : threshold_offset + len(thresholds)] + # We sum all values >= bucket i, but TP/FP values greater that 1.0 + EPSILON + # should be 0.0. The FN/TN above 1.0 + _EPSILON should also be adjusted to + # match the TP/FP values at the start. + for i, t in enumerate(thresholds): + if t >= 1.0 + _EPSILON: + tp[i] = 0.0 + fp[i] = 0.0 + fn[i] = tp[0] + tn[i] = fp[0] + return Matrices(thresholds, tp, tn, fp, fn) def _binary_confusion_matrix_computation( thresholds: List[float], name: Optional[str] = None, eval_config: Optional[config_pb2.EvalConfig] = None, - model_name: str = '', - output_name: str = '', + model_name: str = "", + output_name: str = "", sub_key: Optional[metric_types.SubKey] = None, extract_label_prediction_and_weight: Optional[ Callable[..., Any] @@ -505,150 +553,143 @@ def _binary_confusion_matrix_computation( example_weighted: bool = False, enable_fractional_labels: bool = True, ) -> metric_types.MetricComputations: - """Returns metric computations for computing binary confusion matrix.""" - if example_ids_count is None: - example_ids_count = bcm_computations.DEFAULT_NUM_EXAMPLE_IDS - - key = metric_types.MetricKey( - name=name, - model_name=model_name, - output_name=output_name, - sub_key=sub_key, - aggregation_type=aggregation_type, - example_weighted=example_weighted, - ) - - return [ - metric_types.MetricComputation( - keys=[key], - preprocessors=preprocessors, - combiner=_BinaryConfusionMatrixCombiner( - key=key, - eval_config=eval_config, - thresholds=thresholds, - extract_label_prediction_and_weight=extract_label_prediction_and_weight, - example_id_key=example_id_key, - example_ids_count=example_ids_count, - aggregation_type=aggregation_type, - class_weights=class_weights, - example_weighted=example_weighted, - enable_fractional_labels=enable_fractional_labels, - ), - ) - ] + """Returns metric computations for computing binary confusion matrix.""" + if example_ids_count is None: + example_ids_count = bcm_computations.DEFAULT_NUM_EXAMPLE_IDS - -class _BinaryConfusionMatrixCombiner(beam.CombineFn): - """Computes binary confusion matrix in TFMA.""" - - def __init__( - self, - key: metric_types.MetricKey, - eval_config: Optional[config_pb2.EvalConfig], - thresholds: List[float], - extract_label_prediction_and_weight: Callable[..., Any], - example_id_key: Optional[str], - example_ids_count: int, - aggregation_type: Optional[metric_types.AggregationType], - class_weights: Optional[Dict[int, float]], - example_weighted: bool, - enable_fractional_labels: bool, - ): - self._key = key - self._eval_config = eval_config - self._extract_label_prediction_and_weight = ( - extract_label_prediction_and_weight - ) - self._example_id_key = example_id_key - self._aggregation_type = aggregation_type - self._class_weights = class_weights - self._example_weighted = example_weighted - self._enable_fractional_labels = enable_fractional_labels - - self._binary_confusion_matrices = bcm_computations.BinaryConfusionMatrices( - thresholds=thresholds, - example_ids_count=example_ids_count, - enable_fractional_labels=enable_fractional_labels, + key = metric_types.MetricKey( + name=name, + model_name=model_name, + output_name=output_name, + sub_key=sub_key, + aggregation_type=aggregation_type, + example_weighted=example_weighted, ) - def create_accumulator(self) -> MatrixAccumulator: - return self._binary_confusion_matrices.create_accumulator() - - def add_input( - self, - accumulator: MatrixAccumulator, - element: metric_types.StandardMetricInputs, - ) -> MatrixAccumulator: - example_id = None - if self._example_id_key and self._example_id_key in element.features: - example_id = element.features[self._example_id_key] - - labels = [] - predictions = [] - example_weights = [] - - for ( - label, - prediction, - example_weight, - ) in self._extract_label_prediction_and_weight( - element, - eval_config=self._eval_config, - model_name=self._key.model_name, - output_name=self._key.output_name, - sub_key=self._key.sub_key, - fractional_labels=self._enable_fractional_labels, - flatten=True, - aggregation_type=self._aggregation_type, - class_weights=self._class_weights, - example_weighted=self._example_weighted, - ): - example_weights.append(metric_util.safe_to_scalar(example_weight)) - labels.append(metric_util.safe_to_scalar(label)) - predictions.append(metric_util.safe_to_scalar(prediction)) - - return self._binary_confusion_matrices.add_input( - accumulator=accumulator, - labels=labels, - predictions=predictions, - example_weights=example_weights, - example_id=example_id, - ) + return [ + metric_types.MetricComputation( + keys=[key], + preprocessors=preprocessors, + combiner=_BinaryConfusionMatrixCombiner( + key=key, + eval_config=eval_config, + thresholds=thresholds, + extract_label_prediction_and_weight=extract_label_prediction_and_weight, + example_id_key=example_id_key, + example_ids_count=example_ids_count, + aggregation_type=aggregation_type, + class_weights=class_weights, + example_weighted=example_weighted, + enable_fractional_labels=enable_fractional_labels, + ), + ) + ] - def merge_accumulators( - self, accumulators: Iterable[MatrixAccumulator] - ) -> MatrixAccumulator: - return self._binary_confusion_matrices.merge_accumulators(accumulators) - def extract_output( - self, accumulator: MatrixAccumulator - ) -> Dict[metric_types.MetricKey, MatrixAccumulator]: - return { - self._key: self._binary_confusion_matrices.extract_output(accumulator) - } +class _BinaryConfusionMatrixCombiner(beam.CombineFn): + """Computes binary confusion matrix in TFMA.""" + + def __init__( + self, + key: metric_types.MetricKey, + eval_config: Optional[config_pb2.EvalConfig], + thresholds: List[float], + extract_label_prediction_and_weight: Callable[..., Any], + example_id_key: Optional[str], + example_ids_count: int, + aggregation_type: Optional[metric_types.AggregationType], + class_weights: Optional[Dict[int, float]], + example_weighted: bool, + enable_fractional_labels: bool, + ): + self._key = key + self._eval_config = eval_config + self._extract_label_prediction_and_weight = extract_label_prediction_and_weight + self._example_id_key = example_id_key + self._aggregation_type = aggregation_type + self._class_weights = class_weights + self._example_weighted = example_weighted + self._enable_fractional_labels = enable_fractional_labels + + self._binary_confusion_matrices = bcm_computations.BinaryConfusionMatrices( + thresholds=thresholds, + example_ids_count=example_ids_count, + enable_fractional_labels=enable_fractional_labels, + ) + + def create_accumulator(self) -> MatrixAccumulator: + return self._binary_confusion_matrices.create_accumulator() + + def add_input( + self, + accumulator: MatrixAccumulator, + element: metric_types.StandardMetricInputs, + ) -> MatrixAccumulator: + example_id = None + if self._example_id_key and self._example_id_key in element.features: + example_id = element.features[self._example_id_key] + + labels = [] + predictions = [] + example_weights = [] + + for ( + label, + prediction, + example_weight, + ) in self._extract_label_prediction_and_weight( + element, + eval_config=self._eval_config, + model_name=self._key.model_name, + output_name=self._key.output_name, + sub_key=self._key.sub_key, + fractional_labels=self._enable_fractional_labels, + flatten=True, + aggregation_type=self._aggregation_type, + class_weights=self._class_weights, + example_weighted=self._example_weighted, + ): + example_weights.append(metric_util.safe_to_scalar(example_weight)) + labels.append(metric_util.safe_to_scalar(label)) + predictions.append(metric_util.safe_to_scalar(prediction)) + + return self._binary_confusion_matrices.add_input( + accumulator=accumulator, + labels=labels, + predictions=predictions, + example_weights=example_weights, + example_id=example_id, + ) + + def merge_accumulators( + self, accumulators: Iterable[MatrixAccumulator] + ) -> MatrixAccumulator: + return self._binary_confusion_matrices.merge_accumulators(accumulators) + + def extract_output( + self, accumulator: MatrixAccumulator + ) -> Dict[metric_types.MetricKey, MatrixAccumulator]: + return {self._key: self._binary_confusion_matrices.extract_output(accumulator)} def _accumulator_to_matrices_and_examples( thresholds: List[float], acc: MatrixAccumulator ) -> Tuple[Matrices, Examples]: - """Converts MatrixAccumulator to binary confusion matrices.""" - matrices = Matrices(thresholds=[], tp=[], tn=[], fp=[], fn=[]) - examples = Examples( - thresholds=[], - tp_examples=[], - tn_examples=[], - fp_examples=[], - fn_examples=[]) - for threshold in thresholds: - matrices.thresholds.append(threshold) - matrices.tp.append(acc[threshold].matrix.tp) - matrices.tn.append(acc[threshold].matrix.tn) - matrices.fp.append(acc[threshold].matrix.fp) - matrices.fn.append(acc[threshold].matrix.fn) - - examples.thresholds.append(threshold) - examples.tp_examples.append(acc[threshold].tp_examples) - examples.tn_examples.append(acc[threshold].tn_examples) - examples.fp_examples.append(acc[threshold].fp_examples) - examples.fn_examples.append(acc[threshold].fn_examples) - return matrices, examples + """Converts MatrixAccumulator to binary confusion matrices.""" + matrices = Matrices(thresholds=[], tp=[], tn=[], fp=[], fn=[]) + examples = Examples( + thresholds=[], tp_examples=[], tn_examples=[], fp_examples=[], fn_examples=[] + ) + for threshold in thresholds: + matrices.thresholds.append(threshold) + matrices.tp.append(acc[threshold].matrix.tp) + matrices.tn.append(acc[threshold].matrix.tn) + matrices.fp.append(acc[threshold].matrix.fp) + matrices.fn.append(acc[threshold].matrix.fn) + + examples.thresholds.append(threshold) + examples.tp_examples.append(acc[threshold].tp_examples) + examples.tn_examples.append(acc[threshold].tn_examples) + examples.fp_examples.append(acc[threshold].fp_examples) + examples.fn_examples.append(acc[threshold].fn_examples) + return matrices, examples diff --git a/tensorflow_model_analysis/metrics/binary_confusion_matrices_test.py b/tensorflow_model_analysis/metrics/binary_confusion_matrices_test.py index 818d0198da..037242fe33 100644 --- a/tensorflow_model_analysis/metrics/binary_confusion_matrices_test.py +++ b/tensorflow_model_analysis/metrics/binary_confusion_matrices_test.py @@ -13,563 +13,562 @@ # limitations under the License. """Tests for binary confusion matrices.""" -from absl.testing import parameterized import apache_beam as beam -from apache_beam.testing import util import numpy as np import tensorflow as tf -from tensorflow_model_analysis.metrics import binary_confusion_matrices -from tensorflow_model_analysis.metrics import metric_types -from tensorflow_model_analysis.metrics import metric_util +from absl.testing import parameterized +from apache_beam.testing import util + +from tensorflow_model_analysis.metrics import ( + binary_confusion_matrices, + metric_types, + metric_util, +) from tensorflow_model_analysis.utils import test_util class BinaryConfusionMatricesTest( test_util.TensorflowModelAnalysisTest, parameterized.TestCase ): - - @parameterized.named_parameters( - { - 'testcase_name': '_empty', - 'left': binary_confusion_matrices.Matrices( - thresholds=[], tp=[], tn=[], fp=[], fn=[] - ), - 'right': binary_confusion_matrices.Matrices( - thresholds=[], tp=[], tn=[], fp=[], fn=[] - ), - 'expected': binary_confusion_matrices.Matrices( - thresholds=[], tp=[], tn=[], fp=[], fn=[] - ), - }, - { - 'testcase_name': '_different_thresholds_left_lower', - 'left': binary_confusion_matrices.Matrices( - thresholds=[0.5, 0.6], tp=[5, 6], tn=[5, 6], fp=[5, 6], fn=[5, 6] - ), - 'right': binary_confusion_matrices.Matrices( - thresholds=[0.6, 0.7], tp=[6, 7], tn=[6, 7], fp=[6, 7], fn=[6, 7] - ), - 'expected': binary_confusion_matrices.Matrices( - thresholds=[0.5, 0.6, 0.7], - tp=[5, 12, 7], - tn=[5, 12, 7], - fp=[5, 12, 7], - fn=[5, 12, 7], - ), - }, - { - 'testcase_name': '_different_thresholds_right_lower', - 'left': binary_confusion_matrices.Matrices( - thresholds=[0.6, 0.7], tp=[6, 7], tn=[6, 7], fp=[6, 7], fn=[6, 7] - ), - 'right': binary_confusion_matrices.Matrices( - thresholds=[0.5, 0.6], tp=[5, 6], tn=[5, 6], fp=[5, 6], fn=[5, 6] - ), - 'expected': binary_confusion_matrices.Matrices( - thresholds=[0.5, 0.6, 0.7], - tp=[5, 12, 7], - tn=[5, 12, 7], - fp=[5, 12, 7], - fn=[5, 12, 7], - ), - }, - { - 'testcase_name': '_different_thresholds_one_empty', - 'left': binary_confusion_matrices.Matrices( - thresholds=[0.5, 0.6], tp=[5, 6], tn=[5, 6], fp=[5, 6], fn=[5, 6] - ), - 'right': binary_confusion_matrices.Matrices( - thresholds=[], tp=[], tn=[], fp=[], fn=[] - ), - 'expected': binary_confusion_matrices.Matrices( - thresholds=[0.5, 0.6], tp=[5, 6], tn=[5, 6], fp=[5, 6], fn=[5, 6] - ), - }, - { - 'testcase_name': '_broadcast', - 'left': binary_confusion_matrices.Matrices( - thresholds=[0.5, 0.6], tp=[5, 6], tn=[5, 6], fp=[5, 6], fn=[5, 6] - ), - 'right': 1, - 'expected': binary_confusion_matrices.Matrices( - thresholds=[0.5, 0.6], tp=[6, 7], tn=[6, 7], fp=[6, 7], fn=[6, 7] - ), - }, - ) - def testAddBinaryConfusionMatrices(self, left, right, expected): - self.assertEqual(expected, left + right) - - @parameterized.named_parameters( - ( - 'using_num_thresholds', - { - 'num_thresholds': 3, - }, - binary_confusion_matrices.Matrices( - thresholds=[-1e-7, 0.5, 1.0 + 1e-7], - tp=[2.0, 1.0, 0.0], - fp=[2.0, 0.0, 0.0], - tn=[0.0, 2.0, 2.0], - fn=[0.0, 1.0, 2.0], - ), - ), - ( - 'single_threshold', - { - 'thresholds': [0.5], - 'use_histogram': True, - }, - binary_confusion_matrices.Matrices( - thresholds=[0.5], tp=[1.0], fp=[0.0], tn=[2.0], fn=[1.0] - ), - ), - ( - 'inner_thresholds', - { - 'thresholds': [0.25, 0.75], - 'use_histogram': True, - }, - binary_confusion_matrices.Matrices( - thresholds=[0.25, 0.75], - tp=[2.0, 1.0], - fp=[1.0, 0.0], - tn=[1.0, 2.0], - fn=[0.0, 1.0], - ), - ), - ( - 'boundary_thresholds', - { - 'thresholds': [0.0, 1.0], - 'use_histogram': True, - }, - binary_confusion_matrices.Matrices( - thresholds=[0.0, 1.0], - tp=[2.0, 0.0], - fp=[2.0, 0.0], - tn=[0.0, 2.0], - fn=[0.0, 2.0], - ), - ), - ( - 'left_boundary', - { - 'thresholds': [0.0, 0.5], - 'use_histogram': True, - }, - binary_confusion_matrices.Matrices( - thresholds=[0.0, 0.5], - tp=[2.0, 1.0], - fp=[2.0, 0.0], - tn=[0.0, 2.0], - fn=[0.0, 1.0], - ), - ), - ( - 'right_boundary', - { - 'thresholds': [0.5, 1.0], - 'use_histogram': True, - }, - binary_confusion_matrices.Matrices( - thresholds=[0.5, 1.0], - tp=[1.0, 0.0], - fp=[0.0, 0.0], - tn=[2.0, 2.0], - fn=[1.0, 2.0], - ), - ), - ) - def testBinaryConfusionMatrices(self, kwargs, expected_matrices): - computations = binary_confusion_matrices.binary_confusion_matrices(**kwargs) - histogram = computations[0] - matrices = computations[1] - - example1 = { - 'labels': np.array([0.0]), - 'predictions': np.array([0.0]), - 'example_weights': np.array([1.0]), - } - example2 = { - 'labels': np.array([0.0]), - 'predictions': np.array([0.5]), - 'example_weights': np.array([1.0]), - } - example3 = { - 'labels': np.array([1.0]), - 'predictions': np.array([0.3]), - 'example_weights': np.array([1.0]), - } - example4 = { - 'labels': np.array([1.0]), - 'predictions': np.array([0.9]), - 'example_weights': np.array([1.0]), - } - - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'Create' >> beam.Create([example1, example2, example3, example4]) - | - 'Process' >> beam.Map(metric_util.to_standard_metric_inputs) - | 'AddSlice' >> beam.Map(lambda x: ((), x)) - | 'ComputeHistogram' >> beam.CombinePerKey(histogram.combiner) - | 'ComputeMatrices' >> beam.Map( - lambda x: (x[0], matrices.result(x[1])))) # pyformat: disable - - # pylint: enable=no-value-for-parameter - - def check_result(got): - try: - self.assertLen(got, 1) - got_slice_key, got_metrics = got[0] - self.assertEqual(got_slice_key, ()) - self.assertLen(got_metrics, 1) - name_args = {'example_id_key': None, 'example_ids_count': None} - if 'num_thresholds' in kwargs: - thresholds = binary_confusion_matrices._interpolated_thresholds( - kwargs['num_thresholds'] - ) - name_args['num_thresholds'] = kwargs['num_thresholds'] - else: - thresholds = kwargs['thresholds'] - name_args['thresholds'] = thresholds - - name = metric_util.generate_private_name_from_arguments( - binary_confusion_matrices.BINARY_CONFUSION_MATRICES_NAME, - **name_args - ) - - matrices_key = metric_types.MetricKey(name=name) - self.assertIn(matrices_key, got_metrics) - got_matrices = got_metrics[matrices_key] - self.assertEqual(got_matrices, expected_matrices) - - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(result, check_result, label='result') - - @parameterized.named_parameters( - { - 'testcase_name': 'using_num_thresholds', - 'kwargs': { - 'num_thresholds': 3, - 'use_histogram': False, - }, - 'expected_metrics': { - metric_types.MetricKey( - '_binary_confusion_matrices:num_thresholds=3' - ): binary_confusion_matrices.Matrices( - thresholds=[-1e-7, 0.5, 1.0 + 1e-7], - tp=[2.0, 1.0, 0.0], - fp=[2.0, 0.0, 0.0], - tn=[0.0, 2.0, 2.0], - fn=[0.0, 1.0, 2.0], - ), - }, - }, - { - 'testcase_name': 'single_threshold', - 'kwargs': { - 'thresholds': [0.5], - }, - 'expected_metrics': { - metric_types.MetricKey( - '_binary_confusion_matrices:thresholds=[0.5]' - ): binary_confusion_matrices.Matrices( - thresholds=[0.5], tp=[1.0], fp=[0.0], tn=[2.0], fn=[1.0] - ) - }, - }, - { - 'testcase_name': 'multiple_thresholds', - 'kwargs': { - 'thresholds': [0.25, 0.75], - }, - 'expected_metrics': { - metric_types.MetricKey( - '_binary_confusion_matrices:thresholds=[0.25, 0.75]' - ): binary_confusion_matrices.Matrices( - thresholds=[0.25, 0.75], - tp=[2.0, 1.0], - fp=[1.0, 0.0], - tn=[1.0, 2.0], - fn=[0.0, 1.0], - ), - }, - }, - { - 'testcase_name': 'with_example_ids', - 'kwargs': { - 'thresholds': [0.1, 0.9], - 'example_id_key': 'example_id_key', - 'example_ids_count': 2, - }, - 'expected_metrics': { - metric_types.MetricKey( - '_binary_confusion_matrices:example_id_key=example_id_key,' - 'example_ids_count=2,thresholds=[0.1, 0.9]' - ): binary_confusion_matrices.Matrices( - thresholds=[0.1, 0.9], - tp=[2.0, 0.0], - fp=[1.0, 0.0], - tn=[1.0, 2.0], - fn=[0.0, 2.0], - ), - metric_types.MetricKey( - '_binary_confusion_examples:example_id_key=example_id_key,' - 'example_ids_count=2,thresholds=[0.1, 0.9]' - ): binary_confusion_matrices.Examples( - thresholds=[0.1, 0.9], - tp_examples=[['id_3', 'id_4'], []], - tn_examples=[['id_1'], ['id_1', 'id_2']], - fp_examples=[['id_2'], []], - fn_examples=[[], ['id_3', 'id_4']], - ), - }, - }, - ) - def testBinaryConfusionMatrices_noHistograms(self, kwargs, expected_metrics): - computations = binary_confusion_matrices.binary_confusion_matrices(**kwargs) - histogram = computations[0] - matrices = computations[1] - - example1 = { - 'labels': np.array([0.0]), - 'predictions': np.array([0.0]), - 'example_weights': np.array([1.0]), - 'features': { - 'example_id_key': np.array(['id_1']), + @parameterized.named_parameters( + { + "testcase_name": "_empty", + "left": binary_confusion_matrices.Matrices( + thresholds=[], tp=[], tn=[], fp=[], fn=[] + ), + "right": binary_confusion_matrices.Matrices( + thresholds=[], tp=[], tn=[], fp=[], fn=[] + ), + "expected": binary_confusion_matrices.Matrices( + thresholds=[], tp=[], tn=[], fp=[], fn=[] + ), }, - } - example2 = { - 'labels': np.array([0.0]), - 'predictions': np.array([0.5]), - 'example_weights': np.array([1.0]), - 'features': { - 'example_id_key': np.array(['id_2']), + { + "testcase_name": "_different_thresholds_left_lower", + "left": binary_confusion_matrices.Matrices( + thresholds=[0.5, 0.6], tp=[5, 6], tn=[5, 6], fp=[5, 6], fn=[5, 6] + ), + "right": binary_confusion_matrices.Matrices( + thresholds=[0.6, 0.7], tp=[6, 7], tn=[6, 7], fp=[6, 7], fn=[6, 7] + ), + "expected": binary_confusion_matrices.Matrices( + thresholds=[0.5, 0.6, 0.7], + tp=[5, 12, 7], + tn=[5, 12, 7], + fp=[5, 12, 7], + fn=[5, 12, 7], + ), }, - } - example3 = { - 'labels': np.array([1.0]), - 'predictions': np.array([0.3]), - 'example_weights': np.array([1.0]), - 'features': { - 'example_id_key': np.array(['id_3']), + { + "testcase_name": "_different_thresholds_right_lower", + "left": binary_confusion_matrices.Matrices( + thresholds=[0.6, 0.7], tp=[6, 7], tn=[6, 7], fp=[6, 7], fn=[6, 7] + ), + "right": binary_confusion_matrices.Matrices( + thresholds=[0.5, 0.6], tp=[5, 6], tn=[5, 6], fp=[5, 6], fn=[5, 6] + ), + "expected": binary_confusion_matrices.Matrices( + thresholds=[0.5, 0.6, 0.7], + tp=[5, 12, 7], + tn=[5, 12, 7], + fp=[5, 12, 7], + fn=[5, 12, 7], + ), }, - } - example4 = { - 'labels': np.array([1.0]), - 'predictions': np.array([0.9]), - 'example_weights': np.array([1.0]), - 'features': { - 'example_id_key': np.array(['id_4']), + { + "testcase_name": "_different_thresholds_one_empty", + "left": binary_confusion_matrices.Matrices( + thresholds=[0.5, 0.6], tp=[5, 6], tn=[5, 6], fp=[5, 6], fn=[5, 6] + ), + "right": binary_confusion_matrices.Matrices( + thresholds=[], tp=[], tn=[], fp=[], fn=[] + ), + "expected": binary_confusion_matrices.Matrices( + thresholds=[0.5, 0.6], tp=[5, 6], tn=[5, 6], fp=[5, 6], fn=[5, 6] + ), }, - } - - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'Create' >> beam.Create([example1, example2, example3, example4]) - | - 'Process' >> beam.Map(metric_util.to_standard_metric_inputs) - | 'AddSlice' >> beam.Map(lambda x: ((), x)) - | 'ComputeHistogram' >> beam.CombinePerKey(histogram.combiner) - | 'ComputeMatrices' >> beam.Map( - lambda x: (x[0], matrices.result(x[1])))) # pyformat: disable - - # pylint: enable=no-value-for-parameter - - def check_result(got): - try: - self.assertLen(got, 1) - got_slice_key, got_metrics = got[0] - self.assertEqual(got_slice_key, ()) - self.assertEqual(got_metrics, expected_metrics) - - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(result, check_result, label='result') - - def testBinaryConfusionMatricesExampleIdKeyNotInFeatures(self): - thresholds = [0.1, 0.9] - example_id_key = 'example_id_key' - example_ids_count = 2 - - # Example Ids are empty because 'example_id_key' will not be defined in - # 'features' - expected_metrics = { - metric_types.MetricKey( - '_binary_confusion_matrices:example_id_key=example_id_key,' - 'example_ids_count=2,thresholds=[0.1, 0.9]' - ): binary_confusion_matrices.Matrices( - thresholds=[0.1, 0.9], - tp=[2.0, 0.0], - fp=[1.0, 0.0], - tn=[1.0, 2.0], - fn=[0.0, 2.0], + { + "testcase_name": "_broadcast", + "left": binary_confusion_matrices.Matrices( + thresholds=[0.5, 0.6], tp=[5, 6], tn=[5, 6], fp=[5, 6], fn=[5, 6] + ), + "right": 1, + "expected": binary_confusion_matrices.Matrices( + thresholds=[0.5, 0.6], tp=[6, 7], tn=[6, 7], fp=[6, 7], fn=[6, 7] + ), + }, + ) + def testAddBinaryConfusionMatrices(self, left, right, expected): + self.assertEqual(expected, left + right) + + @parameterized.named_parameters( + ( + "using_num_thresholds", + { + "num_thresholds": 3, + }, + binary_confusion_matrices.Matrices( + thresholds=[-1e-7, 0.5, 1.0 + 1e-7], + tp=[2.0, 1.0, 0.0], + fp=[2.0, 0.0, 0.0], + tn=[0.0, 2.0, 2.0], + fn=[0.0, 1.0, 2.0], + ), ), - metric_types.MetricKey( - '_binary_confusion_examples:example_id_key=example_id_key,' - 'example_ids_count=2,thresholds=[0.1, 0.9]' - ): binary_confusion_matrices.Examples( - thresholds=[0.1, 0.9], - tp_examples=[[], []], - tn_examples=[[], []], - fp_examples=[[], []], - fn_examples=[[], []], + ( + "single_threshold", + { + "thresholds": [0.5], + "use_histogram": True, + }, + binary_confusion_matrices.Matrices( + thresholds=[0.5], tp=[1.0], fp=[0.0], tn=[2.0], fn=[1.0] + ), + ), + ( + "inner_thresholds", + { + "thresholds": [0.25, 0.75], + "use_histogram": True, + }, + binary_confusion_matrices.Matrices( + thresholds=[0.25, 0.75], + tp=[2.0, 1.0], + fp=[1.0, 0.0], + tn=[1.0, 2.0], + fn=[0.0, 1.0], + ), + ), + ( + "boundary_thresholds", + { + "thresholds": [0.0, 1.0], + "use_histogram": True, + }, + binary_confusion_matrices.Matrices( + thresholds=[0.0, 1.0], + tp=[2.0, 0.0], + fp=[2.0, 0.0], + tn=[0.0, 2.0], + fn=[0.0, 2.0], + ), + ), + ( + "left_boundary", + { + "thresholds": [0.0, 0.5], + "use_histogram": True, + }, + binary_confusion_matrices.Matrices( + thresholds=[0.0, 0.5], + tp=[2.0, 1.0], + fp=[2.0, 0.0], + tn=[0.0, 2.0], + fn=[0.0, 1.0], + ), + ), + ( + "right_boundary", + { + "thresholds": [0.5, 1.0], + "use_histogram": True, + }, + binary_confusion_matrices.Matrices( + thresholds=[0.5, 1.0], + tp=[1.0, 0.0], + fp=[0.0, 0.0], + tn=[2.0, 2.0], + fn=[1.0, 2.0], + ), ), - } - - computations = binary_confusion_matrices.binary_confusion_matrices( - thresholds=thresholds, - example_id_key=example_id_key, - example_ids_count=example_ids_count, ) - histogram = computations[0] - matrices = computations[1] - - examples = [ - { - 'labels': np.array([0.0]), - 'predictions': np.array([0.0]), - 'example_weights': np.array([1.0]), - 'features': {}, - }, + def testBinaryConfusionMatrices(self, kwargs, expected_matrices): + computations = binary_confusion_matrices.binary_confusion_matrices(**kwargs) + histogram = computations[0] + matrices = computations[1] + + example1 = { + "labels": np.array([0.0]), + "predictions": np.array([0.0]), + "example_weights": np.array([1.0]), + } + example2 = { + "labels": np.array([0.0]), + "predictions": np.array([0.5]), + "example_weights": np.array([1.0]), + } + example3 = { + "labels": np.array([1.0]), + "predictions": np.array([0.3]), + "example_weights": np.array([1.0]), + } + example4 = { + "labels": np.array([1.0]), + "predictions": np.array([0.9]), + "example_weights": np.array([1.0]), + } + + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "Create" >> beam.Create([example1, example2, example3, example4]) + | "Process" >> beam.Map(metric_util.to_standard_metric_inputs) + | "AddSlice" >> beam.Map(lambda x: ((), x)) + | "ComputeHistogram" >> beam.CombinePerKey(histogram.combiner) + | "ComputeMatrices" >> beam.Map(lambda x: (x[0], matrices.result(x[1]))) + ) # pyformat: disable + + # pylint: enable=no-value-for-parameter + + def check_result(got): + try: + self.assertLen(got, 1) + got_slice_key, got_metrics = got[0] + self.assertEqual(got_slice_key, ()) + self.assertLen(got_metrics, 1) + name_args = {"example_id_key": None, "example_ids_count": None} + if "num_thresholds" in kwargs: + thresholds = binary_confusion_matrices._interpolated_thresholds( + kwargs["num_thresholds"] + ) + name_args["num_thresholds"] = kwargs["num_thresholds"] + else: + thresholds = kwargs["thresholds"] + name_args["thresholds"] = thresholds + + name = metric_util.generate_private_name_from_arguments( + binary_confusion_matrices.BINARY_CONFUSION_MATRICES_NAME, + **name_args, + ) + + matrices_key = metric_types.MetricKey(name=name) + self.assertIn(matrices_key, got_metrics) + got_matrices = got_metrics[matrices_key] + self.assertEqual(got_matrices, expected_matrices) + + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(result, check_result, label="result") + + @parameterized.named_parameters( { - 'labels': np.array([0.0]), - 'predictions': np.array([0.5]), - 'example_weights': np.array([1.0]), - 'features': {}, + "testcase_name": "using_num_thresholds", + "kwargs": { + "num_thresholds": 3, + "use_histogram": False, + }, + "expected_metrics": { + metric_types.MetricKey( + "_binary_confusion_matrices:num_thresholds=3" + ): binary_confusion_matrices.Matrices( + thresholds=[-1e-7, 0.5, 1.0 + 1e-7], + tp=[2.0, 1.0, 0.0], + fp=[2.0, 0.0, 0.0], + tn=[0.0, 2.0, 2.0], + fn=[0.0, 1.0, 2.0], + ), + }, }, { - 'labels': np.array([1.0]), - 'predictions': np.array([0.3]), - 'example_weights': np.array([1.0]), - 'features': {}, + "testcase_name": "single_threshold", + "kwargs": { + "thresholds": [0.5], + }, + "expected_metrics": { + metric_types.MetricKey( + "_binary_confusion_matrices:thresholds=[0.5]" + ): binary_confusion_matrices.Matrices( + thresholds=[0.5], tp=[1.0], fp=[0.0], tn=[2.0], fn=[1.0] + ) + }, }, { - 'labels': np.array([1.0]), - 'predictions': np.array([0.9]), - 'example_weights': np.array([1.0]), - 'features': {}, + "testcase_name": "multiple_thresholds", + "kwargs": { + "thresholds": [0.25, 0.75], + }, + "expected_metrics": { + metric_types.MetricKey( + "_binary_confusion_matrices:thresholds=[0.25, 0.75]" + ): binary_confusion_matrices.Matrices( + thresholds=[0.25, 0.75], + tp=[2.0, 1.0], + fp=[1.0, 0.0], + tn=[1.0, 2.0], + fn=[0.0, 1.0], + ), + }, }, { - # Ensure that empty inputs are handled safely. - 'labels': np.array([]), - 'predictions': np.array([]), - 'example_weights': np.array([]), - 'features': {}, + "testcase_name": "with_example_ids", + "kwargs": { + "thresholds": [0.1, 0.9], + "example_id_key": "example_id_key", + "example_ids_count": 2, + }, + "expected_metrics": { + metric_types.MetricKey( + "_binary_confusion_matrices:example_id_key=example_id_key," + "example_ids_count=2,thresholds=[0.1, 0.9]" + ): binary_confusion_matrices.Matrices( + thresholds=[0.1, 0.9], + tp=[2.0, 0.0], + fp=[1.0, 0.0], + tn=[1.0, 2.0], + fn=[0.0, 2.0], + ), + metric_types.MetricKey( + "_binary_confusion_examples:example_id_key=example_id_key," + "example_ids_count=2,thresholds=[0.1, 0.9]" + ): binary_confusion_matrices.Examples( + thresholds=[0.1, 0.9], + tp_examples=[["id_3", "id_4"], []], + tn_examples=[["id_1"], ["id_1", "id_2"]], + fp_examples=[["id_2"], []], + fn_examples=[[], ["id_3", "id_4"]], + ), + }, }, - ] - - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'Create' >> beam.Create(examples) - | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs) - | 'AddSlice' >> beam.Map(lambda x: ((), x)) - | 'ComputeHistogram' >> beam.CombinePerKey(histogram.combiner) - | 'ComputeMatrices' >> beam.Map( - lambda x: (x[0], matrices.result(x[1])))) # pyformat: disable - - # pylint: enable=no-value-for-parameter - - def check_result(got): - try: - self.assertLen(got, 1) - got_slice_key, got_metrics = got[0] - self.assertEqual(got_slice_key, ()) - self.assertEqual(got_metrics, expected_metrics) - - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(result, check_result, label='result') - - def testBinaryConfusionMatricesTopK(self): - computations = binary_confusion_matrices.binary_confusion_matrices( - thresholds=[float('-inf')], - sub_key=metric_types.SubKey(top_k=3), - use_histogram=True, ) - histogram = computations[0] - matrices = computations[1] - - example1 = { - 'labels': np.array([2]), - 'predictions': np.array([0.1, 0.2, 0.1, 0.25, 0.35]), - 'example_weights': np.array([1.0]), - } - example2 = { - 'labels': np.array([1]), - 'predictions': np.array([0.2, 0.3, 0.05, 0.15, 0.3]), - 'example_weights': np.array([1.0]), - } - example3 = { - 'labels': np.array([3]), - 'predictions': np.array([0.01, 0.2, 0.09, 0.5, 0.2]), - 'example_weights': np.array([1.0]), - } - example4 = { - 'labels': np.array([4]), - 'predictions': np.array([0.3, 0.2, 0.05, 0.4, 0.05]), - 'example_weights': np.array([1.0]), - } - - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'Create' >> beam.Create([example1, example2, example3, example4]) - | - 'Process' >> beam.Map(metric_util.to_standard_metric_inputs) - | 'AddSlice' >> beam.Map(lambda x: ((), x)) - | 'ComputeHistogram' >> beam.CombinePerKey(histogram.combiner) - | 'ComputeMatrices' >> beam.Map( - lambda x: (x[0], matrices.result(x[1])))) # pyformat: disable - - # pylint: enable=no-value-for-parameter - - def check_result(got): - try: - self.assertLen(got, 1) - got_slice_key, got_metrics = got[0] - self.assertEqual(got_slice_key, ()) - self.assertLen(got_metrics, 1) - thresholds = [float('-inf')] - name = metric_util.generate_private_name_from_arguments( - binary_confusion_matrices.BINARY_CONFUSION_MATRICES_NAME, - thresholds=thresholds, - example_id_key=None, - example_ids_count=None, - ) - key = metric_types.MetricKey( - name=name, sub_key=metric_types.SubKey(top_k=3) - ) - self.assertIn(key, got_metrics) - got_matrices = got_metrics[key] - self.assertEqual( - got_matrices, - binary_confusion_matrices.Matrices( - thresholds=[float('-inf')], - tp=[2.0], - fp=[10.0], - tn=[6.0], - fn=[2.0], - ), - ) - - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(result, check_result, label='result') - - -if __name__ == '__main__': - tf.test.main() + def testBinaryConfusionMatrices_noHistograms(self, kwargs, expected_metrics): + computations = binary_confusion_matrices.binary_confusion_matrices(**kwargs) + histogram = computations[0] + matrices = computations[1] + + example1 = { + "labels": np.array([0.0]), + "predictions": np.array([0.0]), + "example_weights": np.array([1.0]), + "features": { + "example_id_key": np.array(["id_1"]), + }, + } + example2 = { + "labels": np.array([0.0]), + "predictions": np.array([0.5]), + "example_weights": np.array([1.0]), + "features": { + "example_id_key": np.array(["id_2"]), + }, + } + example3 = { + "labels": np.array([1.0]), + "predictions": np.array([0.3]), + "example_weights": np.array([1.0]), + "features": { + "example_id_key": np.array(["id_3"]), + }, + } + example4 = { + "labels": np.array([1.0]), + "predictions": np.array([0.9]), + "example_weights": np.array([1.0]), + "features": { + "example_id_key": np.array(["id_4"]), + }, + } + + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "Create" >> beam.Create([example1, example2, example3, example4]) + | "Process" >> beam.Map(metric_util.to_standard_metric_inputs) + | "AddSlice" >> beam.Map(lambda x: ((), x)) + | "ComputeHistogram" >> beam.CombinePerKey(histogram.combiner) + | "ComputeMatrices" >> beam.Map(lambda x: (x[0], matrices.result(x[1]))) + ) # pyformat: disable + + # pylint: enable=no-value-for-parameter + + def check_result(got): + try: + self.assertLen(got, 1) + got_slice_key, got_metrics = got[0] + self.assertEqual(got_slice_key, ()) + self.assertEqual(got_metrics, expected_metrics) + + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(result, check_result, label="result") + + def testBinaryConfusionMatricesExampleIdKeyNotInFeatures(self): + thresholds = [0.1, 0.9] + example_id_key = "example_id_key" + example_ids_count = 2 + + # Example Ids are empty because 'example_id_key' will not be defined in + # 'features' + expected_metrics = { + metric_types.MetricKey( + "_binary_confusion_matrices:example_id_key=example_id_key," + "example_ids_count=2,thresholds=[0.1, 0.9]" + ): binary_confusion_matrices.Matrices( + thresholds=[0.1, 0.9], + tp=[2.0, 0.0], + fp=[1.0, 0.0], + tn=[1.0, 2.0], + fn=[0.0, 2.0], + ), + metric_types.MetricKey( + "_binary_confusion_examples:example_id_key=example_id_key," + "example_ids_count=2,thresholds=[0.1, 0.9]" + ): binary_confusion_matrices.Examples( + thresholds=[0.1, 0.9], + tp_examples=[[], []], + tn_examples=[[], []], + fp_examples=[[], []], + fn_examples=[[], []], + ), + } + + computations = binary_confusion_matrices.binary_confusion_matrices( + thresholds=thresholds, + example_id_key=example_id_key, + example_ids_count=example_ids_count, + ) + histogram = computations[0] + matrices = computations[1] + + examples = [ + { + "labels": np.array([0.0]), + "predictions": np.array([0.0]), + "example_weights": np.array([1.0]), + "features": {}, + }, + { + "labels": np.array([0.0]), + "predictions": np.array([0.5]), + "example_weights": np.array([1.0]), + "features": {}, + }, + { + "labels": np.array([1.0]), + "predictions": np.array([0.3]), + "example_weights": np.array([1.0]), + "features": {}, + }, + { + "labels": np.array([1.0]), + "predictions": np.array([0.9]), + "example_weights": np.array([1.0]), + "features": {}, + }, + { + # Ensure that empty inputs are handled safely. + "labels": np.array([]), + "predictions": np.array([]), + "example_weights": np.array([]), + "features": {}, + }, + ] + + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "Create" >> beam.Create(examples) + | "Process" >> beam.Map(metric_util.to_standard_metric_inputs) + | "AddSlice" >> beam.Map(lambda x: ((), x)) + | "ComputeHistogram" >> beam.CombinePerKey(histogram.combiner) + | "ComputeMatrices" >> beam.Map(lambda x: (x[0], matrices.result(x[1]))) + ) # pyformat: disable + + # pylint: enable=no-value-for-parameter + + def check_result(got): + try: + self.assertLen(got, 1) + got_slice_key, got_metrics = got[0] + self.assertEqual(got_slice_key, ()) + self.assertEqual(got_metrics, expected_metrics) + + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(result, check_result, label="result") + + def testBinaryConfusionMatricesTopK(self): + computations = binary_confusion_matrices.binary_confusion_matrices( + thresholds=[float("-inf")], + sub_key=metric_types.SubKey(top_k=3), + use_histogram=True, + ) + histogram = computations[0] + matrices = computations[1] + + example1 = { + "labels": np.array([2]), + "predictions": np.array([0.1, 0.2, 0.1, 0.25, 0.35]), + "example_weights": np.array([1.0]), + } + example2 = { + "labels": np.array([1]), + "predictions": np.array([0.2, 0.3, 0.05, 0.15, 0.3]), + "example_weights": np.array([1.0]), + } + example3 = { + "labels": np.array([3]), + "predictions": np.array([0.01, 0.2, 0.09, 0.5, 0.2]), + "example_weights": np.array([1.0]), + } + example4 = { + "labels": np.array([4]), + "predictions": np.array([0.3, 0.2, 0.05, 0.4, 0.05]), + "example_weights": np.array([1.0]), + } + + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "Create" >> beam.Create([example1, example2, example3, example4]) + | "Process" >> beam.Map(metric_util.to_standard_metric_inputs) + | "AddSlice" >> beam.Map(lambda x: ((), x)) + | "ComputeHistogram" >> beam.CombinePerKey(histogram.combiner) + | "ComputeMatrices" >> beam.Map(lambda x: (x[0], matrices.result(x[1]))) + ) # pyformat: disable + + # pylint: enable=no-value-for-parameter + + def check_result(got): + try: + self.assertLen(got, 1) + got_slice_key, got_metrics = got[0] + self.assertEqual(got_slice_key, ()) + self.assertLen(got_metrics, 1) + thresholds = [float("-inf")] + name = metric_util.generate_private_name_from_arguments( + binary_confusion_matrices.BINARY_CONFUSION_MATRICES_NAME, + thresholds=thresholds, + example_id_key=None, + example_ids_count=None, + ) + key = metric_types.MetricKey( + name=name, sub_key=metric_types.SubKey(top_k=3) + ) + self.assertIn(key, got_metrics) + got_matrices = got_metrics[key] + self.assertEqual( + got_matrices, + binary_confusion_matrices.Matrices( + thresholds=[float("-inf")], + tp=[2.0], + fp=[10.0], + tn=[6.0], + fn=[2.0], + ), + ) + + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(result, check_result, label="result") + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow_model_analysis/metrics/bleu.py b/tensorflow_model_analysis/metrics/bleu.py index 82446984b8..2864de76e0 100644 --- a/tensorflow_model_analysis/metrics/bleu.py +++ b/tensorflow_model_analysis/metrics/bleu.py @@ -20,386 +20,397 @@ import dataclasses from typing import Iterable, Optional, Sequence -from absl import logging import apache_beam as beam import numpy as np import sacrebleu.metrics as sacrebleu -from tensorflow_model_analysis.metrics import metric_types -from tensorflow_model_analysis.metrics import metric_util -from tensorflow_model_analysis.proto import config_pb2 +from absl import logging +from tensorflow_model_analysis.metrics import metric_types, metric_util +from tensorflow_model_analysis.proto import config_pb2 -_BLEU_NAME_DEFAULT = 'BLEU' +_BLEU_NAME_DEFAULT = "BLEU" # TODO: b/287700355 - Add __slots__ to _Accumulator @dataclasses.dataclass class _Accumulator: - """Accumulator for _BleuCombiner. - - Attributes: - matching_ngrams: A list containing the number of matching n-grams between - the hypothesis and the reference for each n. This should be initialized as - np.zeros(max_ngram_order). - total_ngrams: A list containing the total number of n-grams for each n. Like - 'matching_ngrams', this should be initialized as an - np.zeros(max_ngram_order). - hyp_len: The number of unigrams (words) in the hypothesis. - ref_len: The number of unigrams (words) in the reference. - - matching_ngrams[n - 1] = number of matching n-grams for n > 0 - matching_ngrams[0] = number of matching unigrams - matching_ngrams[1] = number of matching bigrams - ... - - total_ngrams[n - 1] = ( - max(number of n-grams in hyp, number of n-grams in ref) for n > 0 - ) - total_ngrams[] follows same pattern as matching_ngrams[] - - For hypotheses and references, ending punctuation (periods, exclamation - points, etc.) count as their own unigram. - For example, 'Google.' has 2 unigrams: 'Google' and '.'. - """ - - matching_ngrams: np.ndarray - total_ngrams: np.ndarray - hyp_len: int = 0 - ref_len: int = 0 - - def __eq__(self, other): - return ( - np.array_equal(self.matching_ngrams, other.matching_ngrams) - and np.array_equal(self.total_ngrams, other.total_ngrams) - and self.hyp_len == other.hyp_len - and self.ref_len == other.ref_len + """Accumulator for _BleuCombiner. + + Attributes + ---------- + matching_ngrams: A list containing the number of matching n-grams between + the hypothesis and the reference for each n. This should be initialized as + np.zeros(max_ngram_order). + total_ngrams: A list containing the total number of n-grams for each n. Like + 'matching_ngrams', this should be initialized as an + np.zeros(max_ngram_order). + hyp_len: The number of unigrams (words) in the hypothesis. + ref_len: The number of unigrams (words) in the reference. + + matching_ngrams[n - 1] = number of matching n-grams for n > 0 + matching_ngrams[0] = number of matching unigrams + matching_ngrams[1] = number of matching bigrams + ... + + total_ngrams[n - 1] = ( + max(number of n-grams in hyp, number of n-grams in ref) for n > 0 ) + total_ngrams[] follows same pattern as matching_ngrams[] + + For hypotheses and references, ending punctuation (periods, exclamation + points, etc.) count as their own unigram. + For example, 'Google.' has 2 unigrams: 'Google' and '.'. + """ + + matching_ngrams: np.ndarray + total_ngrams: np.ndarray + hyp_len: int = 0 + ref_len: int = 0 + + def __eq__(self, other): + return ( + np.array_equal(self.matching_ngrams, other.matching_ngrams) + and np.array_equal(self.total_ngrams, other.total_ngrams) + and self.hyp_len == other.hyp_len + and self.ref_len == other.ref_len + ) # TODO: b/287700355 - Add __slots__ to this dataclass. @dataclasses.dataclass class _RefInfo: - ngrams: collections.Counter[dict[tuple[str], int]] # n-grams and counts - lens: list[int] # lengths + ngrams: collections.Counter[dict[tuple[str], int]] # n-grams and counts + lens: list[int] # lengths def _find_closest_ref_len(hyp_len: int, ref_lens: list[int]) -> int: - """Given a hypothesis length and a list of reference lengths, returns the closest reference length. - - Args: - hyp_len: The hypothesis length. - ref_lens: A list of reference lengths. The closest reference length. - - Returns: - The closest reference length, or -1 if ref_lens is empty. - """ - ref_lens_arr = np.array(ref_lens) - return ref_lens_arr[np.argmin(abs(ref_lens_arr - hyp_len))] - - -class _BleuCombiner(beam.CombineFn): - """Computes BLEU Score.""" - - def __init__( - self, - eval_config: config_pb2.EvalConfig, - model_name: str, - output_name: str, - key: metric_types.MetricKey, - **bleu_kwargs, - ): - """Initializes BLEU Combiner. + """Given a hypothesis length and a list of reference lengths, returns the closest reference length. Args: - eval_config: Eval config. - model_name: The model for which to compute these metrics. - output_name: The output name for which to compute these metrics. - key: MetricKey for extract_output(). - **bleu_kwargs: kwargs to initialize BLEU Metric. Possible options include - 'lowercase' (If True, lowercases the input, enabling - case-insensitivity.), 'force' (If True, insists that the tokenized input - is detokenized.), 'tokenize' (Tokenization method to use for BLEU. - Possible values are 'none' (No tokenization), 'zh' (Chinese - tokenization), '13a' (mimics the mteval-v13a from Moses), and 'intl' - (International tokenization, mimics the mteval-v14 script from Moses).), - 'smooth_method' (The smoothing method to use. Possible values are 'none' - (no smoothing), 'floor' (increment zero counts), 'add-k' (increment - num/denom by k for n>1), and 'exp' (exponential decay).), 'smooth_value' - (The smoothing value. Only valid when smoothmethod='floor' or - smooth_method='add-k'.), and 'effective_order' (If True, stops including - n-gram orders for which precision is 0. This should be True if - sentence-level BLEU will be computed.). - """ - self.eval_config = eval_config - self.model_name = model_name - self.output_name = output_name - self.key = key - self.bleu_metric = sacrebleu.BLEU(**bleu_kwargs) - - def _extract_statistics_for_empty_reference( - self, hypotheses: Sequence[str] - ) -> list[_Accumulator]: - """Returns sentence-level statistics when there are no references. - - Args: - hypotheses: A sequence of hypothesis strings. + ---- + hyp_len: The hypothesis length. + ref_lens: A list of reference lengths. The closest reference length. Returns: - A list of _Accumulators of segment statistics. + ------- + The closest reference length, or -1 if ref_lens is empty. """ - sum_hyp_len = 0 - for hypothesis in hypotheses: - _, hyp_len = sacrebleu.helpers.extract_all_word_ngrams( - hypothesis, 1, self.bleu_metric.max_ngram_order - ) - sum_hyp_len += hyp_len + ref_lens_arr = np.array(ref_lens) + return ref_lens_arr[np.argmin(abs(ref_lens_arr - hyp_len))] - # No n-grams. - matching_ngrams = np.zeros(self.bleu_metric.max_ngram_order, dtype=int) - total_ngrams = np.zeros(self.bleu_metric.max_ngram_order, dtype=int) - return [ - _Accumulator( - matching_ngrams=matching_ngrams, - total_ngrams=total_ngrams, - hyp_len=sum_hyp_len, +class _BleuCombiner(beam.CombineFn): + """Computes BLEU Score.""" + + def __init__( + self, + eval_config: config_pb2.EvalConfig, + model_name: str, + output_name: str, + key: metric_types.MetricKey, + **bleu_kwargs, + ): + """Initializes BLEU Combiner. + + Args: + ---- + eval_config: Eval config. + model_name: The model for which to compute these metrics. + output_name: The output name for which to compute these metrics. + key: MetricKey for extract_output(). + **bleu_kwargs: kwargs to initialize BLEU Metric. Possible options include + 'lowercase' (If True, lowercases the input, enabling + case-insensitivity.), 'force' (If True, insists that the tokenized input + is detokenized.), 'tokenize' (Tokenization method to use for BLEU. + Possible values are 'none' (No tokenization), 'zh' (Chinese + tokenization), '13a' (mimics the mteval-v13a from Moses), and 'intl' + (International tokenization, mimics the mteval-v14 script from Moses).), + 'smooth_method' (The smoothing method to use. Possible values are 'none' + (no smoothing), 'floor' (increment zero counts), 'add-k' (increment + num/denom by k for n>1), and 'exp' (exponential decay).), 'smooth_value' + (The smoothing value. Only valid when smoothmethod='floor' or + smooth_method='add-k'.), and 'effective_order' (If True, stops including + n-gram orders for which precision is 0. This should be True if + sentence-level BLEU will be computed.). + """ + self.eval_config = eval_config + self.model_name = model_name + self.output_name = output_name + self.key = key + self.bleu_metric = sacrebleu.BLEU(**bleu_kwargs) + + def _extract_statistics_for_empty_reference( + self, hypotheses: Sequence[str] + ) -> list[_Accumulator]: + """Returns sentence-level statistics when there are no references. + + Args: + ---- + hypotheses: A sequence of hypothesis strings. + + Returns: + ------- + A list of _Accumulators of segment statistics. + """ + sum_hyp_len = 0 + for hypothesis in hypotheses: + _, hyp_len = sacrebleu.helpers.extract_all_word_ngrams( + hypothesis, 1, self.bleu_metric.max_ngram_order + ) + sum_hyp_len += hyp_len + + # No n-grams. + matching_ngrams = np.zeros(self.bleu_metric.max_ngram_order, dtype=int) + total_ngrams = np.zeros(self.bleu_metric.max_ngram_order, dtype=int) + + return [ + _Accumulator( + matching_ngrams=matching_ngrams, + total_ngrams=total_ngrams, + hyp_len=sum_hyp_len, + ) + ] + + def _preprocess_segment(self, sentence: str) -> str: + """Given a sentence, lowercases (optionally) and tokenizes it.""" + if self.bleu_metric.lowercase: + sentence = sentence.lower() + return self.bleu_metric.tokenizer(sentence.rstrip()) + + def _extract_reference_info(self, refs: Sequence[str]) -> _RefInfo: + """Given a list of reference segments, extract the n-grams and reference lengths. + + The latter will be useful when comparing hypothesis and reference lengths. + + Args: + ---- + refs: A sequence of strings. + + Returns: + ------- + A _RefInfo() with reference ngrams and lengths. + """ + refs = iter(refs) + + final_ngrams, ref_len = sacrebleu.helpers.extract_all_word_ngrams( + next(refs), 1, self.bleu_metric.max_ngram_order + ) + ref_lens = [ref_len] + + for ref in refs: + # Extract n-grams for this ref. + new_ngrams, ref_len = sacrebleu.helpers.extract_all_word_ngrams( + ref, 1, self.bleu_metric.max_ngram_order + ) + + ref_lens.append(ref_len) + + # Merge counts across multiple references. + # The below loop is faster than 'final_ngrams |= new_ngrams'. + for ngram, count in new_ngrams.items(): + final_ngrams[ngram] = max(final_ngrams[ngram], count) + + return _RefInfo(ngrams=final_ngrams, lens=ref_lens) + + def _extract_reference_ngrams_and_lens( + self, references: Sequence[Sequence[str]] + ) -> list[_RefInfo]: + """Given the full set of document references, extract segment n-grams and lens.""" + ref_data = [] + + # Iterate through all references. + for refs in zip(*references): + # Remove undefined references and seperate ngrams. + lines = [ + self._preprocess_segment(line) for line in refs if line is not None + ] + + # Get n-grams data. + ref_data.append(self._extract_reference_info(lines)) + + return ref_data + + def _compute_segment_statistics( + self, + hypothesis: str, + ref_info: _RefInfo, + ) -> _Accumulator: + """Given a (pre-processed) hypothesis sentence and already computed reference n-grams & lengths, returns the best match statistics across the references. + + Args: + ---- + hypothesis: Hypothesis sentence. + ref_info: _RefInfo containing the counter with all n-grams and counts, and + the list of reference lengths. + + Returns: + ------- + An _Accumulator with match statistics. + """ + # Extract n-grams for the hypothesis. + hyp_ngrams, hyp_len = sacrebleu.helpers.extract_all_word_ngrams( + hypothesis, 1, self.bleu_metric.max_ngram_order ) - ] - - def _preprocess_segment(self, sentence: str) -> str: - """Given a sentence, lowercases (optionally) and tokenizes it.""" - if self.bleu_metric.lowercase: - sentence = sentence.lower() - return self.bleu_metric.tokenizer(sentence.rstrip()) - - def _extract_reference_info(self, refs: Sequence[str]) -> _RefInfo: - """Given a list of reference segments, extract the n-grams and reference lengths. - - The latter will be useful when comparing hypothesis and reference lengths. - - Args: - refs: A sequence of strings. - - Returns: - A _RefInfo() with reference ngrams and lengths. - """ - refs = iter(refs) - - final_ngrams, ref_len = sacrebleu.helpers.extract_all_word_ngrams( - next(refs), 1, self.bleu_metric.max_ngram_order - ) - ref_lens = [ref_len] - - for ref in refs: - # Extract n-grams for this ref. - new_ngrams, ref_len = sacrebleu.helpers.extract_all_word_ngrams( - ref, 1, self.bleu_metric.max_ngram_order - ) - - ref_lens.append(ref_len) - - # Merge counts across multiple references. - # The below loop is faster than 'final_ngrams |= new_ngrams'. - for ngram, count in new_ngrams.items(): - final_ngrams[ngram] = max(final_ngrams[ngram], count) - - return _RefInfo(ngrams=final_ngrams, lens=ref_lens) - - def _extract_reference_ngrams_and_lens( - self, references: Sequence[Sequence[str]] - ) -> list[_RefInfo]: - """Given the full set of document references, extract segment n-grams and lens.""" - ref_data = [] - - # Iterate through all references. - for refs in zip(*references): - # Remove undefined references and seperate ngrams. - lines = [ - self._preprocess_segment(line) for line in refs if line is not None - ] - - # Get n-grams data. - ref_data.append(self._extract_reference_info(lines)) - - return ref_data - - def _compute_segment_statistics( - self, - hypothesis: str, - ref_info: _RefInfo, - ) -> _Accumulator: - """Given a (pre-processed) hypothesis sentence and already computed reference n-grams & lengths, returns the best match statistics across the references. - - Args: - hypothesis: Hypothesis sentence. - ref_info: _RefInfo containing the counter with all n-grams and counts, and - the list of reference lengths. - - Returns: - An _Accumulator with match statistics. - """ - # Extract n-grams for the hypothesis. - hyp_ngrams, hyp_len = sacrebleu.helpers.extract_all_word_ngrams( - hypothesis, 1, self.bleu_metric.max_ngram_order - ) - - ref_len = _find_closest_ref_len(hyp_len, ref_info.lens) - # Count the stats. - # Although counter has its internal & and | operators, this is faster. - matching_ngrams = np.zeros(self.bleu_metric.max_ngram_order, dtype=int) - total_ngrams = np.zeros(self.bleu_metric.max_ngram_order, dtype=int) + ref_len = _find_closest_ref_len(hyp_len, ref_info.lens) - for hyp_ngram, hyp_count in hyp_ngrams.items(): - # n-gram order. - n = len(hyp_ngram) - 1 + # Count the stats. + # Although counter has its internal & and | operators, this is faster. + matching_ngrams = np.zeros(self.bleu_metric.max_ngram_order, dtype=int) + total_ngrams = np.zeros(self.bleu_metric.max_ngram_order, dtype=int) - # Count hypothesis n-grams. - total_ngrams[n] += hyp_count + for hyp_ngram, hyp_count in hyp_ngrams.items(): + # n-gram order. + n = len(hyp_ngram) - 1 - # Count matched n-grams. - ref_ngrams = ref_info.ngrams - if hyp_ngram in ref_ngrams: - matching_ngrams[n] += min(hyp_count, ref_ngrams[hyp_ngram]) + # Count hypothesis n-grams. + total_ngrams[n] += hyp_count - return _Accumulator( - matching_ngrams=matching_ngrams, - total_ngrams=total_ngrams, - hyp_len=hyp_len, - ref_len=ref_len, - ) + # Count matched n-grams. + ref_ngrams = ref_info.ngrams + if hyp_ngram in ref_ngrams: + matching_ngrams[n] += min(hyp_count, ref_ngrams[hyp_ngram]) - def _extract_corpus_statistics( - self, - hypotheses: Sequence[str], - references: Sequence[Sequence[str]], - ) -> list[_Accumulator]: - """Reads the corpus and returns sentence-level match statistics for faster re-computations esp during statistical tests. - - Args: - hypotheses: A sequence of hypothesis strings. - references: A sequence of reference documents with document being defined - as a sequence of reference strings of shape (batch_size_of_references x - batch_size_of_hypotheses). - - Returns: - A list of _Accumulators of segment statistics. - """ - if np.all((np.array(references) == [''])): - # Empty Reference. - return self._extract_statistics_for_empty_reference(hypotheses) + return _Accumulator( + matching_ngrams=matching_ngrams, + total_ngrams=total_ngrams, + hyp_len=hyp_len, + ref_len=ref_len, + ) - stats = [] - tok_count = 0 + def _extract_corpus_statistics( + self, + hypotheses: Sequence[str], + references: Sequence[Sequence[str]], + ) -> list[_Accumulator]: + """Reads the corpus and returns sentence-level match statistics for faster re-computations esp during statistical tests. + + Args: + ---- + hypotheses: A sequence of hypothesis strings. + references: A sequence of reference documents with document being defined + as a sequence of reference strings of shape (batch_size_of_references x + batch_size_of_hypotheses). + + Returns: + ------- + A list of _Accumulators of segment statistics. + """ + if np.all(np.array(references) == [""]): + # Empty Reference. + return self._extract_statistics_for_empty_reference(hypotheses) + + stats = [] + tok_count = 0 + + # Extract the new 'stats'. + for hyp, ref_kwargs in zip( + hypotheses, self._extract_reference_ngrams_and_lens(references) + ): + # Check for already-tokenized input problem. + if not self.bleu_metric._force and hyp.endswith(" ."): # pylint:disable=protected-access + tok_count += 1 + + # Collect stats. + stats.append( + self._compute_segment_statistics( + self._preprocess_segment(hyp), ref_kwargs + ) + ) + + if tok_count >= 100: + logging.warning("That's 100 lines that end in a tokenized period (' .')") + logging.warning( + "It looks like you forgot to detokenize your test data, which may" + " hurt your score." + ) + logging.warning( + "If you insist your data is detokenized, or don't care, you can" + " suppress this message with the 'force' parameter." + ) + + return stats + + def _compute_score_from_accumulator( + self, accumulator: _Accumulator + ) -> sacrebleu.BLEUScore: + """Computes the final score from already aggregated statistics. + + Args: + ---- + accumulator: An accumulator containing segment-level statistics. + + Returns: + ------- + A 'BLEUScore' object. + """ + bleu_metric = self.bleu_metric + + # TODO: b/319702245 - Resolve the issue below in compute_bleu(). + # We need to convert the np.ndarray's to a lists here. + # If we leave it as a np.ndarray of ints, then sacrebleu will not be able to + # add decimal smooth values to the stats list within compute_bleu(). + # If we convert it to an np.ndarray of floats, then sacrebleu will not be + # able to propely set BLEUScore._verbose because there is no format code 'd' + # for floats. + return self.bleu_metric.compute_bleu( + correct=accumulator.matching_ngrams.tolist(), + total=accumulator.total_ngrams.tolist(), + sys_len=accumulator.hyp_len, + ref_len=accumulator.ref_len, + smooth_method=bleu_metric.smooth_method, + smooth_value=bleu_metric.smooth_value, + effective_order=bleu_metric.effective_order, + max_ngram_order=bleu_metric.max_ngram_order, + ) - # Extract the new 'stats'. - for hyp, ref_kwargs in zip( - hypotheses, self._extract_reference_ngrams_and_lens(references) - ): - # Check for already-tokenized input problem. - if not self.bleu_metric._force and hyp.endswith(' .'): # pylint:disable=protected-access - tok_count += 1 - - # Collect stats. - stats.append( - self._compute_segment_statistics( - self._preprocess_segment(hyp), ref_kwargs - ) - ) - - if tok_count >= 100: - logging.warning("That's 100 lines that end in a tokenized period (' .')") - logging.warning( - 'It looks like you forgot to detokenize your test data, which may' - ' hurt your score.' - ) - logging.warning( - "If you insist your data is detokenized, or don't care, you can" - " suppress this message with the 'force' parameter." - ) - - return stats - - def _compute_score_from_accumulator( - self, accumulator: _Accumulator - ) -> sacrebleu.BLEUScore: - """Computes the final score from already aggregated statistics. + def create_accumulator(self): + return _Accumulator( + matching_ngrams=np.zeros(self.bleu_metric.max_ngram_order, dtype=int), + total_ngrams=np.zeros(self.bleu_metric.max_ngram_order, dtype=int), + ) - Args: - accumulator: An accumulator containing segment-level statistics. + def add_input( + self, + accumulator: _Accumulator, + element: metric_types.StandardMetricInputs, + ) -> _Accumulator: + # references = labels, hypotheses = predictions + references, hypotheses, _ = next( + metric_util.to_label_prediction_example_weight( + element, + eval_config=self.eval_config, + model_name=self.model_name, + output_name=self.output_name, + example_weighted=False, # Example weights not honored. + flatten=False, + squeeze=False, + ) + ) - Returns: - A 'BLEUScore' object. - """ - bleu_metric = self.bleu_metric - - # TODO: b/319702245 - Resolve the issue below in compute_bleu(). - # We need to convert the np.ndarray's to a lists here. - # If we leave it as a np.ndarray of ints, then sacrebleu will not be able to - # add decimal smooth values to the stats list within compute_bleu(). - # If we convert it to an np.ndarray of floats, then sacrebleu will not be - # able to propely set BLEUScore._verbose because there is no format code 'd' - # for floats. - return self.bleu_metric.compute_bleu( - correct=accumulator.matching_ngrams.tolist(), - total=accumulator.total_ngrams.tolist(), - sys_len=accumulator.hyp_len, - ref_len=accumulator.ref_len, - smooth_method=bleu_metric.smooth_method, - smooth_value=bleu_metric.smooth_value, - effective_order=bleu_metric.effective_order, - max_ngram_order=bleu_metric.max_ngram_order, - ) + corpus_stats = self._extract_corpus_statistics(hypotheses, references) + corpus_stats.append(accumulator) - def create_accumulator(self): - return _Accumulator( - matching_ngrams=np.zeros(self.bleu_metric.max_ngram_order, dtype=int), - total_ngrams=np.zeros(self.bleu_metric.max_ngram_order, dtype=int), - ) + return self.merge_accumulators(corpus_stats) - def add_input( - self, - accumulator: _Accumulator, - element: metric_types.StandardMetricInputs, - ) -> _Accumulator: - # references = labels, hypotheses = predictions - references, hypotheses, _ = next( - metric_util.to_label_prediction_example_weight( - element, - eval_config=self.eval_config, - model_name=self.model_name, - output_name=self.output_name, - example_weighted=False, # Example weights not honored. - flatten=False, - squeeze=False, - ) - ) + def merge_accumulators(self, accumulators: Iterable[_Accumulator]) -> _Accumulator: + accumulators = iter(accumulators) + result = next(accumulators) + for accumulator in accumulators: + result.hyp_len += accumulator.hyp_len + result.ref_len += accumulator.ref_len + result.matching_ngrams = np.sum( + [result.matching_ngrams, accumulator.matching_ngrams], axis=0 + ) + result.total_ngrams = np.sum( + [result.total_ngrams, accumulator.total_ngrams], axis=0 + ) + return result - corpus_stats = self._extract_corpus_statistics(hypotheses, references) - corpus_stats.append(accumulator) - - return self.merge_accumulators(corpus_stats) - - def merge_accumulators( - self, accumulators: Iterable[_Accumulator] - ) -> _Accumulator: - accumulators = iter(accumulators) - result = next(accumulators) - for accumulator in accumulators: - result.hyp_len += accumulator.hyp_len - result.ref_len += accumulator.ref_len - result.matching_ngrams = np.sum( - [result.matching_ngrams, accumulator.matching_ngrams], axis=0 - ) - result.total_ngrams = np.sum( - [result.total_ngrams, accumulator.total_ngrams], axis=0 - ) - return result - - def extract_output( - self, accumulator: _Accumulator - ) -> dict[metric_types.MetricKey, sacrebleu.BLEUScore]: - return {self.key: self._compute_score_from_accumulator(accumulator)} + def extract_output( + self, accumulator: _Accumulator + ) -> dict[metric_types.MetricKey, sacrebleu.BLEUScore]: + return {self.key: self._compute_score_from_accumulator(accumulator)} def _bleu( @@ -414,55 +425,54 @@ def _bleu( smooth_value: float, effective_order: bool, ) -> metric_types.MetricComputations: - """Returns BLEU score.""" - key = metric_types.MetricKey(name=name) - return [ - metric_types.MetricComputation( - keys=[key], - preprocessors=None, - combiner=_BleuCombiner( - eval_config=eval_config, - model_name=model_name, - output_name=output_name, - lowercase=lowercase, - force=force, - tokenize=tokenize, - smooth_method=smooth_method, - smooth_value=smooth_value, - effective_order=effective_order, - key=key, - ), - ) - ] + """Returns BLEU score.""" + key = metric_types.MetricKey(name=name) + return [ + metric_types.MetricComputation( + keys=[key], + preprocessors=None, + combiner=_BleuCombiner( + eval_config=eval_config, + model_name=model_name, + output_name=output_name, + lowercase=lowercase, + force=force, + tokenize=tokenize, + smooth_method=smooth_method, + smooth_value=smooth_value, + effective_order=effective_order, + key=key, + ), + ) + ] class Bleu(metric_types.Metric): - """BLEU Metric.""" - - def __init__( - self, - name: Optional[str] = _BLEU_NAME_DEFAULT, - lowercase: Optional[bool] = False, - force: Optional[bool] = False, - tokenize: Optional[str] = '13a', - smooth_method: Optional[str] = 'exp', - smooth_value: Optional[float] = None, - use_effective_order: Optional[bool] = False, - ): - """Initializes BLEU Metric.""" - - # This is 'use_effective_order' and not 'effective_order' for backward - # compatibility for old style BLEU API access (< 1.4.11) - super().__init__( - metric_util.merge_per_key_computations(_bleu), - name=name, - lowercase=lowercase, - force=force, - tokenize=tokenize, - smooth_method=smooth_method, - smooth_value=smooth_value, - effective_order=use_effective_order, - ) + """BLEU Metric.""" + + def __init__( + self, + name: Optional[str] = _BLEU_NAME_DEFAULT, + lowercase: Optional[bool] = False, + force: Optional[bool] = False, + tokenize: Optional[str] = "13a", + smooth_method: Optional[str] = "exp", + smooth_value: Optional[float] = None, + use_effective_order: Optional[bool] = False, + ): + """Initializes BLEU Metric.""" + # This is 'use_effective_order' and not 'effective_order' for backward + # compatibility for old style BLEU API access (< 1.4.11) + super().__init__( + metric_util.merge_per_key_computations(_bleu), + name=name, + lowercase=lowercase, + force=force, + tokenize=tokenize, + smooth_method=smooth_method, + smooth_value=smooth_value, + effective_order=use_effective_order, + ) metric_types.register_metric(Bleu) diff --git a/tensorflow_model_analysis/metrics/bleu_test.py b/tensorflow_model_analysis/metrics/bleu_test.py index 0cac537787..f3ee22b2d4 100644 --- a/tensorflow_model_analysis/metrics/bleu_test.py +++ b/tensorflow_model_analysis/metrics/bleu_test.py @@ -13,557 +13,549 @@ # limitations under the License. """Tests for BLEU metric.""" -from absl.testing import parameterized import apache_beam as beam -from apache_beam.testing import util import numpy as np import tensorflow as tf -import tensorflow_model_analysis as tfma -from tensorflow_model_analysis.proto import config_pb2 +from absl.testing import parameterized +from apache_beam.testing import util +from google.protobuf import text_format + from tensorflow_model_analysis import constants from tensorflow_model_analysis.evaluators import metrics_plots_and_validations_evaluator -from tensorflow_model_analysis.metrics import bleu -from tensorflow_model_analysis.metrics import metric_types -from tensorflow_model_analysis.metrics import metric_util +from tensorflow_model_analysis.metrics import bleu, metric_types, metric_util +from tensorflow_model_analysis.proto import config_pb2 from tensorflow_model_analysis.slicer import slicer_lib as slicer from tensorflow_model_analysis.utils import test_util -from google.protobuf import text_format - _Accumulator = bleu._Accumulator _EXAMPLES = { - 'perfect_score': { + "perfect_score": { constants.LABELS_KEY: [ - ['hello there general kenobi', 'Avengers! Assemble.'], - ['may the force be with you', 'I am Iron Man.'], + ["hello there general kenobi", "Avengers! Assemble."], + ["may the force be with you", "I am Iron Man."], ], constants.PREDICTIONS_KEY: [ - 'hello there general kenobi', - 'I am Iron Man.', + "hello there general kenobi", + "I am Iron Man.", ], }, - 'imperfect_score': { + "imperfect_score": { constants.LABELS_KEY: [ [ - 'The dog bit the man.', - 'It was not unexpected.', - 'The man bit him first.', + "The dog bit the man.", + "It was not unexpected.", + "The man bit him first.", ], [ - 'The dog had bit the man.', - 'No one was surprised.', - 'The man had bitten the dog.', + "The dog had bit the man.", + "No one was surprised.", + "The man had bitten the dog.", ], ], constants.PREDICTIONS_KEY: [ - 'The dog bit the man.', + "The dog bit the man.", "It wasn't surprising.", - 'The man had just bitten him.', + "The man had just bitten him.", ], }, - 'zero_score': { - constants.LABELS_KEY: [['So BLEU', 'will be 0.'], ['Foo.', 'Bar.']], - constants.PREDICTIONS_KEY: ['No matching text', 'in this test'], + "zero_score": { + constants.LABELS_KEY: [["So BLEU", "will be 0."], ["Foo.", "Bar."]], + constants.PREDICTIONS_KEY: ["No matching text", "in this test"], }, } def _get_result(pipeline, examples, combiner): - return ( - pipeline - | 'Create' >> beam.Create(examples) - | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs) - | 'AddSlice' >> beam.Map(lambda x: ((), x)) - | 'ComputeBleu' >> beam.CombinePerKey(combiner) - ) + return ( + pipeline + | "Create" >> beam.Create(examples) + | "Process" >> beam.Map(metric_util.to_standard_metric_inputs) + | "AddSlice" >> beam.Map(lambda x: ((), x)) + | "ComputeBleu" >> beam.CombinePerKey(combiner) + ) class FindClosestRefLenTest( test_util.TensorflowModelAnalysisTest, parameterized.TestCase ): + @parameterized.parameters((0, 2), (5, 4), (10, 10)) + def test_find_closest_ref_len(self, target, expected_closest): + candidates = [2, 4, 6, 8, 10] + self.assertEqual( + expected_closest, bleu._find_closest_ref_len(target, candidates) + ) + - @parameterized.parameters((0, 2), (5, 4), (10, 10)) - def test_find_closest_ref_len(self, target, expected_closest): - candidates = [2, 4, 6, 8, 10] - self.assertEqual( - expected_closest, bleu._find_closest_ref_len(target, candidates) +class BleuTest(test_util.TensorflowModelAnalysisTest, parameterized.TestCase): + def _check_got(self, got, expected_key): + self.assertLen(got, 1) + got_slice_key, got_metrics = got[0] + self.assertEqual(got_slice_key, ()) + self.assertIn(expected_key, got_metrics) + return got_metrics + + @parameterized.parameters( + ("perfect_score", 100), + ("imperfect_score", 48.53), + ("zero_score", 0), ) + def test_bleu_default(self, examples_key, expected_score): + key = metric_types.MetricKey(name=bleu._BLEU_NAME_DEFAULT) + computation = bleu.Bleu().computations()[0] + with beam.Pipeline() as pipeline: + result = _get_result( + pipeline, [_EXAMPLES[examples_key]], computation.combiner + ) + + def check_result(got): + try: + got_metrics = self._check_got(got, key) + self.assertAlmostEqual( + expected_score, got_metrics[key].score, places=2 + ) + + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(result, check_result, label="result") + + @parameterized.parameters( + ("perfect_score", 100), + ("imperfect_score", 48.53), + ("zero_score", 0), + ) + def test_bleu_name(self, examples_key, expected_score): + custom_name = "custom_name_set_by_caller" + key = metric_types.MetricKey(name=custom_name) + computation = bleu.Bleu(name=custom_name).computations()[0] -class BleuTest(test_util.TensorflowModelAnalysisTest, parameterized.TestCase): + with beam.Pipeline() as pipeline: + result = _get_result( + pipeline, [_EXAMPLES[examples_key]], computation.combiner + ) + + def check_result(got): + try: + got_metrics = self._check_got(got, key) + self.assertAlmostEqual( + expected_score, got_metrics[key].score, places=2 + ) + + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(result, check_result, label="result") + + @parameterized.named_parameters( + ( + "case-sensitive", + _EXAMPLES["perfect_score"][constants.LABELS_KEY], + [ + prediction.upper() + for prediction in _EXAMPLES["perfect_score"][constants.PREDICTIONS_KEY] + ], + False, + 7.58, + ), + ( + "case-insensitive", + _EXAMPLES["perfect_score"][constants.LABELS_KEY], + [ + prediction.upper() + for prediction in _EXAMPLES["perfect_score"][constants.PREDICTIONS_KEY] + ], + True, + 100, + ), + ) + def test_bleu_lowercase(self, labels, predictions, lowercase, expected_score): + example = { + constants.LABELS_KEY: labels, + constants.PREDICTIONS_KEY: predictions, + } + key = metric_types.MetricKey(name=bleu._BLEU_NAME_DEFAULT) + computation = bleu.Bleu(lowercase=lowercase).computations()[0] - def _check_got(self, got, expected_key): - self.assertLen(got, 1) - got_slice_key, got_metrics = got[0] - self.assertEqual(got_slice_key, ()) - self.assertIn(expected_key, got_metrics) - return got_metrics - - @parameterized.parameters( - ('perfect_score', 100), - ('imperfect_score', 48.53), - ('zero_score', 0), - ) - def test_bleu_default(self, examples_key, expected_score): - key = metric_types.MetricKey(name=bleu._BLEU_NAME_DEFAULT) - computation = bleu.Bleu().computations()[0] - - with beam.Pipeline() as pipeline: - result = _get_result( - pipeline, [_EXAMPLES[examples_key]], computation.combiner - ) - - def check_result(got): - try: - got_metrics = self._check_got(got, key) - self.assertAlmostEqual( - expected_score, got_metrics[key].score, places=2 - ) - - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(result, check_result, label='result') - - @parameterized.parameters( - ('perfect_score', 100), - ('imperfect_score', 48.53), - ('zero_score', 0), - ) - def test_bleu_name(self, examples_key, expected_score): - custom_name = 'custom_name_set_by_caller' - key = metric_types.MetricKey(name=custom_name) - computation = bleu.Bleu(name=custom_name).computations()[0] - - with beam.Pipeline() as pipeline: - result = _get_result( - pipeline, [_EXAMPLES[examples_key]], computation.combiner - ) - - def check_result(got): - try: - got_metrics = self._check_got(got, key) - self.assertAlmostEqual( - expected_score, got_metrics[key].score, places=2 - ) - - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(result, check_result, label='result') - - @parameterized.named_parameters( - ( - 'case-sensitive', - _EXAMPLES['perfect_score'][constants.LABELS_KEY], - [ - prediction.upper() - for prediction in _EXAMPLES['perfect_score'][ - constants.PREDICTIONS_KEY - ] - ], - False, - 7.58, - ), - ( - 'case-insensitive', - _EXAMPLES['perfect_score'][constants.LABELS_KEY], - [ - prediction.upper() - for prediction in _EXAMPLES['perfect_score'][ - constants.PREDICTIONS_KEY - ] - ], - True, - 100, - ), - ) - def test_bleu_lowercase(self, labels, predictions, lowercase, expected_score): - example = { - constants.LABELS_KEY: labels, - constants.PREDICTIONS_KEY: predictions, - } - key = metric_types.MetricKey(name=bleu._BLEU_NAME_DEFAULT) - computation = bleu.Bleu(lowercase=lowercase).computations()[0] - - with beam.Pipeline() as pipeline: - result = _get_result(pipeline, [example], computation.combiner) - - def check_result(got): - try: - got_metrics = self._check_got(got, key) - self.assertAlmostEqual( - expected_score, got_metrics[key].score, places=2 - ) - - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(result, check_result, label='result') - - @parameterized.parameters( - ('perfect_score', 'none', 100), - ('imperfect_score', 'none', 49.19), - ('zero_score', 'none', 0), - ('perfect_score', 'zh', 100), - ('imperfect_score', 'zh', 48.53), - ('zero_score', 'zh', 0), - ('perfect_score', 'intl', 100), - ('imperfect_score', 'intl', 43.92), - ('zero_score', 'intl', 0), - ) - def test_bleu_tokenize(self, examples_key, tokenizer, expected_score): - key = metric_types.MetricKey(name=bleu._BLEU_NAME_DEFAULT) - computation = bleu.Bleu(tokenize=tokenizer).computations()[0] - - with beam.Pipeline() as pipeline: - result = _get_result( - pipeline, [_EXAMPLES[examples_key]], computation.combiner - ) - - def check_result(got): - try: - got_metrics = self._check_got(got, key) - self.assertAlmostEqual( - expected_score, got_metrics[key].score, places=2 - ) - - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(result, check_result, label='result') - - def test_bleu_invalid_tokenizer(self): - invalid_tokenizer = 'invalid_tokenizer_name' - bleu_metric = bleu.Bleu(tokenize=invalid_tokenizer) - - with self.assertRaisesRegex(KeyError, invalid_tokenizer): - bleu_metric.computations() - - @parameterized.parameters( - # Perfect score is always perfect - ('perfect_score', (100,) * 3 * 5), - ( - 'imperfect_score', - # smooth_methods = 'none' or 'floor' - (48.53,) * 2 * 5 - + ( # smooth_method = 'add-k' - 48.53, # smooth_value = 0 - 50.74, # smooth_value = 0.5 - 52.70, # smooth_value = 1 - 43.05, # smooth_value = -1 - 56.03, # smooth_value = 2 - ), - ), - ) - def test_bleu_smoothing(self, examples_key, expected_scores): - smooth_methods = ('none', 'floor', 'add-k') - smooth_values = (0, 0.5, 1, -1, 2) - key = metric_types.MetricKey(name=bleu._BLEU_NAME_DEFAULT) - - for method_counter, smooth_method in enumerate(smooth_methods): - for value_counter, smooth_value in enumerate(smooth_values): - computation = bleu.Bleu( - smooth_method=smooth_method, smooth_value=smooth_value - ).computations()[0] with beam.Pipeline() as pipeline: - result = _get_result( - pipeline, [_EXAMPLES[examples_key]], computation.combiner - ) - - def check_result( - got, - inner_len=len(smooth_values), - outer_counter=method_counter, - inner_counter=value_counter, - ): - try: - got_metrics = self._check_got(got, key) - self.assertAlmostEqual( - expected_scores[inner_len * outer_counter + inner_counter], - got_metrics[key].score, - places=2, - ) - - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(result, check_result, label='result') - - def test_bleu_invalid_smooth_method(self): - invalid_smooth_method = 'invalid_smooth_method_name' - smooth_values = (0, 0.5, 1) - - for smooth_value in smooth_values: - bleu_metric = bleu.Bleu( - smooth_method=invalid_smooth_method, smooth_value=smooth_value - ) - with self.assertRaisesRegex(AssertionError, 'Unknown smooth_method '): - bleu_metric.computations() - - @parameterized.parameters( - ('perfect_score', 100), - ('imperfect_score', 48.53), - ('zero_score', 0), - ) - def test_bleu_use_effective_order(self, examples_key, expected_score): - key = metric_types.MetricKey(name=bleu._BLEU_NAME_DEFAULT) - computation = bleu.Bleu(use_effective_order=True).computations()[0] - - with beam.Pipeline() as pipeline: - result = _get_result( - pipeline, [_EXAMPLES[examples_key]], computation.combiner - ) - - def check_result(got): - try: - got_metrics = self._check_got(got, key) - self.assertAlmostEqual( - expected_score, got_metrics[key].score, places=2 - ) - - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(result, check_result, label='result') - - @parameterized.parameters( - ('perfect_score', 100), - ('imperfect_score', 48.53), - ('zero_score', 0), - ) - def test_bleu_multiple_examples(self, examples_key, expected_score): - combined_example = _EXAMPLES[examples_key] - list_of_examples = [] - - # Convert combined_example into a list of multiple examples - for i, prediction in enumerate(combined_example['predictions']): - list_of_examples.append({ - constants.LABELS_KEY: np.expand_dims( - np.array(combined_example['labels'])[:, i], axis=1 - ), - constants.PREDICTIONS_KEY: [prediction], - }) - - key = metric_types.MetricKey(name=bleu._BLEU_NAME_DEFAULT) - computation = bleu.Bleu().computations()[0] - - with beam.Pipeline() as pipeline: - result = _get_result(pipeline, list_of_examples, computation.combiner) - - def check_result(got): - try: - got_metrics = self._check_got(got, key) - self.assertAlmostEqual( - expected_score, got_metrics[key].score, places=2 - ) - - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(result, check_result, label='result') - - @parameterized.parameters( - ([''], ['']), - ([''], _EXAMPLES['perfect_score'][constants.PREDICTIONS_KEY]), - (_EXAMPLES['perfect_score'][constants.LABELS_KEY], ['']), - ) - def test_bleu_empty_label_or_prediction(self, labels, predictions): - example = { - constants.LABELS_KEY: labels, - constants.PREDICTIONS_KEY: predictions, - } - expected_score = 0 - key = metric_types.MetricKey(name=bleu._BLEU_NAME_DEFAULT) - computation = bleu.Bleu().computations()[0] - - with beam.Pipeline() as pipeline: - result = _get_result(pipeline, [example], computation.combiner) - - def check_result(got): - try: - got_metrics = self._check_got(got, key) - self.assertAlmostEqual(expected_score, got_metrics[key].score) - - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(result, check_result, label='result') - - @parameterized.parameters( - ( - 'perfect_score', - [ - _Accumulator( - matching_ngrams=np.array([4, 3, 2, 1]), - total_ngrams=np.array([4, 3, 2, 1]), - hyp_len=4, - ref_len=4, - ), - _Accumulator( - matching_ngrams=np.array([5, 4, 3, 2]), - total_ngrams=np.array([5, 4, 3, 2]), - hyp_len=5, - ref_len=5, - ), - ], - ), - ( - 'imperfect_score', - [ - _Accumulator( - matching_ngrams=[6, 5, 4, 3], - total_ngrams=[6, 5, 4, 3], - hyp_len=6, - ref_len=6, - ), - _Accumulator( - matching_ngrams=[2, 0, 0, 0], - total_ngrams=[4, 3, 2, 1], - hyp_len=4, - ref_len=5, - ), - _Accumulator( - matching_ngrams=[6, 2, 1, 0], - total_ngrams=[7, 6, 5, 4], - hyp_len=7, - ref_len=7, - ), - ], - ), - ( - 'zero_score', - [ - _Accumulator( - matching_ngrams=[0, 0, 0, 0], - total_ngrams=[3, 2, 1, 0], - hyp_len=3, - ref_len=2, - ), - _Accumulator( - matching_ngrams=[0, 0, 0, 0], - total_ngrams=[3, 2, 1, 0], - hyp_len=3, - ref_len=4, - ), - ], - ), - ) - def test_bleu_extract_corpus_statistics(self, examples_key, expected_accs): - examples = _EXAMPLES[examples_key] - actual_accs = bleu._BleuCombiner( - None, '', '', None - )._extract_corpus_statistics( - examples[constants.PREDICTIONS_KEY], examples[constants.LABELS_KEY] + result = _get_result(pipeline, [example], computation.combiner) + + def check_result(got): + try: + got_metrics = self._check_got(got, key) + self.assertAlmostEqual( + expected_score, got_metrics[key].score, places=2 + ) + + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(result, check_result, label="result") + + @parameterized.parameters( + ("perfect_score", "none", 100), + ("imperfect_score", "none", 49.19), + ("zero_score", "none", 0), + ("perfect_score", "zh", 100), + ("imperfect_score", "zh", 48.53), + ("zero_score", "zh", 0), + ("perfect_score", "intl", 100), + ("imperfect_score", "intl", 43.92), + ("zero_score", "intl", 0), ) + def test_bleu_tokenize(self, examples_key, tokenizer, expected_score): + key = metric_types.MetricKey(name=bleu._BLEU_NAME_DEFAULT) + computation = bleu.Bleu(tokenize=tokenizer).computations()[0] - for expected_acc, actual_acc in zip(expected_accs, actual_accs): - # Use __eq__() in _Accumulator(). - self.assertEqual(expected_acc, actual_acc) - - @parameterized.parameters( - ( - # Merge a non-empty _Accumulator() and an empty _Accumulator(). - [ - _Accumulator( - matching_ngrams=[6, 5, 4, 3], - total_ngrams=[6, 5, 4, 3], - hyp_len=6, - ref_len=6, - ), - _Accumulator( - matching_ngrams=[0, 0, 0, 0], - total_ngrams=[0, 0, 0, 0], - hyp_len=0, - ref_len=0, - ), - ], - _Accumulator( - matching_ngrams=[6, 5, 4, 3], - total_ngrams=[6, 5, 4, 3], - hyp_len=6, - ref_len=6, - ), - ), - ( - # Merge two non-empty _Accumulator()'s. - [ - _Accumulator( - matching_ngrams=[6, 5, 4, 3], - total_ngrams=[6, 5, 4, 3], - hyp_len=6, - ref_len=6, - ), - _Accumulator( - matching_ngrams=[6, 5, 4, 3], - total_ngrams=[6, 5, 4, 3], - hyp_len=6, - ref_len=6, - ), - ], - _Accumulator( - matching_ngrams=[12, 10, 8, 6], - total_ngrams=[12, 10, 8, 6], - hyp_len=12, - ref_len=12, - ), - ), - ( - # Merge two emtpy _Accumulaor()'s. - [ - _Accumulator( - matching_ngrams=[0, 0, 0, 0], - total_ngrams=[0, 0, 0, 0], - hyp_len=0, - ref_len=0, - ), - _Accumulator( - matching_ngrams=[0, 0, 0, 0], - total_ngrams=[0, 0, 0, 0], - hyp_len=0, - ref_len=0, - ), - ], - _Accumulator( - matching_ngrams=[0, 0, 0, 0], - total_ngrams=[0, 0, 0, 0], - hyp_len=0, - ref_len=0, - ), - ), - ( - # Call merge_accumulators() with one _Accumulator(). - [ - _Accumulator( - matching_ngrams=[14, 7, 5, 3], - total_ngrams=[17, 14, 11, 8], - hyp_len=17, - ref_len=18, - ) - ], - _Accumulator( - matching_ngrams=[14, 7, 5, 3], - total_ngrams=[17, 14, 11, 8], - hyp_len=17, - ref_len=18, - ), - ), - ) - def test_bleu_merge_accumulators(self, accs_list, expected_merged_acc): - actual_merged_acc = bleu._BleuCombiner( - None, '', '', None - ).merge_accumulators(accs_list) - - self.assertEqual(expected_merged_acc, actual_merged_acc) + with beam.Pipeline() as pipeline: + result = _get_result( + pipeline, [_EXAMPLES[examples_key]], computation.combiner + ) + + def check_result(got): + try: + got_metrics = self._check_got(got, key) + self.assertAlmostEqual( + expected_score, got_metrics[key].score, places=2 + ) + + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(result, check_result, label="result") + + def test_bleu_invalid_tokenizer(self): + invalid_tokenizer = "invalid_tokenizer_name" + bleu_metric = bleu.Bleu(tokenize=invalid_tokenizer) + + with self.assertRaisesRegex(KeyError, invalid_tokenizer): + bleu_metric.computations() + + @parameterized.parameters( + # Perfect score is always perfect + ("perfect_score", (100,) * 3 * 5), + ( + "imperfect_score", + # smooth_methods = 'none' or 'floor' + (48.53,) * 2 * 5 + + ( # smooth_method = 'add-k' + 48.53, # smooth_value = 0 + 50.74, # smooth_value = 0.5 + 52.70, # smooth_value = 1 + 43.05, # smooth_value = -1 + 56.03, # smooth_value = 2 + ), + ), + ) + def test_bleu_smoothing(self, examples_key, expected_scores): + smooth_methods = ("none", "floor", "add-k") + smooth_values = (0, 0.5, 1, -1, 2) + key = metric_types.MetricKey(name=bleu._BLEU_NAME_DEFAULT) + + for method_counter, smooth_method in enumerate(smooth_methods): + for value_counter, smooth_value in enumerate(smooth_values): + computation = bleu.Bleu( + smooth_method=smooth_method, smooth_value=smooth_value + ).computations()[0] + with beam.Pipeline() as pipeline: + result = _get_result( + pipeline, [_EXAMPLES[examples_key]], computation.combiner + ) + + def check_result( + got, + inner_len=len(smooth_values), + outer_counter=method_counter, + inner_counter=value_counter, + ): + try: + got_metrics = self._check_got(got, key) + self.assertAlmostEqual( + expected_scores[ + inner_len * outer_counter + inner_counter + ], + got_metrics[key].score, + places=2, + ) + + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(result, check_result, label="result") + + def test_bleu_invalid_smooth_method(self): + invalid_smooth_method = "invalid_smooth_method_name" + smooth_values = (0, 0.5, 1) + + for smooth_value in smooth_values: + bleu_metric = bleu.Bleu( + smooth_method=invalid_smooth_method, smooth_value=smooth_value + ) + with self.assertRaisesRegex(AssertionError, "Unknown smooth_method "): + bleu_metric.computations() + + @parameterized.parameters( + ("perfect_score", 100), + ("imperfect_score", 48.53), + ("zero_score", 0), + ) + def test_bleu_use_effective_order(self, examples_key, expected_score): + key = metric_types.MetricKey(name=bleu._BLEU_NAME_DEFAULT) + computation = bleu.Bleu(use_effective_order=True).computations()[0] + with beam.Pipeline() as pipeline: + result = _get_result( + pipeline, [_EXAMPLES[examples_key]], computation.combiner + ) + + def check_result(got): + try: + got_metrics = self._check_got(got, key) + self.assertAlmostEqual( + expected_score, got_metrics[key].score, places=2 + ) + + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(result, check_result, label="result") + + @parameterized.parameters( + ("perfect_score", 100), + ("imperfect_score", 48.53), + ("zero_score", 0), + ) + def test_bleu_multiple_examples(self, examples_key, expected_score): + combined_example = _EXAMPLES[examples_key] + list_of_examples = [] + + # Convert combined_example into a list of multiple examples + for i, prediction in enumerate(combined_example["predictions"]): + list_of_examples.append( + { + constants.LABELS_KEY: np.expand_dims( + np.array(combined_example["labels"])[:, i], axis=1 + ), + constants.PREDICTIONS_KEY: [prediction], + } + ) + + key = metric_types.MetricKey(name=bleu._BLEU_NAME_DEFAULT) + computation = bleu.Bleu().computations()[0] -class BleuEnd2EndTest(parameterized.TestCase): + with beam.Pipeline() as pipeline: + result = _get_result(pipeline, list_of_examples, computation.combiner) + + def check_result(got): + try: + got_metrics = self._check_got(got, key) + self.assertAlmostEqual( + expected_score, got_metrics[key].score, places=2 + ) + + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(result, check_result, label="result") + + @parameterized.parameters( + ([""], [""]), + ([""], _EXAMPLES["perfect_score"][constants.PREDICTIONS_KEY]), + (_EXAMPLES["perfect_score"][constants.LABELS_KEY], [""]), + ) + def test_bleu_empty_label_or_prediction(self, labels, predictions): + example = { + constants.LABELS_KEY: labels, + constants.PREDICTIONS_KEY: predictions, + } + expected_score = 0 + key = metric_types.MetricKey(name=bleu._BLEU_NAME_DEFAULT) + computation = bleu.Bleu().computations()[0] + + with beam.Pipeline() as pipeline: + result = _get_result(pipeline, [example], computation.combiner) - def test_bleu_end_2_end(self): - # Same test as BleuTest.testBleuDefault with 'imperfect_score' - eval_config = text_format.Parse( - """ + def check_result(got): + try: + got_metrics = self._check_got(got, key) + self.assertAlmostEqual(expected_score, got_metrics[key].score) + + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(result, check_result, label="result") + + @parameterized.parameters( + ( + "perfect_score", + [ + _Accumulator( + matching_ngrams=np.array([4, 3, 2, 1]), + total_ngrams=np.array([4, 3, 2, 1]), + hyp_len=4, + ref_len=4, + ), + _Accumulator( + matching_ngrams=np.array([5, 4, 3, 2]), + total_ngrams=np.array([5, 4, 3, 2]), + hyp_len=5, + ref_len=5, + ), + ], + ), + ( + "imperfect_score", + [ + _Accumulator( + matching_ngrams=[6, 5, 4, 3], + total_ngrams=[6, 5, 4, 3], + hyp_len=6, + ref_len=6, + ), + _Accumulator( + matching_ngrams=[2, 0, 0, 0], + total_ngrams=[4, 3, 2, 1], + hyp_len=4, + ref_len=5, + ), + _Accumulator( + matching_ngrams=[6, 2, 1, 0], + total_ngrams=[7, 6, 5, 4], + hyp_len=7, + ref_len=7, + ), + ], + ), + ( + "zero_score", + [ + _Accumulator( + matching_ngrams=[0, 0, 0, 0], + total_ngrams=[3, 2, 1, 0], + hyp_len=3, + ref_len=2, + ), + _Accumulator( + matching_ngrams=[0, 0, 0, 0], + total_ngrams=[3, 2, 1, 0], + hyp_len=3, + ref_len=4, + ), + ], + ), + ) + def test_bleu_extract_corpus_statistics(self, examples_key, expected_accs): + examples = _EXAMPLES[examples_key] + actual_accs = bleu._BleuCombiner(None, "", "", None)._extract_corpus_statistics( + examples[constants.PREDICTIONS_KEY], examples[constants.LABELS_KEY] + ) + + for expected_acc, actual_acc in zip(expected_accs, actual_accs): + # Use __eq__() in _Accumulator(). + self.assertEqual(expected_acc, actual_acc) + + @parameterized.parameters( + ( + # Merge a non-empty _Accumulator() and an empty _Accumulator(). + [ + _Accumulator( + matching_ngrams=[6, 5, 4, 3], + total_ngrams=[6, 5, 4, 3], + hyp_len=6, + ref_len=6, + ), + _Accumulator( + matching_ngrams=[0, 0, 0, 0], + total_ngrams=[0, 0, 0, 0], + hyp_len=0, + ref_len=0, + ), + ], + _Accumulator( + matching_ngrams=[6, 5, 4, 3], + total_ngrams=[6, 5, 4, 3], + hyp_len=6, + ref_len=6, + ), + ), + ( + # Merge two non-empty _Accumulator()'s. + [ + _Accumulator( + matching_ngrams=[6, 5, 4, 3], + total_ngrams=[6, 5, 4, 3], + hyp_len=6, + ref_len=6, + ), + _Accumulator( + matching_ngrams=[6, 5, 4, 3], + total_ngrams=[6, 5, 4, 3], + hyp_len=6, + ref_len=6, + ), + ], + _Accumulator( + matching_ngrams=[12, 10, 8, 6], + total_ngrams=[12, 10, 8, 6], + hyp_len=12, + ref_len=12, + ), + ), + ( + # Merge two emtpy _Accumulaor()'s. + [ + _Accumulator( + matching_ngrams=[0, 0, 0, 0], + total_ngrams=[0, 0, 0, 0], + hyp_len=0, + ref_len=0, + ), + _Accumulator( + matching_ngrams=[0, 0, 0, 0], + total_ngrams=[0, 0, 0, 0], + hyp_len=0, + ref_len=0, + ), + ], + _Accumulator( + matching_ngrams=[0, 0, 0, 0], + total_ngrams=[0, 0, 0, 0], + hyp_len=0, + ref_len=0, + ), + ), + ( + # Call merge_accumulators() with one _Accumulator(). + [ + _Accumulator( + matching_ngrams=[14, 7, 5, 3], + total_ngrams=[17, 14, 11, 8], + hyp_len=17, + ref_len=18, + ) + ], + _Accumulator( + matching_ngrams=[14, 7, 5, 3], + total_ngrams=[17, 14, 11, 8], + hyp_len=17, + ref_len=18, + ), + ), + ) + def test_bleu_merge_accumulators(self, accs_list, expected_merged_acc): + actual_merged_acc = bleu._BleuCombiner(None, "", "", None).merge_accumulators( + accs_list + ) + + self.assertEqual(expected_merged_acc, actual_merged_acc) + + +class BleuEnd2EndTest(parameterized.TestCase): + def test_bleu_end_2_end(self): + # Same test as BleuTest.testBleuDefault with 'imperfect_score' + eval_config = text_format.Parse( + """ model_specs { label_key: "labels" prediction_key: "predictions" @@ -574,66 +566,66 @@ def test_bleu_end_2_end(self): } } """, - config_pb2.EvalConfig(), - ) + config_pb2.EvalConfig(), + ) + + example1 = { + constants.SLICE_KEY_TYPES_KEY: slicer.slice_keys_to_numpy_array([()]), + constants.FEATURES_KEY: None, + constants.LABELS_KEY: [ + ["The dog bit the man."], + ["The dog had bit the man."], + ], + constants.PREDICTIONS_KEY: ["The dog bit the man."], + } + example2 = { + constants.SLICE_KEY_TYPES_KEY: slicer.slice_keys_to_numpy_array([()]), + constants.FEATURES_KEY: None, + constants.LABELS_KEY: [ + ["It was not unexpected."], + ["No one was surprised."], + ], + constants.PREDICTIONS_KEY: ["It wasn't surprising."], + } + example3 = { + constants.SLICE_KEY_TYPES_KEY: slicer.slice_keys_to_numpy_array([()]), + constants.FEATURES_KEY: None, + constants.LABELS_KEY: [ + ["The man bit him first."], + ["The man had bitten the dog."], + ], + constants.PREDICTIONS_KEY: ["The man had just bitten him."], + } - example1 = { - constants.SLICE_KEY_TYPES_KEY: slicer.slice_keys_to_numpy_array([()]), - constants.FEATURES_KEY: None, - constants.LABELS_KEY: [ - ['The dog bit the man.'], - ['The dog had bit the man.'], - ], - constants.PREDICTIONS_KEY: ['The dog bit the man.'], - } - example2 = { - constants.SLICE_KEY_TYPES_KEY: slicer.slice_keys_to_numpy_array([()]), - constants.FEATURES_KEY: None, - constants.LABELS_KEY: [ - ['It was not unexpected.'], - ['No one was surprised.'], - ], - constants.PREDICTIONS_KEY: ["It wasn't surprising."], - } - example3 = { - constants.SLICE_KEY_TYPES_KEY: slicer.slice_keys_to_numpy_array([()]), - constants.FEATURES_KEY: None, - constants.LABELS_KEY: [ - ['The man bit him first.'], - ['The man had bitten the dog.'], - ], - constants.PREDICTIONS_KEY: ['The man had just bitten him.'], - } - - expected_score = 48.53 - key = metric_types.MetricKey(name=bleu._BLEU_NAME_DEFAULT) - - with beam.Pipeline() as pipeline: - result = ( - pipeline - | 'LoadData' >> beam.Create([example1, example2, example3]) - | 'ExtractEval' - >> metrics_plots_and_validations_evaluator.MetricsPlotsAndValidationsEvaluator( - eval_config=eval_config - ).ptransform - ) - - def check_result(got): - try: - self.assertLen(got, 1) - got_slice_key, got_metrics = got[0] - self.assertEqual(got_slice_key, ()) - self.assertIn(key, got_metrics.keys()) - self.assertAlmostEqual( - expected_score, got_metrics[key].score, places=2 - ) - - except AssertionError as err: - raise util.BeamAssertException(err) - - self.assertIn('metrics', result) - util.assert_that(result['metrics'], check_result, label='result') - - -if __name__ == '__main__': - tf.test.main() + expected_score = 48.53 + key = metric_types.MetricKey(name=bleu._BLEU_NAME_DEFAULT) + + with beam.Pipeline() as pipeline: + result = ( + pipeline + | "LoadData" >> beam.Create([example1, example2, example3]) + | "ExtractEval" + >> metrics_plots_and_validations_evaluator.MetricsPlotsAndValidationsEvaluator( + eval_config=eval_config + ).ptransform + ) + + def check_result(got): + try: + self.assertLen(got, 1) + got_slice_key, got_metrics = got[0] + self.assertEqual(got_slice_key, ()) + self.assertIn(key, got_metrics.keys()) + self.assertAlmostEqual( + expected_score, got_metrics[key].score, places=2 + ) + + except AssertionError as err: + raise util.BeamAssertException(err) + + self.assertIn("metrics", result) + util.assert_that(result["metrics"], check_result, label="result") + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow_model_analysis/metrics/calibration.py b/tensorflow_model_analysis/metrics/calibration.py index 91121ea03a..ded3b423a0 100644 --- a/tensorflow_model_analysis/metrics/calibration.py +++ b/tensorflow_model_analysis/metrics/calibration.py @@ -17,28 +17,27 @@ import apache_beam as beam import numpy as np -from tensorflow_model_analysis.metrics import metric_types -from tensorflow_model_analysis.metrics import metric_util + +from tensorflow_model_analysis.metrics import metric_types, metric_util from tensorflow_model_analysis.proto import config_pb2 -CALIBRATION_NAME = 'calibration' -MEAN_LABEL_NAME = 'mean_label' -MEAN_PREDICTION_NAME = 'mean_prediction' -_WEIGHTED_LABELS_PREDICTIONS_EXAMPLES_NAME = ( - '_weighted_labels_predictions_examples') +CALIBRATION_NAME = "calibration" +MEAN_LABEL_NAME = "mean_label" +MEAN_PREDICTION_NAME = "mean_prediction" +_WEIGHTED_LABELS_PREDICTIONS_EXAMPLES_NAME = "_weighted_labels_predictions_examples" class MeanLabel(metric_types.Metric): - """Mean label.""" + """Mean label.""" - def __init__(self, name: str = MEAN_LABEL_NAME): - """Initializes mean label. + def __init__(self, name: str = MEAN_LABEL_NAME): + """Initializes mean label. - Args: - name: Metric name. - """ - super().__init__( - metric_util.merge_per_key_computations(_mean_label), name=name) + Args: + ---- + name: Metric name. + """ + super().__init__(metric_util.merge_per_key_computations(_mean_label), name=name) metric_types.register_metric(MeanLabel) @@ -47,59 +46,65 @@ def __init__(self, name: str = MEAN_LABEL_NAME): def _mean_label( name: str = MEAN_LABEL_NAME, eval_config: Optional[config_pb2.EvalConfig] = None, - model_name: str = '', - output_name: str = '', + model_name: str = "", + output_name: str = "", sub_key: Optional[metric_types.SubKey] = None, aggregation_type: Optional[metric_types.AggregationType] = None, class_weights: Optional[Dict[int, float]] = None, - example_weighted: bool = False) -> metric_types.MetricComputations: - """Returns metric computations for mean label.""" - key = metric_types.MetricKey( - name=name, - model_name=model_name, - output_name=output_name, - sub_key=sub_key, - example_weighted=example_weighted) - - # Make sure weighted_labels_predictions_examples are calculated. - computations = _weighted_labels_predictions_examples( - eval_config=eval_config, - model_name=model_name, - output_name=output_name, - sub_key=sub_key, - aggregation_type=aggregation_type, - class_weights=class_weights, - example_weighted=example_weighted) - weighted_labels_predictions_key = computations[-1].keys[-1] - - def result( - metrics: Dict[metric_types.MetricKey, Any] - ) -> Dict[metric_types.MetricKey, Any]: - """Returns mean label.""" - metric = metrics[weighted_labels_predictions_key] - if np.isclose(metric.total_weighted_examples, 0.0): - value = float('nan') - else: - value = metric.total_weighted_labels / metric.total_weighted_examples - return {key: value} - - derived_computation = metric_types.DerivedMetricComputation( - keys=[key], result=result) - computations.append(derived_computation) - return computations + example_weighted: bool = False, +) -> metric_types.MetricComputations: + """Returns metric computations for mean label.""" + key = metric_types.MetricKey( + name=name, + model_name=model_name, + output_name=output_name, + sub_key=sub_key, + example_weighted=example_weighted, + ) + + # Make sure weighted_labels_predictions_examples are calculated. + computations = _weighted_labels_predictions_examples( + eval_config=eval_config, + model_name=model_name, + output_name=output_name, + sub_key=sub_key, + aggregation_type=aggregation_type, + class_weights=class_weights, + example_weighted=example_weighted, + ) + weighted_labels_predictions_key = computations[-1].keys[-1] + + def result( + metrics: Dict[metric_types.MetricKey, Any], + ) -> Dict[metric_types.MetricKey, Any]: + """Returns mean label.""" + metric = metrics[weighted_labels_predictions_key] + if np.isclose(metric.total_weighted_examples, 0.0): + value = float("nan") + else: + value = metric.total_weighted_labels / metric.total_weighted_examples + return {key: value} + + derived_computation = metric_types.DerivedMetricComputation( + keys=[key], result=result + ) + computations.append(derived_computation) + return computations class MeanPrediction(metric_types.Metric): - """Mean prediction.""" + """Mean prediction.""" - def __init__(self, name: str = MEAN_PREDICTION_NAME): - """Initializes mean prediction. + def __init__(self, name: str = MEAN_PREDICTION_NAME): + """Initializes mean prediction. - Args: - name: Metric name. - """ - super().__init__( - metric_util.merge_per_key_computations(_mean_prediction), name=name) + Args: + ---- + name: Metric name. + """ + super().__init__( + metric_util.merge_per_key_computations(_mean_prediction), name=name + ) metric_types.register_metric(MeanPrediction) @@ -108,63 +113,69 @@ def __init__(self, name: str = MEAN_PREDICTION_NAME): def _mean_prediction( name: str = MEAN_PREDICTION_NAME, eval_config: Optional[config_pb2.EvalConfig] = None, - model_name: str = '', - output_name: str = '', + model_name: str = "", + output_name: str = "", sub_key: Optional[metric_types.SubKey] = None, aggregation_type: Optional[metric_types.AggregationType] = None, class_weights: Optional[Dict[int, float]] = None, - example_weighted: bool = False) -> metric_types.MetricComputations: - """Returns metric computations for mean prediction.""" - key = metric_types.MetricKey( - name=name, - model_name=model_name, - output_name=output_name, - sub_key=sub_key, - example_weighted=example_weighted) - - # Make sure weighted_labels_predictions_examples are calculated. - computations = _weighted_labels_predictions_examples( - eval_config=eval_config, - model_name=model_name, - output_name=output_name, - sub_key=sub_key, - aggregation_type=aggregation_type, - class_weights=class_weights, - example_weighted=example_weighted) - weighted_labels_predictions_key = computations[-1].keys[-1] - - def result( - metrics: Dict[metric_types.MetricKey, Any] - ) -> Dict[metric_types.MetricKey, Any]: - """Returns mean prediction.""" - metric = metrics[weighted_labels_predictions_key] - if np.isclose(metric.total_weighted_examples, 0.0): - value = float('nan') - else: - value = metric.total_weighted_predictions / metric.total_weighted_examples - return {key: value} - - derived_computation = metric_types.DerivedMetricComputation( - keys=[key], result=result) - computations.append(derived_computation) - return computations + example_weighted: bool = False, +) -> metric_types.MetricComputations: + """Returns metric computations for mean prediction.""" + key = metric_types.MetricKey( + name=name, + model_name=model_name, + output_name=output_name, + sub_key=sub_key, + example_weighted=example_weighted, + ) + + # Make sure weighted_labels_predictions_examples are calculated. + computations = _weighted_labels_predictions_examples( + eval_config=eval_config, + model_name=model_name, + output_name=output_name, + sub_key=sub_key, + aggregation_type=aggregation_type, + class_weights=class_weights, + example_weighted=example_weighted, + ) + weighted_labels_predictions_key = computations[-1].keys[-1] + + def result( + metrics: Dict[metric_types.MetricKey, Any], + ) -> Dict[metric_types.MetricKey, Any]: + """Returns mean prediction.""" + metric = metrics[weighted_labels_predictions_key] + if np.isclose(metric.total_weighted_examples, 0.0): + value = float("nan") + else: + value = metric.total_weighted_predictions / metric.total_weighted_examples + return {key: value} + + derived_computation = metric_types.DerivedMetricComputation( + keys=[key], result=result + ) + computations.append(derived_computation) + return computations class Calibration(metric_types.Metric): - """Calibration. + """Calibration. - Calibration in this context is defined as the total weighted predictions / - total weighted labels. - """ + Calibration in this context is defined as the total weighted predictions / + total weighted labels. + """ - def __init__(self, name: str = CALIBRATION_NAME): - """Initializes calibration. + def __init__(self, name: str = CALIBRATION_NAME): + """Initializes calibration. - Args: - name: Metric name. - """ - super().__init__( - metric_util.merge_per_key_computations(_calibration), name=name) + Args: + ---- + name: Metric name. + """ + super().__init__( + metric_util.merge_per_key_computations(_calibration), name=name + ) metric_types.register_metric(Calibration) @@ -173,127 +184,145 @@ def __init__(self, name: str = CALIBRATION_NAME): def _calibration( name: str = CALIBRATION_NAME, eval_config: Optional[config_pb2.EvalConfig] = None, - model_name: str = '', - output_name: str = '', + model_name: str = "", + output_name: str = "", sub_key: Optional[metric_types.SubKey] = None, aggregation_type: Optional[metric_types.AggregationType] = None, class_weights: Optional[Dict[int, float]] = None, - example_weighted: bool = False) -> metric_types.MetricComputations: - """Returns metric computations for calibration.""" - key = metric_types.MetricKey( - name=name, - model_name=model_name, - output_name=output_name, - sub_key=sub_key, - example_weighted=example_weighted) - - # Make sure weighted_labels_predictions_examples are calculated. - computations = _weighted_labels_predictions_examples( - eval_config=eval_config, - model_name=model_name, - output_name=output_name, - sub_key=sub_key, - aggregation_type=aggregation_type, - class_weights=class_weights, - example_weighted=example_weighted) - weighted_labels_predictions_key = computations[-1].keys[-1] - - def result( - metrics: Dict[metric_types.MetricKey, Any] - ) -> Dict[metric_types.MetricKey, Any]: - """Returns calibration.""" - metric = metrics[weighted_labels_predictions_key] - if np.isclose(metric.total_weighted_labels, 0.0): - value = float('nan') - else: - value = metric.total_weighted_predictions / metric.total_weighted_labels - - return {key: value} - - derived_computation = metric_types.DerivedMetricComputation( - keys=[key], result=result) - computations.append(derived_computation) - return computations + example_weighted: bool = False, +) -> metric_types.MetricComputations: + """Returns metric computations for calibration.""" + key = metric_types.MetricKey( + name=name, + model_name=model_name, + output_name=output_name, + sub_key=sub_key, + example_weighted=example_weighted, + ) + + # Make sure weighted_labels_predictions_examples are calculated. + computations = _weighted_labels_predictions_examples( + eval_config=eval_config, + model_name=model_name, + output_name=output_name, + sub_key=sub_key, + aggregation_type=aggregation_type, + class_weights=class_weights, + example_weighted=example_weighted, + ) + weighted_labels_predictions_key = computations[-1].keys[-1] + + def result( + metrics: Dict[metric_types.MetricKey, Any], + ) -> Dict[metric_types.MetricKey, Any]: + """Returns calibration.""" + metric = metrics[weighted_labels_predictions_key] + if np.isclose(metric.total_weighted_labels, 0.0): + value = float("nan") + else: + value = metric.total_weighted_predictions / metric.total_weighted_labels + + return {key: value} + + derived_computation = metric_types.DerivedMetricComputation( + keys=[key], result=result + ) + computations.append(derived_computation) + return computations def _weighted_labels_predictions_examples( name: str = _WEIGHTED_LABELS_PREDICTIONS_EXAMPLES_NAME, eval_config: Optional[config_pb2.EvalConfig] = None, - model_name: str = '', - output_name: str = '', + model_name: str = "", + output_name: str = "", sub_key: Optional[metric_types.SubKey] = None, aggregation_type: Optional[metric_types.AggregationType] = None, class_weights: Optional[Dict[int, float]] = None, - example_weighted: bool = False) -> metric_types.MetricComputations: - """Returns metric computations for weighted labels, predictions, and examples. - - Args: - name: Metric name. - eval_config: Eval config. - model_name: Optional model name (if multi-model evaluation). - output_name: Optional output name (if multi-output model type). - sub_key: Optional sub key. - aggregation_type: Optional aggregation type. - class_weights: Optional class weights to apply to multi-class / multi-label - labels and predictions prior to flattening (when micro averaging is used). - example_weighted: True if example weights should be applied. - """ - key = metric_types.MetricKey( - name=name, - model_name=model_name, - output_name=output_name, - sub_key=sub_key, - example_weighted=example_weighted) - return [ - metric_types.MetricComputation( - keys=[key], - preprocessors=None, # Use default - combiner=_WeightedLabelsPredictionsExamplesCombiner( - key, - eval_config=eval_config, - aggregation_type=aggregation_type, - class_weights=class_weights, - example_weighted=example_weighted)) - ] + example_weighted: bool = False, +) -> metric_types.MetricComputations: + """Returns metric computations for weighted labels, predictions, and examples. + + Args: + ---- + name: Metric name. + eval_config: Eval config. + model_name: Optional model name (if multi-model evaluation). + output_name: Optional output name (if multi-output model type). + sub_key: Optional sub key. + aggregation_type: Optional aggregation type. + class_weights: Optional class weights to apply to multi-class / multi-label + labels and predictions prior to flattening (when micro averaging is used). + example_weighted: True if example weights should be applied. + """ + key = metric_types.MetricKey( + name=name, + model_name=model_name, + output_name=output_name, + sub_key=sub_key, + example_weighted=example_weighted, + ) + return [ + metric_types.MetricComputation( + keys=[key], + preprocessors=None, # Use default + combiner=_WeightedLabelsPredictionsExamplesCombiner( + key, + eval_config=eval_config, + aggregation_type=aggregation_type, + class_weights=class_weights, + example_weighted=example_weighted, + ), + ) + ] class _WeightedLabelsPredictionsExamples: - """Total weighted labels, predictions, and examples.""" - __slots__ = [ - 'total_weighted_labels', 'total_weighted_predictions', - 'total_weighted_examples' - ] + """Total weighted labels, predictions, and examples.""" - def __init__(self): - """Initializes accumulator.""" - self.total_weighted_labels = 0.0 - self.total_weighted_predictions = 0.0 - self.total_weighted_examples = 0.0 + __slots__ = [ + "total_weighted_labels", + "total_weighted_predictions", + "total_weighted_examples", + ] + + def __init__(self): + """Initializes accumulator.""" + self.total_weighted_labels = 0.0 + self.total_weighted_predictions = 0.0 + self.total_weighted_examples = 0.0 class _WeightedLabelsPredictionsExamplesCombiner(beam.CombineFn): - """Computes weighted labels, predictions, and examples.""" - - def __init__(self, key: metric_types.MetricKey, - eval_config: Optional[config_pb2.EvalConfig], - aggregation_type: Optional[metric_types.AggregationType], - class_weights: Optional[Dict[int, - float]], example_weighted: bool): - self._key = key - self._eval_config = eval_config - self._aggregation_type = aggregation_type - self._class_weights = class_weights - self._example_weighted = example_weighted - - def create_accumulator(self) -> _WeightedLabelsPredictionsExamples: - return _WeightedLabelsPredictionsExamples() - - def add_input( - self, accumulator: _WeightedLabelsPredictionsExamples, - element: metric_types.StandardMetricInputs - ) -> _WeightedLabelsPredictionsExamples: - for label, prediction, example_weight in ( - metric_util.to_label_prediction_example_weight( + """Computes weighted labels, predictions, and examples.""" + + def __init__( + self, + key: metric_types.MetricKey, + eval_config: Optional[config_pb2.EvalConfig], + aggregation_type: Optional[metric_types.AggregationType], + class_weights: Optional[Dict[int, float]], + example_weighted: bool, + ): + self._key = key + self._eval_config = eval_config + self._aggregation_type = aggregation_type + self._class_weights = class_weights + self._example_weighted = example_weighted + + def create_accumulator(self) -> _WeightedLabelsPredictionsExamples: + return _WeightedLabelsPredictionsExamples() + + def add_input( + self, + accumulator: _WeightedLabelsPredictionsExamples, + element: metric_types.StandardMetricInputs, + ) -> _WeightedLabelsPredictionsExamples: + for ( + label, + prediction, + example_weight, + ) in metric_util.to_label_prediction_example_weight( element, eval_config=self._eval_config, model_name=self._key.model_name, @@ -302,38 +331,38 @@ def add_input( aggregation_type=self._aggregation_type, class_weights=self._class_weights, example_weighted=self._example_weighted, - allow_none=True)): - example_weight = float(example_weight) - accumulator.total_weighted_examples += example_weight - if label is not None and len(label): - if self._key.sub_key and self._key.sub_key.top_k is not None: - for i in range(self._key.sub_key.top_k): - weighted_label = label[i] * example_weight - else: - weighted_label = float(label) * example_weight - accumulator.total_weighted_labels += weighted_label - if prediction is not None and len(label): - if self._key.sub_key and self._key.sub_key.top_k is not None: - for i in range(self._key.sub_key.top_k): - weighted_prediction = prediction[i] * example_weight - else: - weighted_prediction = float(prediction) * example_weight - accumulator.total_weighted_predictions += weighted_prediction - return accumulator - - def merge_accumulators( - self, accumulators: Iterable[_WeightedLabelsPredictionsExamples] - ) -> _WeightedLabelsPredictionsExamples: - accumulators = iter(accumulators) - result = next(accumulators) - for accumulator in accumulators: - result.total_weighted_labels += accumulator.total_weighted_labels - result.total_weighted_predictions += ( - accumulator.total_weighted_predictions) - result.total_weighted_examples += accumulator.total_weighted_examples - return result - - def extract_output( - self, accumulator: _WeightedLabelsPredictionsExamples - ) -> Dict[metric_types.MetricKey, _WeightedLabelsPredictionsExamples]: - return {self._key: accumulator} + allow_none=True, + ): + example_weight = float(example_weight) + accumulator.total_weighted_examples += example_weight + if label is not None and len(label): + if self._key.sub_key and self._key.sub_key.top_k is not None: + for i in range(self._key.sub_key.top_k): + weighted_label = label[i] * example_weight + else: + weighted_label = float(label) * example_weight + accumulator.total_weighted_labels += weighted_label + if prediction is not None and len(label): + if self._key.sub_key and self._key.sub_key.top_k is not None: + for i in range(self._key.sub_key.top_k): + weighted_prediction = prediction[i] * example_weight + else: + weighted_prediction = float(prediction) * example_weight + accumulator.total_weighted_predictions += weighted_prediction + return accumulator + + def merge_accumulators( + self, accumulators: Iterable[_WeightedLabelsPredictionsExamples] + ) -> _WeightedLabelsPredictionsExamples: + accumulators = iter(accumulators) + result = next(accumulators) + for accumulator in accumulators: + result.total_weighted_labels += accumulator.total_weighted_labels + result.total_weighted_predictions += accumulator.total_weighted_predictions + result.total_weighted_examples += accumulator.total_weighted_examples + return result + + def extract_output( + self, accumulator: _WeightedLabelsPredictionsExamples + ) -> Dict[metric_types.MetricKey, _WeightedLabelsPredictionsExamples]: + return {self._key: accumulator} diff --git a/tensorflow_model_analysis/metrics/calibration_histogram.py b/tensorflow_model_analysis/metrics/calibration_histogram.py index 3bf4b8936e..96de0b4cb2 100644 --- a/tensorflow_model_analysis/metrics/calibration_histogram.py +++ b/tensorflow_model_analysis/metrics/calibration_histogram.py @@ -18,30 +18,31 @@ from typing import Dict, Iterable, List, Optional import apache_beam as beam -from tensorflow_model_analysis.metrics import metric_types -from tensorflow_model_analysis.metrics import metric_util + +from tensorflow_model_analysis.metrics import metric_types, metric_util from tensorflow_model_analysis.proto import config_pb2 -CALIBRATION_HISTOGRAM_NAME = '_calibration_histogram' +CALIBRATION_HISTOGRAM_NAME = "_calibration_histogram" DEFAULT_NUM_BUCKETS = 10000 @dataclasses.dataclass class Bucket: - """Bucket for calibration histogram.""" - bucket_id: int - weighted_labels: float - weighted_predictions: float - weighted_examples: float + """Bucket for calibration histogram.""" + + bucket_id: int + weighted_labels: float + weighted_predictions: float + weighted_examples: float - def merge(self, other: 'Bucket') -> None: - """Add a bucket with the same bucket_id, updating self in place.""" - if self.bucket_id != other.bucket_id: - raise ValueError('attempted to merge mismatched bucket_id') - self.weighted_labels += other.weighted_labels - self.weighted_predictions += other.weighted_predictions - self.weighted_examples += other.weighted_examples + def merge(self, other: "Bucket") -> None: + """Add a bucket with the same bucket_id, updating self in place.""" + if self.bucket_id != other.bucket_id: + raise ValueError("attempted to merge mismatched bucket_id") + self.weighted_labels += other.weighted_labels + self.weighted_predictions += other.weighted_predictions + self.weighted_examples += other.weighted_examples Histogram = List[Bucket] @@ -53,8 +54,8 @@ def calibration_histogram( right: Optional[float] = None, name: Optional[str] = None, eval_config: Optional[config_pb2.EvalConfig] = None, - model_name: str = '', - output_name: str = '', + model_name: str = "", + output_name: str = "", preprocessors: Optional[List[metric_types.Preprocessor]] = None, sub_key: Optional[metric_types.SubKey] = None, aggregation_type: Optional[metric_types.AggregationType] = None, @@ -63,83 +64,88 @@ def calibration_histogram( prediction_based_bucketing: bool = True, fractional_labels: Optional[bool] = None, ) -> metric_types.MetricComputations: - """Returns metric computations for calibration histogram. + """Returns metric computations for calibration histogram. - Args: - num_buckets: Number of buckets to use. Note that the actual number of - buckets will be num_buckets + 2 to account for the edge cases. - left: Start of predictions interval. - right: End of predictions interval. - name: Metric name. - eval_config: Eval config. - model_name: Optional model name (if multi-model evaluation). - output_name: Optional output name (if multi-output model type). - preprocessors: A tfma component for preprocessing data. - sub_key: Optional sub key. - aggregation_type: Optional aggregation type. - class_weights: Optional class weights to apply to multi-class / multi-label - labels and predictions prior to flattening (when micro averaging is used). - example_weighted: True if example weights should be applied. - prediction_based_bucketing: If true, create buckets based on predictions - else use labels to perform bucketing. - fractional_labels: If true, each incoming tuple of (label, prediction, and - example weight) will be split into two tuples as follows (where l, p, w - represent the resulting label, prediction, and example weight values): (1) - l = 0.0, p = prediction, and w = example_weight * (1.0 - label) (2) l = - 1.0, p = prediction, and w = example_weight * label If enabled, an - exception will be raised if labels are not within [0, 1]. The - implementation is such that tuples associated with a weight of zero are - not yielded. This means it is safe to enable fractional_labels even when - the labels only take on the values of 0.0 or 1.0. + Args: + ---- + num_buckets: Number of buckets to use. Note that the actual number of + buckets will be num_buckets + 2 to account for the edge cases. + left: Start of predictions interval. + right: End of predictions interval. + name: Metric name. + eval_config: Eval config. + model_name: Optional model name (if multi-model evaluation). + output_name: Optional output name (if multi-output model type). + preprocessors: A tfma component for preprocessing data. + sub_key: Optional sub key. + aggregation_type: Optional aggregation type. + class_weights: Optional class weights to apply to multi-class / multi-label + labels and predictions prior to flattening (when micro averaging is used). + example_weighted: True if example weights should be applied. + prediction_based_bucketing: If true, create buckets based on predictions + else use labels to perform bucketing. + fractional_labels: If true, each incoming tuple of (label, prediction, and + example weight) will be split into two tuples as follows (where l, p, w + represent the resulting label, prediction, and example weight values): (1) + l = 0.0, p = prediction, and w = example_weight * (1.0 - label) (2) l = + 1.0, p = prediction, and w = example_weight * label If enabled, an + exception will be raised if labels are not within [0, 1]. The + implementation is such that tuples associated with a weight of zero are + not yielded. This means it is safe to enable fractional_labels even when + the labels only take on the values of 0.0 or 1.0. - Returns: - MetricComputations for computing the histogram(s). - """ - if num_buckets is None: - num_buckets = DEFAULT_NUM_BUCKETS - if left is None: - left = 0.0 - if right is None: - right = 1.0 - if fractional_labels is None: - fractional_labels = (left == 0.0 and right == 1.0) - if name is None: - name_args = { - 'num_buckets': num_buckets, - 'left': left, - 'right': right, - 'fractional_labels': fractional_labels, - 'prediction_based_bucketing': prediction_based_bucketing, - } - if preprocessors: - name_args['preprocessors'] = tuple(p.name for p in preprocessors) - if class_weights: - name_args['class_weights'] = class_weights - name = metric_util.generate_private_name_from_arguments( - CALIBRATION_HISTOGRAM_NAME, **name_args + Returns: + ------- + MetricComputations for computing the histogram(s). + """ + if num_buckets is None: + num_buckets = DEFAULT_NUM_BUCKETS + if left is None: + left = 0.0 + if right is None: + right = 1.0 + if fractional_labels is None: + fractional_labels = left == 0.0 and right == 1.0 + if name is None: + name_args = { + "num_buckets": num_buckets, + "left": left, + "right": right, + "fractional_labels": fractional_labels, + "prediction_based_bucketing": prediction_based_bucketing, + } + if preprocessors: + name_args["preprocessors"] = tuple(p.name for p in preprocessors) + if class_weights: + name_args["class_weights"] = class_weights + name = metric_util.generate_private_name_from_arguments( + CALIBRATION_HISTOGRAM_NAME, **name_args + ) + key = metric_types.PlotKey( + name=name, + model_name=model_name, + output_name=output_name, + sub_key=sub_key, + example_weighted=example_weighted, ) - key = metric_types.PlotKey( - name=name, - model_name=model_name, - output_name=output_name, - sub_key=sub_key, - example_weighted=example_weighted) - return [ - metric_types.MetricComputation( - keys=[key], - preprocessors=preprocessors, - combiner=_CalibrationHistogramCombiner( - key=key, - eval_config=eval_config, - aggregation_type=aggregation_type, - class_weights=class_weights, - example_weighted=example_weighted, - num_buckets=num_buckets, - left=left, - right=right, - prediction_based_bucketing=prediction_based_bucketing, - fractional_labels=fractional_labels)) - ] + return [ + metric_types.MetricComputation( + keys=[key], + preprocessors=preprocessors, + combiner=_CalibrationHistogramCombiner( + key=key, + eval_config=eval_config, + aggregation_type=aggregation_type, + class_weights=class_weights, + example_weighted=example_weighted, + num_buckets=num_buckets, + left=left, + right=right, + prediction_based_bucketing=prediction_based_bucketing, + fractional_labels=fractional_labels, + ), + ) + ] # bucket_id to bucket. @@ -147,53 +153,62 @@ def calibration_histogram( class _CalibrationHistogramCombiner(beam.CombineFn): - """Creates histogram from labels, predictions, and example weights.""" + """Creates histogram from labels, predictions, and example weights.""" - def __init__(self, key: metric_types.PlotKey, - eval_config: Optional[config_pb2.EvalConfig], - aggregation_type: Optional[metric_types.AggregationType], - class_weights: Optional[Dict[int, - float]], example_weighted: bool, - num_buckets: int, left: float, right: float, - prediction_based_bucketing: bool, fractional_labels: bool): - self._key = key - self._eval_config = eval_config - self._aggregation_type = aggregation_type - self._class_weights = class_weights - self._example_weighted = example_weighted - self._num_buckets = num_buckets - self._left = left - self._range = right - left - self._fractional_labels = fractional_labels - self._prediction_based_bucketing = prediction_based_bucketing + def __init__( + self, + key: metric_types.PlotKey, + eval_config: Optional[config_pb2.EvalConfig], + aggregation_type: Optional[metric_types.AggregationType], + class_weights: Optional[Dict[int, float]], + example_weighted: bool, + num_buckets: int, + left: float, + right: float, + prediction_based_bucketing: bool, + fractional_labels: bool, + ): + self._key = key + self._eval_config = eval_config + self._aggregation_type = aggregation_type + self._class_weights = class_weights + self._example_weighted = example_weighted + self._num_buckets = num_buckets + self._left = left + self._range = right - left + self._fractional_labels = fractional_labels + self._prediction_based_bucketing = prediction_based_bucketing - def _bucket_index(self, prediction: float) -> int: - """Returns bucket index given prediction value. Values are truncated.""" - bucket_index = ( - (prediction - self._left) / self._range * self._num_buckets) + 1 - if bucket_index < 0: - return 0 - if bucket_index >= self._num_buckets + 1: - return self._num_buckets + 1 - return int(bucket_index) + def _bucket_index(self, prediction: float) -> int: + """Returns bucket index given prediction value. Values are truncated.""" + bucket_index = ((prediction - self._left) / self._range * self._num_buckets) + 1 + if bucket_index < 0: + return 0 + if bucket_index >= self._num_buckets + 1: + return self._num_buckets + 1 + return int(bucket_index) - def create_accumulator(self) -> _CalibrationHistogramCombinerAcctype: - # The number of accumulator (histogram) buckets is variable and depends on - # the number of distinct intervals that are matched during calls to - # add_inputs. This allows the histogram size to start small and gradually - # grow size during calls to merge until reaching the final histogram. - return {} + def create_accumulator(self) -> _CalibrationHistogramCombinerAcctype: + # The number of accumulator (histogram) buckets is variable and depends on + # the number of distinct intervals that are matched during calls to + # add_inputs. This allows the histogram size to start small and gradually + # grow size during calls to merge until reaching the final histogram. + return {} - def add_input( - self, accumulator: _CalibrationHistogramCombinerAcctype, - element: metric_types.StandardMetricInputs - ) -> _CalibrationHistogramCombinerAcctype: - # Note that in the case of top_k, if the aggregation type is not set then - # the non-top_k predictions will be set to float('-inf'), but the labels - # will remain unchanged. If aggregation type is set then both the - # predictions and labels will be truncated to only the top_k values. - for label, prediction, example_weight in ( - metric_util.to_label_prediction_example_weight( + def add_input( + self, + accumulator: _CalibrationHistogramCombinerAcctype, + element: metric_types.StandardMetricInputs, + ) -> _CalibrationHistogramCombinerAcctype: + # Note that in the case of top_k, if the aggregation type is not set then + # the non-top_k predictions will be set to float('-inf'), but the labels + # will remain unchanged. If aggregation type is set then both the + # predictions and labels will be truncated to only the top_k values. + for ( + label, + prediction, + example_weight, + ) in metric_util.to_label_prediction_example_weight( element, eval_config=self._eval_config, model_name=self._key.model_name, @@ -203,103 +218,119 @@ def add_input( flatten=True, aggregation_type=self._aggregation_type, class_weights=self._class_weights, - example_weighted=self._example_weighted)): - example_weight = float(example_weight) - label = float(label) - prediction = float(prediction) - weighted_label = label * example_weight - weighted_prediction = prediction * example_weight - if self._prediction_based_bucketing: - bucket_index = self._bucket_index(prediction) - else: - bucket_index = self._bucket_index(label) - if bucket_index not in accumulator: - accumulator[bucket_index] = Bucket(bucket_index, weighted_label, - weighted_prediction, example_weight) - else: - existing_bucket = accumulator[bucket_index] - existing_bucket.weighted_labels += weighted_label - existing_bucket.weighted_predictions += weighted_prediction - existing_bucket.weighted_examples += example_weight - return accumulator + example_weighted=self._example_weighted, + ): + example_weight = float(example_weight) + label = float(label) + prediction = float(prediction) + weighted_label = label * example_weight + weighted_prediction = prediction * example_weight + if self._prediction_based_bucketing: + bucket_index = self._bucket_index(prediction) + else: + bucket_index = self._bucket_index(label) + if bucket_index not in accumulator: + accumulator[bucket_index] = Bucket( + bucket_index, weighted_label, weighted_prediction, example_weight + ) + else: + existing_bucket = accumulator[bucket_index] + existing_bucket.weighted_labels += weighted_label + existing_bucket.weighted_predictions += weighted_prediction + existing_bucket.weighted_examples += example_weight + return accumulator - def merge_accumulators( - self, accumulators: Iterable[_CalibrationHistogramCombinerAcctype] - ) -> _CalibrationHistogramCombinerAcctype: - it = iter(accumulators) - result = next(it) - for acc in it: - for bucket_id, bucket in acc.items(): - if bucket_id not in result: - new_bucket = Bucket(bucket_id, 0.0, 0.0, 0.0) - new_bucket.merge(bucket) - result[bucket_id] = new_bucket - else: - result[bucket_id].merge(bucket) - return result + def merge_accumulators( + self, accumulators: Iterable[_CalibrationHistogramCombinerAcctype] + ) -> _CalibrationHistogramCombinerAcctype: + it = iter(accumulators) + result = next(it) + for acc in it: + for bucket_id, bucket in acc.items(): + if bucket_id not in result: + new_bucket = Bucket(bucket_id, 0.0, 0.0, 0.0) + new_bucket.merge(bucket) + result[bucket_id] = new_bucket + else: + result[bucket_id].merge(bucket) + return result - def extract_output( - self, accumulator: _CalibrationHistogramCombinerAcctype - ) -> Dict[metric_types.PlotKey, Histogram]: - accumulator = list(accumulator.values()) - accumulator.sort(key=operator.attrgetter('bucket_id')) - return {self._key: accumulator} + def extract_output( + self, accumulator: _CalibrationHistogramCombinerAcctype + ) -> Dict[metric_types.PlotKey, Histogram]: + accumulator = list(accumulator.values()) + accumulator.sort(key=operator.attrgetter("bucket_id")) + return {self._key: accumulator} -def rebin(thresholds: List[float], - histogram: Histogram, - num_buckets: int = DEFAULT_NUM_BUCKETS, - left: float = 0.0, - right: float = 1.0) -> Histogram: - """Applies new thresholds to an existing calibration histogram. +def rebin( + thresholds: List[float], + histogram: Histogram, + num_buckets: int = DEFAULT_NUM_BUCKETS, + left: float = 0.0, + right: float = 1.0, +) -> Histogram: + """Applies new thresholds to an existing calibration histogram. - Args: - thresholds: New thresholds to apply to the histogram. Must be in sorted - order, but need not be evenly spaced. - histogram: Existing calibration histogram. - num_buckets: Number of buckets in existing histogram. - left: Left boundary for existing histogram. - right: Right boundary for existing histogram. + Args: + ---- + thresholds: New thresholds to apply to the histogram. Must be in sorted + order, but need not be evenly spaced. + histogram: Existing calibration histogram. + num_buckets: Number of buckets in existing histogram. + left: Left boundary for existing histogram. + right: Right boundary for existing histogram. - Returns: - A histogram of len(thresholds) where the buckets with IDs (0, 1, 2, ...) - correspond to the intervals: - [thresholds[0], thresholds[1]), ... [thresholds[i], thresholds[i+1]) - Any values in buckets -inf or +inf will be added to the start and end - thresholds respectively. Unlike the input histogram empty buckets will be - returned. - """ - buckets = [] - offset = 0 - total_weighted_labels = 0.0 - total_weighted_predictions = 0.0 - total_weighted_examples = 0.0 - for bucket in histogram: - if bucket.bucket_id == 0: - pred = float('-inf') - elif bucket.bucket_id >= num_buckets + 1: - pred = float('inf') - else: - pred = (bucket.bucket_id - 1) / num_buckets * (right - left) + left - if offset + 1 < len(thresholds) and pred >= thresholds[offset + 1]: - buckets.append( - Bucket(offset, total_weighted_labels, total_weighted_predictions, - total_weighted_examples)) - offset += 1 - total_weighted_labels = 0.0 - total_weighted_predictions = 0.0 - total_weighted_examples = 0.0 - while offset + 1 < len(thresholds) and pred >= thresholds[offset + 1]: + Returns: + ------- + A histogram of len(thresholds) where the buckets with IDs (0, 1, 2, ...) + correspond to the intervals: + [thresholds[0], thresholds[1]), ... [thresholds[i], thresholds[i+1]) + Any values in buckets -inf or +inf will be added to the start and end + thresholds respectively. Unlike the input histogram empty buckets will be + returned. + """ + buckets = [] + offset = 0 + total_weighted_labels = 0.0 + total_weighted_predictions = 0.0 + total_weighted_examples = 0.0 + for bucket in histogram: + if bucket.bucket_id == 0: + pred = float("-inf") + elif bucket.bucket_id >= num_buckets + 1: + pred = float("inf") + else: + pred = (bucket.bucket_id - 1) / num_buckets * (right - left) + left + if offset + 1 < len(thresholds) and pred >= thresholds[offset + 1]: + buckets.append( + Bucket( + offset, + total_weighted_labels, + total_weighted_predictions, + total_weighted_examples, + ) + ) + offset += 1 + total_weighted_labels = 0.0 + total_weighted_predictions = 0.0 + total_weighted_examples = 0.0 + while offset + 1 < len(thresholds) and pred >= thresholds[offset + 1]: + buckets.append(Bucket(offset, 0.0, 0.0, 0.0)) + offset += 1 + total_weighted_labels += bucket.weighted_labels + total_weighted_predictions += bucket.weighted_predictions + total_weighted_examples += bucket.weighted_examples + buckets.append( + Bucket( + offset, + total_weighted_labels, + total_weighted_predictions, + total_weighted_examples, + ) + ) + offset += 1 + while offset < len(thresholds): buckets.append(Bucket(offset, 0.0, 0.0, 0.0)) offset += 1 - total_weighted_labels += bucket.weighted_labels - total_weighted_predictions += bucket.weighted_predictions - total_weighted_examples += bucket.weighted_examples - buckets.append( - Bucket(offset, total_weighted_labels, total_weighted_predictions, - total_weighted_examples)) - offset += 1 - while offset < len(thresholds): - buckets.append(Bucket(offset, 0.0, 0.0, 0.0)) - offset += 1 - return buckets + return buckets diff --git a/tensorflow_model_analysis/metrics/calibration_histogram_test.py b/tensorflow_model_analysis/metrics/calibration_histogram_test.py index f131cfc64b..f9c70f447f 100644 --- a/tensorflow_model_analysis/metrics/calibration_histogram_test.py +++ b/tensorflow_model_analysis/metrics/calibration_histogram_test.py @@ -14,409 +14,472 @@ """Tests for calibration histogram.""" import dataclasses + import apache_beam as beam -from apache_beam.testing import util import numpy as np import tensorflow as tf -from tensorflow_model_analysis.metrics import calibration_histogram -from tensorflow_model_analysis.metrics import metric_types -from tensorflow_model_analysis.metrics import metric_util +from apache_beam.testing import util + +from tensorflow_model_analysis.metrics import ( + calibration_histogram, + metric_types, + metric_util, +) from tensorflow_model_analysis.utils import test_util class CalibrationHistogramTest(test_util.TensorflowModelAnalysisTest): + def testCalibrationHistogram(self): + histogram = calibration_histogram.calibration_histogram(example_weighted=True)[ + 0 + ] - def testCalibrationHistogram(self): - histogram = calibration_histogram.calibration_histogram( - example_weighted=True)[0] - - example1 = { - 'labels': np.array([0.0]), - 'predictions': np.array([0.2]), - 'example_weights': np.array([1.0]) - } - example2 = { - 'labels': np.array([1.0]), - 'predictions': np.array([0.8]), - 'example_weights': np.array([2.0]) - } - example3 = { - 'labels': np.array([0.0]), - 'predictions': np.array([0.5]), - 'example_weights': np.array([3.0]) - } - example4 = { - 'labels': np.array([1.0]), - 'predictions': np.array([-0.1]), - 'example_weights': np.array([4.0]) - } - example5 = { - 'labels': np.array([1.0]), - 'predictions': np.array([0.5]), - 'example_weights': np.array([5.0]) - } - example6 = { - 'labels': np.array([1.0]), - 'predictions': np.array([0.8]), - 'example_weights': np.array([6.0]) - } - example7 = { - 'labels': np.array([0.0]), - 'predictions': np.array([0.2]), - 'example_weights': np.array([7.0]) - } - example8 = { - 'labels': np.array([1.0]), - 'predictions': np.array([1.1]), - 'example_weights': np.array([8.0]) - } + example1 = { + "labels": np.array([0.0]), + "predictions": np.array([0.2]), + "example_weights": np.array([1.0]), + } + example2 = { + "labels": np.array([1.0]), + "predictions": np.array([0.8]), + "example_weights": np.array([2.0]), + } + example3 = { + "labels": np.array([0.0]), + "predictions": np.array([0.5]), + "example_weights": np.array([3.0]), + } + example4 = { + "labels": np.array([1.0]), + "predictions": np.array([-0.1]), + "example_weights": np.array([4.0]), + } + example5 = { + "labels": np.array([1.0]), + "predictions": np.array([0.5]), + "example_weights": np.array([5.0]), + } + example6 = { + "labels": np.array([1.0]), + "predictions": np.array([0.8]), + "example_weights": np.array([6.0]), + } + example7 = { + "labels": np.array([0.0]), + "predictions": np.array([0.2]), + "example_weights": np.array([7.0]), + } + example8 = { + "labels": np.array([1.0]), + "predictions": np.array([1.1]), + "example_weights": np.array([8.0]), + } - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'Create' >> beam.Create([ - example1, example2, example3, example4, example5, example6, - example7, example8 - ]) - | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs) - | 'AddSlice' >> beam.Map(lambda x: ((), x)) - | 'ComputeHistogram' >> beam.CombinePerKey(histogram.combiner)) + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "Create" + >> beam.Create( + [ + example1, + example2, + example3, + example4, + example5, + example6, + example7, + example8, + ] + ) + | "Process" >> beam.Map(metric_util.to_standard_metric_inputs) + | "AddSlice" >> beam.Map(lambda x: ((), x)) + | "ComputeHistogram" >> beam.CombinePerKey(histogram.combiner) + ) - # pylint: enable=no-value-for-parameter + # pylint: enable=no-value-for-parameter - def check_result(got): - try: - self.assertLen(got, 1) - got_slice_key, got_plots = got[0] - self.assertEqual(got_slice_key, ()) - self.assertLen(got_plots, 1) - key = metric_types.PlotKey( - ( - '_calibration_histogram:fractional_labels=True,left=0.0,' - 'num_buckets=10000,prediction_based_bucketing=True,right=1.0' - ), - example_weighted=True, - ) - self.assertIn(key, got_plots) - got_histogram = got_plots[key] - self.assertLen(got_histogram, 5) - self.assertEqual( - got_histogram[0], - calibration_histogram.Bucket( - bucket_id=0, - weighted_labels=1.0 * 4.0, - weighted_predictions=-0.1 * 4.0, - weighted_examples=4.0)) - self.assertEqual( - got_histogram[1], - calibration_histogram.Bucket( - bucket_id=2001, - weighted_labels=0.0 + 0.0, - weighted_predictions=0.2 + 7 * 0.2, - weighted_examples=1.0 + 7.0)) - self.assertEqual( - got_histogram[2], - calibration_histogram.Bucket( - bucket_id=5001, - weighted_labels=1.0 * 5.0, - weighted_predictions=0.5 * 3.0 + 0.5 * 5.0, - weighted_examples=3.0 + 5.0)) - self.assertEqual( - got_histogram[3], - calibration_histogram.Bucket( - bucket_id=8001, - weighted_labels=1.0 * 2.0 + 1.0 * 6.0, - weighted_predictions=0.8 * 2.0 + 0.8 * 6.0, - weighted_examples=2.0 + 6.0)) - self.assertEqual( - got_histogram[4], - calibration_histogram.Bucket( - bucket_id=10001, - weighted_labels=1.0 * 8.0, - weighted_predictions=1.1 * 8.0, - weighted_examples=8.0)) + def check_result(got): + try: + self.assertLen(got, 1) + got_slice_key, got_plots = got[0] + self.assertEqual(got_slice_key, ()) + self.assertLen(got_plots, 1) + key = metric_types.PlotKey( + ( + "_calibration_histogram:fractional_labels=True,left=0.0," + "num_buckets=10000,prediction_based_bucketing=True,right=1.0" + ), + example_weighted=True, + ) + self.assertIn(key, got_plots) + got_histogram = got_plots[key] + self.assertLen(got_histogram, 5) + self.assertEqual( + got_histogram[0], + calibration_histogram.Bucket( + bucket_id=0, + weighted_labels=1.0 * 4.0, + weighted_predictions=-0.1 * 4.0, + weighted_examples=4.0, + ), + ) + self.assertEqual( + got_histogram[1], + calibration_histogram.Bucket( + bucket_id=2001, + weighted_labels=0.0 + 0.0, + weighted_predictions=0.2 + 7 * 0.2, + weighted_examples=1.0 + 7.0, + ), + ) + self.assertEqual( + got_histogram[2], + calibration_histogram.Bucket( + bucket_id=5001, + weighted_labels=1.0 * 5.0, + weighted_predictions=0.5 * 3.0 + 0.5 * 5.0, + weighted_examples=3.0 + 5.0, + ), + ) + self.assertEqual( + got_histogram[3], + calibration_histogram.Bucket( + bucket_id=8001, + weighted_labels=1.0 * 2.0 + 1.0 * 6.0, + weighted_predictions=0.8 * 2.0 + 0.8 * 6.0, + weighted_examples=2.0 + 6.0, + ), + ) + self.assertEqual( + got_histogram[4], + calibration_histogram.Bucket( + bucket_id=10001, + weighted_labels=1.0 * 8.0, + weighted_predictions=1.1 * 8.0, + weighted_examples=8.0, + ), + ) - except AssertionError as err: - raise util.BeamAssertException(err) + except AssertionError as err: + raise util.BeamAssertException(err) - util.assert_that(result, check_result, label='result') + util.assert_that(result, check_result, label="result") - def testCalibrationHistogramWithK(self): - histogram = calibration_histogram.calibration_histogram( - sub_key=metric_types.SubKey(k=2), example_weighted=True)[0] + def testCalibrationHistogramWithK(self): + histogram = calibration_histogram.calibration_histogram( + sub_key=metric_types.SubKey(k=2), example_weighted=True + )[0] - example1 = { - 'labels': np.array([2]), - 'predictions': np.array([0.2, 0.05, 0.1, 0.05]), - 'example_weights': np.array([1.0]) - } - example2 = { - 'labels': np.array([2]), - 'predictions': np.array([0.7, 0.1, 0.8, 0.5]), - 'example_weights': np.array([2.0]) - } - example3 = { - 'labels': np.array([3]), - 'predictions': np.array([0.1, 0.5, 0.3, 0.4]), - 'example_weights': np.array([3.0]) - } - example4 = { - 'labels': np.array([0]), - 'predictions': np.array([-0.1, -0.2, -0.7, -0.4]), - 'example_weights': np.array([4.0]) - } - example5 = { - 'labels': np.array([1]), - 'predictions': np.array([0.3, 0.5, 0.0, 0.4]), - 'example_weights': np.array([5.0]) - } - example6 = { - 'labels': np.array([2]), - 'predictions': np.array([0.1, 0.1, 0.8, 0.7]), - 'example_weights': np.array([6.0]) - } - example7 = { - 'labels': np.array([2]), - 'predictions': np.array([0.0, 0.2, 0.1, 0.0]), - 'example_weights': np.array([7.0]) - } - example8 = { - 'labels': np.array([0]), - 'predictions': np.array([1.1, 0.3, 1.05, 0.2]), - 'example_weights': np.array([8.0]) - } + example1 = { + "labels": np.array([2]), + "predictions": np.array([0.2, 0.05, 0.1, 0.05]), + "example_weights": np.array([1.0]), + } + example2 = { + "labels": np.array([2]), + "predictions": np.array([0.7, 0.1, 0.8, 0.5]), + "example_weights": np.array([2.0]), + } + example3 = { + "labels": np.array([3]), + "predictions": np.array([0.1, 0.5, 0.3, 0.4]), + "example_weights": np.array([3.0]), + } + example4 = { + "labels": np.array([0]), + "predictions": np.array([-0.1, -0.2, -0.7, -0.4]), + "example_weights": np.array([4.0]), + } + example5 = { + "labels": np.array([1]), + "predictions": np.array([0.3, 0.5, 0.0, 0.4]), + "example_weights": np.array([5.0]), + } + example6 = { + "labels": np.array([2]), + "predictions": np.array([0.1, 0.1, 0.8, 0.7]), + "example_weights": np.array([6.0]), + } + example7 = { + "labels": np.array([2]), + "predictions": np.array([0.0, 0.2, 0.1, 0.0]), + "example_weights": np.array([7.0]), + } + example8 = { + "labels": np.array([0]), + "predictions": np.array([1.1, 0.3, 1.05, 0.2]), + "example_weights": np.array([8.0]), + } - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'Create' >> beam.Create([ - example1, example2, example3, example4, example5, example6, - example7, example8 - ]) - | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs) - | 'AddSlice' >> beam.Map(lambda x: ((), x)) - | 'ComputeHistogram' >> beam.CombinePerKey(histogram.combiner)) + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "Create" + >> beam.Create( + [ + example1, + example2, + example3, + example4, + example5, + example6, + example7, + example8, + ] + ) + | "Process" >> beam.Map(metric_util.to_standard_metric_inputs) + | "AddSlice" >> beam.Map(lambda x: ((), x)) + | "ComputeHistogram" >> beam.CombinePerKey(histogram.combiner) + ) - # pylint: enable=no-value-for-parameter + # pylint: enable=no-value-for-parameter - def check_result(got): - try: - self.assertLen(got, 1) - got_slice_key, got_plots = got[0] - self.assertEqual(got_slice_key, ()) - self.assertLen(got_plots, 1) - key = metric_types.PlotKey( - name=( - '_calibration_histogram:fractional_labels=True,left=0.0,' - 'num_buckets=10000,prediction_based_bucketing=True,right=1.0' - ), - sub_key=metric_types.SubKey(k=2), - example_weighted=True, - ) - self.assertIn(key, got_plots) - got_histogram = got_plots[key] - self.assertLen(got_histogram, 5) - self.assertEqual( - got_histogram[0], - calibration_histogram.Bucket( - bucket_id=0, - weighted_labels=0.0 * 4.0, - weighted_predictions=-0.2 * 4.0, - weighted_examples=4.0)) - self.assertEqual( - got_histogram[1], - calibration_histogram.Bucket( - bucket_id=1001, - weighted_labels=1.0 + 7 * 1.0, - weighted_predictions=0.1 + 7 * 0.1, - weighted_examples=1.0 + 7.0)) - self.assertEqual( - got_histogram[2], - calibration_histogram.Bucket( - bucket_id=4001, - weighted_labels=1.0 * 3.0 + 0.0 * 5.0, - weighted_predictions=0.4 * 3.0 + 0.4 * 5.0, - weighted_examples=3.0 + 5.0)) - self.assertEqual( - got_histogram[3], - calibration_histogram.Bucket( - bucket_id=7001, - weighted_labels=0.0 * 2.0 + 0.0 * 6.0, - weighted_predictions=0.7 * 2.0 + 0.7 * 6.0, - weighted_examples=2.0 + 6.0)) - self.assertEqual( - got_histogram[4], - calibration_histogram.Bucket( - bucket_id=10001, - weighted_labels=0.0 * 8.0, - weighted_predictions=1.05 * 8.0, - weighted_examples=8.0)) + def check_result(got): + try: + self.assertLen(got, 1) + got_slice_key, got_plots = got[0] + self.assertEqual(got_slice_key, ()) + self.assertLen(got_plots, 1) + key = metric_types.PlotKey( + name=( + "_calibration_histogram:fractional_labels=True,left=0.0," + "num_buckets=10000,prediction_based_bucketing=True,right=1.0" + ), + sub_key=metric_types.SubKey(k=2), + example_weighted=True, + ) + self.assertIn(key, got_plots) + got_histogram = got_plots[key] + self.assertLen(got_histogram, 5) + self.assertEqual( + got_histogram[0], + calibration_histogram.Bucket( + bucket_id=0, + weighted_labels=0.0 * 4.0, + weighted_predictions=-0.2 * 4.0, + weighted_examples=4.0, + ), + ) + self.assertEqual( + got_histogram[1], + calibration_histogram.Bucket( + bucket_id=1001, + weighted_labels=1.0 + 7 * 1.0, + weighted_predictions=0.1 + 7 * 0.1, + weighted_examples=1.0 + 7.0, + ), + ) + self.assertEqual( + got_histogram[2], + calibration_histogram.Bucket( + bucket_id=4001, + weighted_labels=1.0 * 3.0 + 0.0 * 5.0, + weighted_predictions=0.4 * 3.0 + 0.4 * 5.0, + weighted_examples=3.0 + 5.0, + ), + ) + self.assertEqual( + got_histogram[3], + calibration_histogram.Bucket( + bucket_id=7001, + weighted_labels=0.0 * 2.0 + 0.0 * 6.0, + weighted_predictions=0.7 * 2.0 + 0.7 * 6.0, + weighted_examples=2.0 + 6.0, + ), + ) + self.assertEqual( + got_histogram[4], + calibration_histogram.Bucket( + bucket_id=10001, + weighted_labels=0.0 * 8.0, + weighted_predictions=1.05 * 8.0, + weighted_examples=8.0, + ), + ) - except AssertionError as err: - raise util.BeamAssertException(err) + except AssertionError as err: + raise util.BeamAssertException(err) - util.assert_that(result, check_result, label='result') + util.assert_that(result, check_result, label="result") - def testTopKCalibrationHistogramWithTopK(self): - histogram = calibration_histogram.calibration_histogram( - sub_key=metric_types.SubKey(top_k=2), example_weighted=True)[0] + def testTopKCalibrationHistogramWithTopK(self): + histogram = calibration_histogram.calibration_histogram( + sub_key=metric_types.SubKey(top_k=2), example_weighted=True + )[0] - example1 = { - 'labels': np.array([2]), - 'predictions': np.array([0.2, 0.05, 0.5, 0.05]), - 'example_weights': np.array([1.0]) - } - example2 = { - 'labels': np.array([2]), - 'predictions': np.array([0.8, 0.1, 0.8, 0.5]), - 'example_weights': np.array([2.0]) - } - example3 = { - 'labels': np.array([3]), - 'predictions': np.array([0.2, 0.5, 0.1, 0.1]), - 'example_weights': np.array([3.0]) - } - example4 = { - 'labels': np.array([0]), - 'predictions': np.array([-0.1, 1.1, -0.7, -0.4]), - 'example_weights': np.array([4.0]) - } + example1 = { + "labels": np.array([2]), + "predictions": np.array([0.2, 0.05, 0.5, 0.05]), + "example_weights": np.array([1.0]), + } + example2 = { + "labels": np.array([2]), + "predictions": np.array([0.8, 0.1, 0.8, 0.5]), + "example_weights": np.array([2.0]), + } + example3 = { + "labels": np.array([3]), + "predictions": np.array([0.2, 0.5, 0.1, 0.1]), + "example_weights": np.array([3.0]), + } + example4 = { + "labels": np.array([0]), + "predictions": np.array([-0.1, 1.1, -0.7, -0.4]), + "example_weights": np.array([4.0]), + } - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'Create' >> beam.Create([example1, example2, example3, example4]) - | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs) - | 'AddSlice' >> beam.Map(lambda x: ((), x)) - | 'ComputeHistogram' >> beam.CombinePerKey(histogram.combiner)) + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "Create" >> beam.Create([example1, example2, example3, example4]) + | "Process" >> beam.Map(metric_util.to_standard_metric_inputs) + | "AddSlice" >> beam.Map(lambda x: ((), x)) + | "ComputeHistogram" >> beam.CombinePerKey(histogram.combiner) + ) - # pylint: enable=no-value-for-parameter + # pylint: enable=no-value-for-parameter - def check_result(got): - try: - self.assertLen(got, 1) - got_slice_key, got_plots = got[0] - self.assertEqual(got_slice_key, ()) - self.assertLen(got_plots, 1) - key = metric_types.PlotKey( - name=( - '_calibration_histogram:fractional_labels=True,left=0.0,' - 'num_buckets=10000,prediction_based_bucketing=True,right=1.0' - ), - sub_key=metric_types.SubKey(top_k=2), - example_weighted=True, - ) - self.assertIn(key, got_plots) - got_histogram = got_plots[key] - self.assertLen(got_histogram, 5) - self.assertEqual( - got_histogram[0], - calibration_histogram.Bucket( - bucket_id=0, - weighted_labels=3.0 + 4.0, - weighted_predictions=(2 * 1.0 * float('-inf') + - 2 * 2.0 * float('-inf') + - 2 * 3.0 * float('-inf') + - 2 * 4.0 * float('-inf') + -0.1 * 4.0), - weighted_examples=(1.0 * 2.0 + 2.0 * 2.0 + 3.0 * 2.0 + - 4.0 * 3.0))) - self.assertEqual( - got_histogram[1], - calibration_histogram.Bucket( - bucket_id=2001, - weighted_labels=0.0 + 0.0, - weighted_predictions=0.2 + 3 * 0.2, - weighted_examples=1.0 + 3.0)) - self.assertEqual( - got_histogram[2], - calibration_histogram.Bucket( - bucket_id=5001, - weighted_labels=1.0 + 0.0 * 3.0, - weighted_predictions=0.5 * 1.0 + 0.5 * 3.0, - weighted_examples=1.0 + 3.0)) - self.assertEqual( - got_histogram[3], - calibration_histogram.Bucket( - bucket_id=8001, - weighted_labels=0.0 * 2.0 + 1.0 * 2.0, - weighted_predictions=0.8 * 2.0 + 0.8 * 2.0, - weighted_examples=2.0 + 2.0)) - self.assertEqual( - got_histogram[4], - calibration_histogram.Bucket( - bucket_id=10001, - weighted_labels=0.0 * 4.0, - weighted_predictions=1.1 * 4.0, - weighted_examples=4.0)) + def check_result(got): + try: + self.assertLen(got, 1) + got_slice_key, got_plots = got[0] + self.assertEqual(got_slice_key, ()) + self.assertLen(got_plots, 1) + key = metric_types.PlotKey( + name=( + "_calibration_histogram:fractional_labels=True,left=0.0," + "num_buckets=10000,prediction_based_bucketing=True,right=1.0" + ), + sub_key=metric_types.SubKey(top_k=2), + example_weighted=True, + ) + self.assertIn(key, got_plots) + got_histogram = got_plots[key] + self.assertLen(got_histogram, 5) + self.assertEqual( + got_histogram[0], + calibration_histogram.Bucket( + bucket_id=0, + weighted_labels=3.0 + 4.0, + weighted_predictions=( + 2 * 1.0 * float("-inf") + + 2 * 2.0 * float("-inf") + + 2 * 3.0 * float("-inf") + + 2 * 4.0 * float("-inf") + + -0.1 * 4.0 + ), + weighted_examples=( + 1.0 * 2.0 + 2.0 * 2.0 + 3.0 * 2.0 + 4.0 * 3.0 + ), + ), + ) + self.assertEqual( + got_histogram[1], + calibration_histogram.Bucket( + bucket_id=2001, + weighted_labels=0.0 + 0.0, + weighted_predictions=0.2 + 3 * 0.2, + weighted_examples=1.0 + 3.0, + ), + ) + self.assertEqual( + got_histogram[2], + calibration_histogram.Bucket( + bucket_id=5001, + weighted_labels=1.0 + 0.0 * 3.0, + weighted_predictions=0.5 * 1.0 + 0.5 * 3.0, + weighted_examples=1.0 + 3.0, + ), + ) + self.assertEqual( + got_histogram[3], + calibration_histogram.Bucket( + bucket_id=8001, + weighted_labels=0.0 * 2.0 + 1.0 * 2.0, + weighted_predictions=0.8 * 2.0 + 0.8 * 2.0, + weighted_examples=2.0 + 2.0, + ), + ) + self.assertEqual( + got_histogram[4], + calibration_histogram.Bucket( + bucket_id=10001, + weighted_labels=0.0 * 4.0, + weighted_predictions=1.1 * 4.0, + weighted_examples=4.0, + ), + ) - except AssertionError as err: - raise util.BeamAssertException(err) + except AssertionError as err: + raise util.BeamAssertException(err) - util.assert_that(result, check_result, label='result') + util.assert_that(result, check_result, label="result") - def testRebin(self): - # [Bucket(0, -1, -0.01), Bucket(1, 0, 0) ... Bucket(101, 101, 1.01)] - histogram = [calibration_histogram.Bucket(0, -1, -.01, 1.0)] - for i in range(100): - histogram.append(calibration_histogram.Bucket(i + 1, i, i * .01, 1.0)) - histogram.append(calibration_histogram.Bucket(101, 101, 1.01, 1.0)) - # [-1e-7, 0.0, 0.1, ..., 0.9, 1.0, 1.0+1e-7] - thresholds = [-1e-7] + [i * 1.0 / 10 for i in range(11)] + [1.0 + 1e-7] - got = calibration_histogram.rebin(thresholds, histogram, 100) + def testRebin(self): + # [Bucket(0, -1, -0.01), Bucket(1, 0, 0) ... Bucket(101, 101, 1.01)] + histogram = [calibration_histogram.Bucket(0, -1, -0.01, 1.0)] + for i in range(100): + histogram.append(calibration_histogram.Bucket(i + 1, i, i * 0.01, 1.0)) + histogram.append(calibration_histogram.Bucket(101, 101, 1.01, 1.0)) + # [-1e-7, 0.0, 0.1, ..., 0.9, 1.0, 1.0+1e-7] + thresholds = [-1e-7] + [i * 1.0 / 10 for i in range(11)] + [1.0 + 1e-7] + got = calibration_histogram.rebin(thresholds, histogram, 100) - # labels = (10 * (i-1)) + (1 + 2 + 3 + ... + 9) - expected = [ - calibration_histogram.Bucket(0, -1, -0.01, 1.0), - calibration_histogram.Bucket(1, 45.0, 0.45, 10.0), - calibration_histogram.Bucket(2, 145.0, 1.45, 10.0), - calibration_histogram.Bucket(3, 245.0, 2.45, 10.0), - calibration_histogram.Bucket(4, 345.0, 3.45, 10.0), - calibration_histogram.Bucket(5, 445.0, 4.45, 10.0), - calibration_histogram.Bucket(6, 545.0, 5.45, 10.0), - calibration_histogram.Bucket(7, 645.0, 6.45, 10.0), - calibration_histogram.Bucket(8, 745.0, 7.45, 10.0), - calibration_histogram.Bucket(9, 845.0, 8.45, 10.0), - calibration_histogram.Bucket(10, 945.0, 9.45, 10.0), - calibration_histogram.Bucket(11, 0.0, 0.0, 0.0), - calibration_histogram.Bucket(12, 101.0, 1.01, 1.0), - ] - self.assertLen(got, len(expected)) - for i in range(len(got)): - self.assertSequenceAlmostEqual( - dataclasses.astuple(got[i]), dataclasses.astuple(expected[i])) + # labels = (10 * (i-1)) + (1 + 2 + 3 + ... + 9) + expected = [ + calibration_histogram.Bucket(0, -1, -0.01, 1.0), + calibration_histogram.Bucket(1, 45.0, 0.45, 10.0), + calibration_histogram.Bucket(2, 145.0, 1.45, 10.0), + calibration_histogram.Bucket(3, 245.0, 2.45, 10.0), + calibration_histogram.Bucket(4, 345.0, 3.45, 10.0), + calibration_histogram.Bucket(5, 445.0, 4.45, 10.0), + calibration_histogram.Bucket(6, 545.0, 5.45, 10.0), + calibration_histogram.Bucket(7, 645.0, 6.45, 10.0), + calibration_histogram.Bucket(8, 745.0, 7.45, 10.0), + calibration_histogram.Bucket(9, 845.0, 8.45, 10.0), + calibration_histogram.Bucket(10, 945.0, 9.45, 10.0), + calibration_histogram.Bucket(11, 0.0, 0.0, 0.0), + calibration_histogram.Bucket(12, 101.0, 1.01, 1.0), + ] + self.assertLen(got, len(expected)) + for i in range(len(got)): + self.assertSequenceAlmostEqual( + dataclasses.astuple(got[i]), dataclasses.astuple(expected[i]) + ) - def testRebinWithSparseData(self): - histogram = [ - calibration_histogram.Bucket(4, 5.0, .25, 5.0), # pred = .05 - calibration_histogram.Bucket(61, 60.0, 36.0, 60.0), # pred = .6 - calibration_histogram.Bucket(70, 69.0, 47.61, 69.0), # pred = .69 - calibration_histogram.Bucket(100, 99.0, 98.01, 99.0) # pred = .99 - ] - # [0, 0.1, ..., 0.9, 1.0] - thresholds = [i * 1.0 / 10 for i in range(0, 11)] - got = calibration_histogram.rebin(thresholds, histogram, 100) + def testRebinWithSparseData(self): + histogram = [ + calibration_histogram.Bucket(4, 5.0, 0.25, 5.0), # pred = .05 + calibration_histogram.Bucket(61, 60.0, 36.0, 60.0), # pred = .6 + calibration_histogram.Bucket(70, 69.0, 47.61, 69.0), # pred = .69 + calibration_histogram.Bucket(100, 99.0, 98.01, 99.0), # pred = .99 + ] + # [0, 0.1, ..., 0.9, 1.0] + thresholds = [i * 1.0 / 10 for i in range(0, 11)] + got = calibration_histogram.rebin(thresholds, histogram, 100) - expected = [ - calibration_histogram.Bucket(0, 5.0, 0.25, 5.0), - calibration_histogram.Bucket(1, 0.0, 0.0, 0.0), - calibration_histogram.Bucket(2, 0.0, 0.0, 0.0), - calibration_histogram.Bucket(3, 0.0, 0.0, 0.0), - calibration_histogram.Bucket(4, 0.0, 0.0, 0.0), - calibration_histogram.Bucket(5, 0.0, 0.0, 0.0), - calibration_histogram.Bucket(6, 129.0, 83.61, 129.0), - calibration_histogram.Bucket(7, 0.0, 0.0, 0.0), - calibration_histogram.Bucket(8, 0.0, 0.0, 0.0), - calibration_histogram.Bucket(9, 99.0, 98.01, 99.0), - calibration_histogram.Bucket(10, 0.0, 0.0, 0.0), - ] - self.assertLen(got, len(expected)) - for i in range(len(got)): - self.assertSequenceAlmostEqual( - dataclasses.astuple(got[i]), dataclasses.astuple(expected[i])) + expected = [ + calibration_histogram.Bucket(0, 5.0, 0.25, 5.0), + calibration_histogram.Bucket(1, 0.0, 0.0, 0.0), + calibration_histogram.Bucket(2, 0.0, 0.0, 0.0), + calibration_histogram.Bucket(3, 0.0, 0.0, 0.0), + calibration_histogram.Bucket(4, 0.0, 0.0, 0.0), + calibration_histogram.Bucket(5, 0.0, 0.0, 0.0), + calibration_histogram.Bucket(6, 129.0, 83.61, 129.0), + calibration_histogram.Bucket(7, 0.0, 0.0, 0.0), + calibration_histogram.Bucket(8, 0.0, 0.0, 0.0), + calibration_histogram.Bucket(9, 99.0, 98.01, 99.0), + calibration_histogram.Bucket(10, 0.0, 0.0, 0.0), + ] + self.assertLen(got, len(expected)) + for i in range(len(got)): + self.assertSequenceAlmostEqual( + dataclasses.astuple(got[i]), dataclasses.astuple(expected[i]) + ) -if __name__ == '__main__': - tf.test.main() +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow_model_analysis/metrics/calibration_plot.py b/tensorflow_model_analysis/metrics/calibration_plot.py index 979217b7b2..4884a58ead 100644 --- a/tensorflow_model_analysis/metrics/calibration_plot.py +++ b/tensorflow_model_analysis/metrics/calibration_plot.py @@ -15,83 +15,90 @@ from typing import Any, Dict, List, Optional, Tuple, Union -from tensorflow_model_analysis.metrics import calibration_histogram -from tensorflow_model_analysis.metrics import metric_types -from tensorflow_model_analysis.metrics import metric_util -from tensorflow_model_analysis.proto import config_pb2 -from tensorflow_model_analysis.proto import metrics_for_slice_pb2 -from tensorflow_model_analysis.utils import model_util - from tensorflow_metadata.proto.v0 import schema_pb2 +from tensorflow_model_analysis.metrics import ( + calibration_histogram, + metric_types, + metric_util, +) +from tensorflow_model_analysis.proto import config_pb2, metrics_for_slice_pb2 +from tensorflow_model_analysis.utils import model_util + DEFAULT_NUM_BUCKETS = 1000 -CALIBRATION_PLOT_NAME = 'calibration_plot' +CALIBRATION_PLOT_NAME = "calibration_plot" class CalibrationPlot(metric_types.Metric): - """Calibration plot.""" - - def __init__(self, - num_buckets: int = DEFAULT_NUM_BUCKETS, - left: Optional[float] = None, - right: Optional[float] = None, - name: str = CALIBRATION_PLOT_NAME): - """Initializes calibration plot. - - Args: - num_buckets: Number of buckets to use when creating the plot. Defaults to - 1000. - left: Left boundary of plot. Defaults to 0.0 when a schema is not - provided. - right: Right boundary of plot. Defaults to 1.0 when a schema is not - provided. - name: Plot name. - """ - super().__init__( - metric_util.merge_per_key_computations(_calibration_plot), - num_buckets=num_buckets, - left=left, - right=right, - name=name) + """Calibration plot.""" + + def __init__( + self, + num_buckets: int = DEFAULT_NUM_BUCKETS, + left: Optional[float] = None, + right: Optional[float] = None, + name: str = CALIBRATION_PLOT_NAME, + ): + """Initializes calibration plot. + + Args: + ---- + num_buckets: Number of buckets to use when creating the plot. Defaults to + 1000. + left: Left boundary of plot. Defaults to 0.0 when a schema is not + provided. + right: Right boundary of plot. Defaults to 1.0 when a schema is not + provided. + name: Plot name. + """ + super().__init__( + metric_util.merge_per_key_computations(_calibration_plot), + num_buckets=num_buckets, + left=left, + right=right, + name=name, + ) metric_types.register_metric(CalibrationPlot) def _find_label_domain( - eval_config: config_pb2.EvalConfig, schema: schema_pb2.Schema, - model_name: str, output_name: str + eval_config: config_pb2.EvalConfig, + schema: schema_pb2.Schema, + model_name: str, + output_name: str, ) -> Tuple[Optional[Union[int, float]], Optional[Union[int, float]]]: - """Find the min and max value for the label_key for this model / output.""" - model_spec = model_util.get_model_spec(eval_config, model_name) - if not model_spec: - return None, None - label_key = model_util.get_label_key(model_spec, output_name) - if not label_key: - return None, None - label_schema = None - for feature_schema in schema.feature: - if feature_schema.name == label_key: - label_schema = feature_schema - break - if label_schema is None: - return None, None - - # Find the domain - if label_schema.HasField('int_domain'): - label_domain = label_schema.int_domain - elif label_schema.HasField('float_domain'): - label_domain = label_schema.float_domain - else: - return None, None - - left, right = None, None - if label_domain.HasField('min'): - left = float(label_domain.min) - if label_domain.HasField('max'): - right = float(label_domain.max) - return left, right + """Find the min and max value for the label_key for this model / output.""" + model_spec = model_util.get_model_spec(eval_config, model_name) + if not model_spec: + return None, None + label_key = model_util.get_label_key(model_spec, output_name) + if not label_key: + return None, None + label_schema = None + for feature_schema in schema.feature: + if feature_schema.name == label_key: + label_schema = feature_schema + break + if label_schema is None: + return None, None + + # Find the domain + if label_schema.HasField("int_domain"): + label_domain = label_schema.int_domain + elif label_schema.HasField("float_domain"): + label_domain = label_schema.float_domain + else: + return None, None + + left, right = None, None + if label_domain.HasField("min"): + left = float(label_domain.min) + if label_domain.HasField("max"): + right = float(label_domain.max) + return left, right def _calibration_plot( @@ -101,88 +108,95 @@ def _calibration_plot( name: str = CALIBRATION_PLOT_NAME, eval_config: Optional[config_pb2.EvalConfig] = None, schema: Optional[schema_pb2.Schema] = None, - model_name: str = '', - output_name: str = '', + model_name: str = "", + output_name: str = "", sub_key: Optional[metric_types.SubKey] = None, aggregation_type: Optional[metric_types.AggregationType] = None, class_weights: Optional[Dict[int, float]] = None, - example_weighted: bool = False) -> metric_types.MetricComputations: - """Returns metric computations for calibration plot.""" - key = metric_types.PlotKey( - name=name, - model_name=model_name, - output_name=output_name, - sub_key=sub_key, - example_weighted=example_weighted) - - label_left, label_right = None, None - if (left is None or right is None) and eval_config and schema: - label_left, label_right = _find_label_domain(eval_config, schema, - model_name, output_name) - if left is None: - left = label_left if label_left is not None else 0.0 - if right is None: - right = label_right if label_right is not None else 1.0 - - # Make sure calibration histogram is calculated. Note we are using the default - # number of buckets assigned to the histogram instead of the value used for - # the plots just in case the computation is shared with other metrics and - # plots that need higher preicion. It will be downsampled later. - computations = calibration_histogram.calibration_histogram( - eval_config=eval_config, - model_name=model_name, - output_name=output_name, - sub_key=sub_key, - left=left, - right=right, - aggregation_type=aggregation_type, - class_weights=class_weights, - example_weighted=example_weighted) - histogram_key = computations[-1].keys[-1] - - def result( - metrics: Dict[metric_types.MetricKey, Any] - ) -> Dict[metric_types.MetricKey, Any]: - thresholds = [ - left + i * (right - left) / num_buckets for i in range(num_buckets + 1) - ] - thresholds = [float('-inf')] + thresholds - histogram = calibration_histogram.rebin( - thresholds, metrics[histogram_key], left=left, right=right) - return {key: _to_proto(thresholds, histogram)} - - derived_computation = metric_types.DerivedMetricComputation( - keys=[key], result=result) - computations.append(derived_computation) - return computations + example_weighted: bool = False, +) -> metric_types.MetricComputations: + """Returns metric computations for calibration plot.""" + key = metric_types.PlotKey( + name=name, + model_name=model_name, + output_name=output_name, + sub_key=sub_key, + example_weighted=example_weighted, + ) + + label_left, label_right = None, None + if (left is None or right is None) and eval_config and schema: + label_left, label_right = _find_label_domain( + eval_config, schema, model_name, output_name + ) + if left is None: + left = label_left if label_left is not None else 0.0 + if right is None: + right = label_right if label_right is not None else 1.0 + + # Make sure calibration histogram is calculated. Note we are using the default + # number of buckets assigned to the histogram instead of the value used for + # the plots just in case the computation is shared with other metrics and + # plots that need higher preicion. It will be downsampled later. + computations = calibration_histogram.calibration_histogram( + eval_config=eval_config, + model_name=model_name, + output_name=output_name, + sub_key=sub_key, + left=left, + right=right, + aggregation_type=aggregation_type, + class_weights=class_weights, + example_weighted=example_weighted, + ) + histogram_key = computations[-1].keys[-1] + + def result( + metrics: Dict[metric_types.MetricKey, Any], + ) -> Dict[metric_types.MetricKey, Any]: + thresholds = [ + left + i * (right - left) / num_buckets for i in range(num_buckets + 1) + ] + thresholds = [float("-inf")] + thresholds + histogram = calibration_histogram.rebin( + thresholds, metrics[histogram_key], left=left, right=right + ) + return {key: _to_proto(thresholds, histogram)} + + derived_computation = metric_types.DerivedMetricComputation( + keys=[key], result=result + ) + computations.append(derived_computation) + return computations def _to_proto( thresholds: List[float], histogram: calibration_histogram.Histogram ) -> metrics_for_slice_pb2.CalibrationHistogramBuckets: - """Converts histogram into CalibrationHistogramBuckets proto. - - Args: - thresholds: Thresholds associated with histogram buckets. - histogram: Calibration histogram. - - Returns: - A histogram in CalibrationHistogramBuckets proto format. - """ - pb = metrics_for_slice_pb2.CalibrationHistogramBuckets() - lower_threshold = float('-inf') - for i, bucket in enumerate(histogram): - if i >= len(thresholds) - 1: - upper_threshold = float('inf') - else: - upper_threshold = thresholds[i + 1] - pb.buckets.add( - lower_threshold_inclusive=lower_threshold, - upper_threshold_exclusive=upper_threshold, - total_weighted_label={'value': bucket.weighted_labels}, - total_weighted_refined_prediction={ - 'value': bucket.weighted_predictions - }, - num_weighted_examples={'value': bucket.weighted_examples}) - lower_threshold = upper_threshold - return pb + """Converts histogram into CalibrationHistogramBuckets proto. + + Args: + ---- + thresholds: Thresholds associated with histogram buckets. + histogram: Calibration histogram. + + Returns: + ------- + A histogram in CalibrationHistogramBuckets proto format. + """ + pb = metrics_for_slice_pb2.CalibrationHistogramBuckets() + lower_threshold = float("-inf") + for i, bucket in enumerate(histogram): + if i >= len(thresholds) - 1: + upper_threshold = float("inf") + else: + upper_threshold = thresholds[i + 1] + pb.buckets.add( + lower_threshold_inclusive=lower_threshold, + upper_threshold_exclusive=upper_threshold, + total_weighted_label={"value": bucket.weighted_labels}, + total_weighted_refined_prediction={"value": bucket.weighted_predictions}, + num_weighted_examples={"value": bucket.weighted_examples}, + ) + lower_threshold = upper_threshold + return pb diff --git a/tensorflow_model_analysis/metrics/calibration_plot_test.py b/tensorflow_model_analysis/metrics/calibration_plot_test.py index 9e25cb66ec..59887ece8e 100644 --- a/tensorflow_model_analysis/metrics/calibration_plot_test.py +++ b/tensorflow_model_analysis/metrics/calibration_plot_test.py @@ -13,99 +13,112 @@ # limitations under the License. """Tests for calibration plot.""" -from absl.testing import parameterized import apache_beam as beam -from apache_beam.testing import util import numpy as np import tensorflow as tf -from tensorflow_model_analysis.metrics import calibration_plot -from tensorflow_model_analysis.metrics import metric_types -from tensorflow_model_analysis.metrics import metric_util -from tensorflow_model_analysis.proto import config_pb2 -from tensorflow_model_analysis.utils import test_util - +from absl.testing import parameterized +from apache_beam.testing import util from google.protobuf import text_format from tensorflow_metadata.proto.v0 import schema_pb2 +from tensorflow_model_analysis.metrics import ( + calibration_plot, + metric_types, + metric_util, +) +from tensorflow_model_analysis.proto import config_pb2 +from tensorflow_model_analysis.utils import test_util + class CalibrationPlotTest( test_util.TensorflowModelAnalysisTest, parameterized.TestCase ): + def testCalibrationPlot(self): + computations = calibration_plot.CalibrationPlot(num_buckets=10).computations( + example_weighted=True + ) + histogram = computations[0] + plot = computations[1] - def testCalibrationPlot(self): - computations = calibration_plot.CalibrationPlot( - num_buckets=10).computations(example_weighted=True) - histogram = computations[0] - plot = computations[1] - - example1 = { - 'labels': np.array([0.0]), - 'predictions': np.array([0.2]), - 'example_weights': np.array([1.0]) - } - example2 = { - 'labels': np.array([1.0]), - 'predictions': np.array([0.8]), - 'example_weights': np.array([2.0]) - } - example3 = { - 'labels': np.array([0.0]), - 'predictions': np.array([0.5]), - 'example_weights': np.array([3.0]) - } - example4 = { - 'labels': np.array([1.0]), - 'predictions': np.array([-0.1]), - 'example_weights': np.array([4.0]) - } - example5 = { - 'labels': np.array([1.0]), - 'predictions': np.array([0.5]), - 'example_weights': np.array([5.0]) - } - example6 = { - 'labels': np.array([1.0]), - 'predictions': np.array([0.8]), - 'example_weights': np.array([6.0]) - } - example7 = { - 'labels': np.array([0.0]), - 'predictions': np.array([0.2]), - 'example_weights': np.array([7.0]) - } - example8 = { - 'labels': np.array([1.0]), - 'predictions': np.array([1.1]), - 'example_weights': np.array([8.0]) - } + example1 = { + "labels": np.array([0.0]), + "predictions": np.array([0.2]), + "example_weights": np.array([1.0]), + } + example2 = { + "labels": np.array([1.0]), + "predictions": np.array([0.8]), + "example_weights": np.array([2.0]), + } + example3 = { + "labels": np.array([0.0]), + "predictions": np.array([0.5]), + "example_weights": np.array([3.0]), + } + example4 = { + "labels": np.array([1.0]), + "predictions": np.array([-0.1]), + "example_weights": np.array([4.0]), + } + example5 = { + "labels": np.array([1.0]), + "predictions": np.array([0.5]), + "example_weights": np.array([5.0]), + } + example6 = { + "labels": np.array([1.0]), + "predictions": np.array([0.8]), + "example_weights": np.array([6.0]), + } + example7 = { + "labels": np.array([0.0]), + "predictions": np.array([0.2]), + "example_weights": np.array([7.0]), + } + example8 = { + "labels": np.array([1.0]), + "predictions": np.array([1.1]), + "example_weights": np.array([8.0]), + } - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'Create' >> beam.Create([ - example1, example2, example3, example4, example5, example6, - example7, example8 - ]) - | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs) - | 'AddSlice' >> beam.Map(lambda x: ((), x)) - | 'ComputeHistogram' >> beam.CombinePerKey(histogram.combiner) - | 'ComputePlot' >> beam.Map(lambda x: (x[0], plot.result(x[1])))) + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "Create" + >> beam.Create( + [ + example1, + example2, + example3, + example4, + example5, + example6, + example7, + example8, + ] + ) + | "Process" >> beam.Map(metric_util.to_standard_metric_inputs) + | "AddSlice" >> beam.Map(lambda x: ((), x)) + | "ComputeHistogram" >> beam.CombinePerKey(histogram.combiner) + | "ComputePlot" >> beam.Map(lambda x: (x[0], plot.result(x[1]))) + ) - # pylint: enable=no-value-for-parameter + # pylint: enable=no-value-for-parameter - def check_result(got): - try: - self.assertLen(got, 1) - got_slice_key, got_plots = got[0] - self.assertEqual(got_slice_key, ()) - self.assertLen(got_plots, 1) - key = metric_types.PlotKey( - name='calibration_plot', example_weighted=True) - self.assertIn(key, got_plots) - got_plot = got_plots[key] - self.assertProtoEquals( - """ + def check_result(got): + try: + self.assertLen(got, 1) + got_slice_key, got_plots = got[0] + self.assertEqual(got_slice_key, ()) + self.assertLen(got_plots, 1) + key = metric_types.PlotKey( + name="calibration_plot", example_weighted=True + ) + self.assertIn(key, got_plots) + got_plot = got_plots[key] + self.assertProtoEquals( + """ buckets { lower_threshold_inclusive: -inf upper_threshold_exclusive: 0.0 @@ -240,24 +253,25 @@ def check_result(got): value: 8.0 } } - """, got_plot) + """, + got_plot, + ) - except AssertionError as err: - raise util.BeamAssertException(err) + except AssertionError as err: + raise util.BeamAssertException(err) - util.assert_that(result, check_result, label='result') + util.assert_that(result, check_result, label="result") - @parameterized.named_parameters( - { - 'testcase_name': - 'int_single_model', - 'eval_config': - config_pb2.EvalConfig(model_specs=[ - config_pb2.ModelSpec(name='model1', label_key='label'), - ]), - 'schema': - text_format.Parse( - """ + @parameterized.named_parameters( + { + "testcase_name": "int_single_model", + "eval_config": config_pb2.EvalConfig( + model_specs=[ + config_pb2.ModelSpec(name="model1", label_key="label"), + ] + ), + "schema": text_format.Parse( + """ feature { name: "label" type: INT @@ -266,23 +280,23 @@ def check_result(got): max: 15 } } - """, schema_pb2.Schema()), - 'model_names': [''], - 'output_names': [''], - 'expected_left': - 5.0, - 'expected_range': - 10.0, - }, { - 'testcase_name': - 'int_single_model_right_only', - 'eval_config': - config_pb2.EvalConfig(model_specs=[ - config_pb2.ModelSpec(name='model1', label_key='label'), - ]), - 'schema': - text_format.Parse( - """ + """, + schema_pb2.Schema(), + ), + "model_names": [""], + "output_names": [""], + "expected_left": 5.0, + "expected_range": 10.0, + }, + { + "testcase_name": "int_single_model_right_only", + "eval_config": config_pb2.EvalConfig( + model_specs=[ + config_pb2.ModelSpec(name="model1", label_key="label"), + ] + ), + "schema": text_format.Parse( + """ feature { name: "label" type: INT @@ -290,65 +304,65 @@ def check_result(got): max: 15 } } - """, schema_pb2.Schema()), - 'model_names': [''], - 'output_names': [''], - 'expected_left': - 0.0, - 'expected_range': - 15.0, - }, { - 'testcase_name': - 'int_single_model_schema_missing_domain', - 'eval_config': - config_pb2.EvalConfig(model_specs=[ - config_pb2.ModelSpec(name='model1', label_key='label'), - ]), - 'schema': - text_format.Parse( - """ + """, + schema_pb2.Schema(), + ), + "model_names": [""], + "output_names": [""], + "expected_left": 0.0, + "expected_range": 15.0, + }, + { + "testcase_name": "int_single_model_schema_missing_domain", + "eval_config": config_pb2.EvalConfig( + model_specs=[ + config_pb2.ModelSpec(name="model1", label_key="label"), + ] + ), + "schema": text_format.Parse( + """ feature { name: "label" type: FLOAT } - """, schema_pb2.Schema()), - 'model_names': [''], - 'output_names': [''], - 'expected_left': - 0.0, - 'expected_range': - 1.0, - }, { - 'testcase_name': - 'int_single_model_schema_missing_label', - 'eval_config': - config_pb2.EvalConfig(model_specs=[ - config_pb2.ModelSpec(name='model1', label_key='label'), - ]), - 'schema': - text_format.Parse( - """ + """, + schema_pb2.Schema(), + ), + "model_names": [""], + "output_names": [""], + "expected_left": 0.0, + "expected_range": 1.0, + }, + { + "testcase_name": "int_single_model_schema_missing_label", + "eval_config": config_pb2.EvalConfig( + model_specs=[ + config_pb2.ModelSpec(name="model1", label_key="label"), + ] + ), + "schema": text_format.Parse( + """ feature { name: "other_feature" type: BYTES } - """, schema_pb2.Schema()), - 'model_names': [''], - 'output_names': [''], - 'expected_left': - 0.0, - 'expected_range': - 1.0, - }, { - 'testcase_name': - 'float_single_model', - 'eval_config': - config_pb2.EvalConfig(model_specs=[ - config_pb2.ModelSpec(name='model1', label_key='label'), - ]), - 'schema': - text_format.Parse( - """ + """, + schema_pb2.Schema(), + ), + "model_names": [""], + "output_names": [""], + "expected_left": 0.0, + "expected_range": 1.0, + }, + { + "testcase_name": "float_single_model", + "eval_config": config_pb2.EvalConfig( + model_specs=[ + config_pb2.ModelSpec(name="model1", label_key="label"), + ] + ), + "schema": text_format.Parse( + """ feature { name: "label" type: FLOAT @@ -357,29 +371,27 @@ def check_result(got): max: 15.0 } } - """, schema_pb2.Schema()), - 'model_names': [''], - 'output_names': [''], - 'expected_left': - 5.0, - 'expected_range': - 10.0 - }, { - 'testcase_name': - 'float_single_model_multiple_outputs', - 'eval_config': - config_pb2.EvalConfig(model_specs=[ - config_pb2.ModelSpec( - name='model1', - label_keys={ - 'output1': 'label1', - 'output2': 'label2' - }, - signature_name='default'), - ]), - 'schema': - text_format.Parse( - """ + """, + schema_pb2.Schema(), + ), + "model_names": [""], + "output_names": [""], + "expected_left": 5.0, + "expected_range": 10.0, + }, + { + "testcase_name": "float_single_model_multiple_outputs", + "eval_config": config_pb2.EvalConfig( + model_specs=[ + config_pb2.ModelSpec( + name="model1", + label_keys={"output1": "label1", "output2": "label2"}, + signature_name="default", + ), + ] + ), + "schema": text_format.Parse( + """ feature { name: "label2" type: FLOAT @@ -388,24 +400,24 @@ def check_result(got): max: 15.0 } } - """, schema_pb2.Schema()), - 'model_names': [''], - 'output_names': ['output2'], - 'expected_left': - 5.0, - 'expected_range': - 10.0 - }, { - 'testcase_name': - 'float_multiple_models', - 'eval_config': - config_pb2.EvalConfig(model_specs=[ - config_pb2.ModelSpec(name='model1', label_key='label1'), - config_pb2.ModelSpec(name='model2', label_key='label2') - ]), - 'schema': - text_format.Parse( - """ + """, + schema_pb2.Schema(), + ), + "model_names": [""], + "output_names": ["output2"], + "expected_left": 5.0, + "expected_range": 10.0, + }, + { + "testcase_name": "float_multiple_models", + "eval_config": config_pb2.EvalConfig( + model_specs=[ + config_pb2.ModelSpec(name="model1", label_key="label1"), + config_pb2.ModelSpec(name="model2", label_key="label2"), + ] + ), + "schema": text_format.Parse( + """ feature { name: "label2" type: FLOAT @@ -414,27 +426,34 @@ def check_result(got): max: 15.0 } } - """, schema_pb2.Schema()), - 'model_names': ['model2'], - 'output_names': [''], - 'expected_left': - 5.0, - 'expected_range': - 10.0 - }) - def testCalibrationPlotWithSchema(self, eval_config, schema, model_names, - output_names, expected_left, - expected_range): - computations = calibration_plot.CalibrationPlot( - num_buckets=10).computations( + """, + schema_pb2.Schema(), + ), + "model_names": ["model2"], + "output_names": [""], + "expected_left": 5.0, + "expected_range": 10.0, + }, + ) + def testCalibrationPlotWithSchema( + self, + eval_config, + schema, + model_names, + output_names, + expected_left, + expected_range, + ): + computations = calibration_plot.CalibrationPlot(num_buckets=10).computations( eval_config=eval_config, schema=schema, model_names=model_names, - output_names=output_names) - histogram = computations[0] - self.assertEqual(expected_left, histogram.combiner._left) - self.assertEqual(expected_range, histogram.combiner._range) + output_names=output_names, + ) + histogram = computations[0] + self.assertEqual(expected_left, histogram.combiner._left) + self.assertEqual(expected_range, histogram.combiner._range) -if __name__ == '__main__': - tf.test.main() +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow_model_analysis/metrics/calibration_test.py b/tensorflow_model_analysis/metrics/calibration_test.py index f3d432b07b..193851f0e1 100644 --- a/tensorflow_model_analysis/metrics/calibration_test.py +++ b/tensorflow_model_analysis/metrics/calibration_test.py @@ -13,126 +13,137 @@ # limitations under the License. """Tests for calibration related metrics.""" -from absl.testing import parameterized import apache_beam as beam -from apache_beam.testing import util import numpy as np import tensorflow as tf -from tensorflow_model_analysis.metrics import calibration -from tensorflow_model_analysis.metrics import metric_util +from absl.testing import parameterized +from apache_beam.testing import util + +from tensorflow_model_analysis.metrics import calibration, metric_util from tensorflow_model_analysis.utils import test_util class CalibrationMetricsTest( test_util.TensorflowModelAnalysisTest, parameterized.TestCase ): - - @parameterized.named_parameters( - ('mean_label', calibration.MeanLabel(), 2.0 / 3.0), - ('mean_prediction', calibration.MeanPrediction(), (0.3 + 0.9) / 3.0), - ('calibration', calibration.Calibration(), (0.3 + 0.9) / 2.0)) - def testCalibrationMetricsWithoutWeights(self, metric, expected_value): - computations = metric.computations() - weighted_totals = computations[0] - metric = computations[1] - - example1 = { - 'labels': np.array([0.0]), - 'predictions': np.array([0.0]), - 'example_weights': np.array([1.0]), - } - example2 = { - 'labels': np.array([1.0]), - 'predictions': np.array([0.3]), - 'example_weights': np.array([1.0]), - } - example3 = { - 'labels': np.array([1.0]), - 'predictions': np.array([0.9]), - 'example_weights': None, # defaults to 1.0 - } - - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'Create' >> beam.Create([example1, example2, example3]) - | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs) - | 'AddSlice' >> beam.Map(lambda x: ((), x)) - | 'ComputeWeightedTotals' >> beam.CombinePerKey( - weighted_totals.combiner) - | 'ComputeMetric' >> beam.Map(lambda x: (x[0], metric.result(x[1])))) - - # pylint: enable=no-value-for-parameter - - def check_result(got): - try: - self.assertLen(got, 1) - got_slice_key, got_metrics = got[0] - self.assertEqual(got_slice_key, ()) - key = metric.keys[0] - self.assertDictElementsAlmostEqual( - got_metrics, {key: expected_value}, places=5) - - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(result, check_result, label='result') - - @parameterized.named_parameters( - ('mean_label', calibration.MeanLabel(), 1.0 * 0.7 / 2.1), - ('mean_prediction', calibration.MeanPrediction(), - (1.0 * 0.5 + 0.7 * 0.7 + 0.5 * 0.9) / 2.1), - ('calibration', calibration.Calibration(), - (1.0 * 0.5 + 0.7 * 0.7 + 0.5 * 0.9) / (1.0 * 0.7))) - def testCalibrationMetricsWithWeights(self, metric, expected_value): - computations = metric.computations(example_weighted=True) - weighted_totals = computations[0] - metric = computations[1] - - example1 = { - 'labels': np.array([0.0]), - 'predictions': np.array([1.0]), - 'example_weights': np.array([0.5]), - } - example2 = { - 'labels': np.array([1.0]), - 'predictions': np.array([0.7]), - 'example_weights': np.array([0.7]), - } - example3 = { - 'labels': np.array([0.0]), - 'predictions': np.array([0.5]), - 'example_weights': np.array([0.9]), - } - - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'Create' >> beam.Create([example1, example2, example3]) - | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs) - | 'AddSlice' >> beam.Map(lambda x: ((), x)) - | 'ComputeWeightedTotals' >> beam.CombinePerKey( - weighted_totals.combiner) - | 'ComputeMetric' >> beam.Map(lambda x: (x[0], metric.result(x[1])))) - - # pylint: enable=no-value-for-parameter - - def check_result(got): - try: - self.assertLen(got, 1) - got_slice_key, got_metrics = got[0] - self.assertEqual(got_slice_key, ()) - key = metric.keys[0] - self.assertDictElementsAlmostEqual( - got_metrics, {key: expected_value}, places=5) - - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(result, check_result, label='result') - - -if __name__ == '__main__': - tf.test.main() + @parameterized.named_parameters( + ("mean_label", calibration.MeanLabel(), 2.0 / 3.0), + ("mean_prediction", calibration.MeanPrediction(), (0.3 + 0.9) / 3.0), + ("calibration", calibration.Calibration(), (0.3 + 0.9) / 2.0), + ) + def testCalibrationMetricsWithoutWeights(self, metric, expected_value): + computations = metric.computations() + weighted_totals = computations[0] + metric = computations[1] + + example1 = { + "labels": np.array([0.0]), + "predictions": np.array([0.0]), + "example_weights": np.array([1.0]), + } + example2 = { + "labels": np.array([1.0]), + "predictions": np.array([0.3]), + "example_weights": np.array([1.0]), + } + example3 = { + "labels": np.array([1.0]), + "predictions": np.array([0.9]), + "example_weights": None, # defaults to 1.0 + } + + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "Create" >> beam.Create([example1, example2, example3]) + | "Process" >> beam.Map(metric_util.to_standard_metric_inputs) + | "AddSlice" >> beam.Map(lambda x: ((), x)) + | "ComputeWeightedTotals" + >> beam.CombinePerKey(weighted_totals.combiner) + | "ComputeMetric" >> beam.Map(lambda x: (x[0], metric.result(x[1]))) + ) + + # pylint: enable=no-value-for-parameter + + def check_result(got): + try: + self.assertLen(got, 1) + got_slice_key, got_metrics = got[0] + self.assertEqual(got_slice_key, ()) + key = metric.keys[0] + self.assertDictElementsAlmostEqual( + got_metrics, {key: expected_value}, places=5 + ) + + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(result, check_result, label="result") + + @parameterized.named_parameters( + ("mean_label", calibration.MeanLabel(), 1.0 * 0.7 / 2.1), + ( + "mean_prediction", + calibration.MeanPrediction(), + (1.0 * 0.5 + 0.7 * 0.7 + 0.5 * 0.9) / 2.1, + ), + ( + "calibration", + calibration.Calibration(), + (1.0 * 0.5 + 0.7 * 0.7 + 0.5 * 0.9) / (1.0 * 0.7), + ), + ) + def testCalibrationMetricsWithWeights(self, metric, expected_value): + computations = metric.computations(example_weighted=True) + weighted_totals = computations[0] + metric = computations[1] + + example1 = { + "labels": np.array([0.0]), + "predictions": np.array([1.0]), + "example_weights": np.array([0.5]), + } + example2 = { + "labels": np.array([1.0]), + "predictions": np.array([0.7]), + "example_weights": np.array([0.7]), + } + example3 = { + "labels": np.array([0.0]), + "predictions": np.array([0.5]), + "example_weights": np.array([0.9]), + } + + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "Create" >> beam.Create([example1, example2, example3]) + | "Process" >> beam.Map(metric_util.to_standard_metric_inputs) + | "AddSlice" >> beam.Map(lambda x: ((), x)) + | "ComputeWeightedTotals" + >> beam.CombinePerKey(weighted_totals.combiner) + | "ComputeMetric" >> beam.Map(lambda x: (x[0], metric.result(x[1]))) + ) + + # pylint: enable=no-value-for-parameter + + def check_result(got): + try: + self.assertLen(got, 1) + got_slice_key, got_metrics = got[0] + self.assertEqual(got_slice_key, ()) + key = metric.keys[0] + self.assertDictElementsAlmostEqual( + got_metrics, {key: expected_value}, places=5 + ) + + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(result, check_result, label="result") + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow_model_analysis/metrics/confusion_matrix_metrics.py b/tensorflow_model_analysis/metrics/confusion_matrix_metrics.py index d65b53e159..c9fc5d1091 100644 --- a/tensorflow_model_analysis/metrics/confusion_matrix_metrics.py +++ b/tensorflow_model_analysis/metrics/confusion_matrix_metrics.py @@ -20,144 +20,151 @@ from typing import Any, Callable, Dict, List, Optional, Union, overload import numpy as np -from tensorflow_model_analysis.metrics import binary_confusion_matrices -from tensorflow_model_analysis.metrics import metric_types -from tensorflow_model_analysis.metrics import metric_util + +from tensorflow_model_analysis.metrics import ( + binary_confusion_matrices, + metric_types, + metric_util, +) from tensorflow_model_analysis.proto import config_pb2 -AUC_NAME = 'auc' -AUC_PRECISION_RECALL_NAME = 'auc_precision_recall' -SENSITIVITY_AT_SPECIFICITY_NAME = 'sensitivity_at_specificity' -SPECIFICITY_AT_SENSITIVITY_NAME = 'specificity_at_sensitivity' -PRECISION_AT_RECALL_NAME = 'precision_at_recall' -RECALL_AT_PRECISION_NAME = 'recall_at_precision' -RECALL_AT_FALSE_POSITIVE_RATE_NAME = 'recall_at_false_positive_rate' -TRUE_POSITIVES_NAME = 'true_positives' -TP_NAME = 'tp' -TRUE_NEGATIVES_NAME = 'true_negatives' -TN_NAME = 'tn' -FALSE_POSITIVES_NAME = 'false_positives' -FP_NAME = 'fp' -FALSE_NEGATIVES_NAME = 'false_negatives' -FN_NAME = 'fn' -BINARY_ACCURACY_NAME = 'binary_accuracy' -PRECISION_NAME = 'precision' -PPV_NAME = 'ppv' -RECALL_NAME = 'recall' -TPR_NAME = 'tpr' -SPECIFICITY_NAME = 'specificity' -TNR_NAME = 'tnr' -FALL_OUT_NAME = 'fall_out' -FPR_NAME = 'fpr' -MISS_RATE_NAME = 'miss_rate' -FNR_NAME = 'fnr' -NEGATIVE_PREDICTIVE_VALUE_NAME = 'negative_predictive_value' -NPV_NAME = 'npv' -FALSE_DISCOVERY_RATE_NAME = 'false_discovery_rate' -FALSE_OMISSION_RATE_NAME = 'false_omission_rate' -PREVALENCE_NAME = 'prevalence' -PREVALENCE_THRESHOLD_NAME = 'prevalence_threshold' -THREAT_SCORE_NAME = 'threat_score' -BALANCED_ACCURACY_NAME = 'balanced_accuracy' -F1_SCORE_NAME = 'f1_score' -MATTHEWS_CORRELATION_COEFFICIENT_NAME = 'matthews_correlation_coefficient' -FOWLKES_MALLOWS_INDEX_NAME = 'fowlkes_mallows_index' -INFORMEDNESS_NAME = 'informedness' -MARKEDNESS_NAME = 'markedness' -POSITIVE_LIKELIHOOD_RATIO_NAME = 'positive_likelihood_ratio' -NEGATIVE_LIKELIHOOD_RATIO_NAME = 'negative_likelihood_ratio' -DIAGNOSTIC_ODDS_RATIO_NAME = 'diagnostic_odds_ratio' -PREDICTED_POSITIVE_RATE_NAME = 'predicted_positive_rate' -CONFUSION_MATRIX_AT_THRESHOLDS_NAME = 'confusion_matrix_at_thresholds' -FALSE_POSITIVE_FEATURE_SAMPLER_NAME = 'false_positive_feature_sampler' -FALSE_NEGATIVE_FEATURE_SAMPLER_NAME = 'false_negative_feature_sampler' -AVERAGE_PRECISION_NAME = 'average_precision' -MAX_RECALL_NAME = 'max_recall' -THRESHOLD_AT_RECALL_NAME = 'threshold_at_recall' +AUC_NAME = "auc" +AUC_PRECISION_RECALL_NAME = "auc_precision_recall" +SENSITIVITY_AT_SPECIFICITY_NAME = "sensitivity_at_specificity" +SPECIFICITY_AT_SENSITIVITY_NAME = "specificity_at_sensitivity" +PRECISION_AT_RECALL_NAME = "precision_at_recall" +RECALL_AT_PRECISION_NAME = "recall_at_precision" +RECALL_AT_FALSE_POSITIVE_RATE_NAME = "recall_at_false_positive_rate" +TRUE_POSITIVES_NAME = "true_positives" +TP_NAME = "tp" +TRUE_NEGATIVES_NAME = "true_negatives" +TN_NAME = "tn" +FALSE_POSITIVES_NAME = "false_positives" +FP_NAME = "fp" +FALSE_NEGATIVES_NAME = "false_negatives" +FN_NAME = "fn" +BINARY_ACCURACY_NAME = "binary_accuracy" +PRECISION_NAME = "precision" +PPV_NAME = "ppv" +RECALL_NAME = "recall" +TPR_NAME = "tpr" +SPECIFICITY_NAME = "specificity" +TNR_NAME = "tnr" +FALL_OUT_NAME = "fall_out" +FPR_NAME = "fpr" +MISS_RATE_NAME = "miss_rate" +FNR_NAME = "fnr" +NEGATIVE_PREDICTIVE_VALUE_NAME = "negative_predictive_value" +NPV_NAME = "npv" +FALSE_DISCOVERY_RATE_NAME = "false_discovery_rate" +FALSE_OMISSION_RATE_NAME = "false_omission_rate" +PREVALENCE_NAME = "prevalence" +PREVALENCE_THRESHOLD_NAME = "prevalence_threshold" +THREAT_SCORE_NAME = "threat_score" +BALANCED_ACCURACY_NAME = "balanced_accuracy" +F1_SCORE_NAME = "f1_score" +MATTHEWS_CORRELATION_COEFFICIENT_NAME = "matthews_correlation_coefficient" +FOWLKES_MALLOWS_INDEX_NAME = "fowlkes_mallows_index" +INFORMEDNESS_NAME = "informedness" +MARKEDNESS_NAME = "markedness" +POSITIVE_LIKELIHOOD_RATIO_NAME = "positive_likelihood_ratio" +NEGATIVE_LIKELIHOOD_RATIO_NAME = "negative_likelihood_ratio" +DIAGNOSTIC_ODDS_RATIO_NAME = "diagnostic_odds_ratio" +PREDICTED_POSITIVE_RATE_NAME = "predicted_positive_rate" +CONFUSION_MATRIX_AT_THRESHOLDS_NAME = "confusion_matrix_at_thresholds" +FALSE_POSITIVE_FEATURE_SAMPLER_NAME = "false_positive_feature_sampler" +FALSE_NEGATIVE_FEATURE_SAMPLER_NAME = "false_negative_feature_sampler" +AVERAGE_PRECISION_NAME = "average_precision" +MAX_RECALL_NAME = "max_recall" +THRESHOLD_AT_RECALL_NAME = "threshold_at_recall" class AUCCurve(enum.Enum): - ROC = 'ROC' - PR = 'PR' + ROC = "ROC" + PR = "PR" class AUCSummationMethod(enum.Enum): - INTERPOLATION = 'interpolation' - MAJORING = 'majoring' - MINORING = 'minoring' + INTERPOLATION = "interpolation" + MAJORING = "majoring" + MINORING = "minoring" -def _divide_only_positive_denominator( - numerator: float, denominator: float -) -> float: - """Returns division only when denominator is positive, otherwise nan.""" - return numerator / denominator if denominator > 0 else float('nan') +def _divide_only_positive_denominator(numerator: float, denominator: float) -> float: + """Returns division only when denominator is positive, otherwise nan.""" + return numerator / denominator if denominator > 0 else float("nan") def _pos_sqrt(value: float) -> float: - """Returns sqrt of value or raises ValueError if negative.""" - if value < 0: - raise ValueError('Attempt to take sqrt of negative value: {}'.format(value)) - return math.sqrt(value) + """Returns sqrt of value or raises ValueError if negative.""" + if value < 0: + raise ValueError(f"Attempt to take sqrt of negative value: {value}") + return math.sqrt(value) def _validate_and_update_sub_key( - metric_name: str, model_name: str, output_name: str, - sub_key: metric_types.SubKey, top_k: Optional[int], - class_id: Optional[int]) -> metric_types.SubKey: - """Validates and updates sub key. - - This function validates that the top_k and class_id settings that are - determined by the MetricsSpec.binarize are compatible and do not overlap with - any settings provided by MetricConfigs. - - Args: - metric_name: Metric name. - model_name: Model name. - output_name: Output name. - sub_key: Sub key (from MetricsSpec). - top_k: Top k setting (from MetricConfig). - class_id: Class ID setting (from MetricConfig). - - Returns: - Updated sub-key if top_k or class_id params are used. - - Raises: - ValueError: If validation fails. - """ - - if sub_key is None: - if top_k is None and class_id is None: - return None - else: - sub_key = metric_types.SubKey() - if top_k is not None: - if sub_key.top_k is not None: - raise ValueError( - f'Metric {metric_name} is configured with overlapping settings. ' - f'The metric was initialized with top_k={top_k}, but the ' - f'metric was defined in a spec using sub_key={sub_key}, ' - f'model_name={model_name}, output_name={output_name}\n\n' - 'Binarization related settings can be configured in either the' - 'metrics_spec or the metric, but not both. Either remove the top_k ' - 'setting from this metric or remove the metrics_spec.binarize ' - 'settings.' - ) - sub_key = sub_key._replace(top_k=top_k) - if class_id is not None: - if sub_key.class_id is not None: - raise ValueError( - f'Metric {metric_name} is configured with overlapping settings. ' - f'The metric was initialized with class_id={class_id}, but the ' - f'metric was defined in a spec using sub_key={sub_key}, ' - f'model_name={model_name}, output_name={output_name}\n\n' - 'Binarization related settings can be configured in either the' - 'metrics_spec or the metric, but not both. Either remove the class_id' - ' setting from this metric or remove the metrics_spec.binarize ' - 'settings.' - ) - sub_key = sub_key._replace(class_id=class_id) - return sub_key + metric_name: str, + model_name: str, + output_name: str, + sub_key: metric_types.SubKey, + top_k: Optional[int], + class_id: Optional[int], +) -> metric_types.SubKey: + """Validates and updates sub key. + + This function validates that the top_k and class_id settings that are + determined by the MetricsSpec.binarize are compatible and do not overlap with + any settings provided by MetricConfigs. + + Args: + ---- + metric_name: Metric name. + model_name: Model name. + output_name: Output name. + sub_key: Sub key (from MetricsSpec). + top_k: Top k setting (from MetricConfig). + class_id: Class ID setting (from MetricConfig). + + Returns: + ------- + Updated sub-key if top_k or class_id params are used. + + Raises: + ------ + ValueError: If validation fails. + """ + if sub_key is None: + if top_k is None and class_id is None: + return None + else: + sub_key = metric_types.SubKey() + if top_k is not None: + if sub_key.top_k is not None: + raise ValueError( + f"Metric {metric_name} is configured with overlapping settings. " + f"The metric was initialized with top_k={top_k}, but the " + f"metric was defined in a spec using sub_key={sub_key}, " + f"model_name={model_name}, output_name={output_name}\n\n" + "Binarization related settings can be configured in either the" + "metrics_spec or the metric, but not both. Either remove the top_k " + "setting from this metric or remove the metrics_spec.binarize " + "settings." + ) + sub_key = sub_key._replace(top_k=top_k) + if class_id is not None: + if sub_key.class_id is not None: + raise ValueError( + f"Metric {metric_name} is configured with overlapping settings. " + f"The metric was initialized with class_id={class_id}, but the " + f"metric was defined in a spec using sub_key={sub_key}, " + f"model_name={model_name}, output_name={output_name}\n\n" + "Binarization related settings can be configured in either the" + "metrics_spec or the metric, but not both. Either remove the class_id" + " setting from this metric or remove the metrics_spec.binarize " + "settings." + ) + sub_key = sub_key._replace(class_id=class_id) + return sub_key @overload @@ -166,8 +173,7 @@ def _find_max_under_constraint( dependent: np.ndarray, values: float, operator: Callable[[np.ndarray, np.ndarray], np.ndarray] = np.greater_equal, -) -> float: - ... +) -> float: ... @overload @@ -176,2493 +182,2688 @@ def _find_max_under_constraint( dependent: np.ndarray, values: List[float], operator: Callable[[np.ndarray, np.ndarray], np.ndarray] = np.greater_equal, -) -> np.ndarray: - ... +) -> np.ndarray: ... def _find_max_under_constraint( constrained, dependent, values, operator=np.greater_equal ): - """Returns the maximum of dependent that satisfies contraints. - - Args: - constrained: Over these values the constraint is specified. A rank-1 np - array. - dependent: From these values the maximum that satiesfies the constraint is - selected. Values in this array and in `constrained` are linked by having - the same threshold at each position, hence this array must have the same - shape. - values: A list of the value constraints. - operator: A numpy logic functions. Default is greater_equal. - - Returns: - Maximal dependent value, if no value satiesfies the constraint 0.0. - """ - result = [] - for value in np.array([values] if isinstance(values, float) else values): - feasible = np.where(operator(constrained, value)) - gathered = np.take(dependent, feasible) - if gathered.size > 0: - result.append( - float(np.where(np.size(feasible) > 0, np.nanmax(gathered), 0.0))) - else: - # If the gathered is empty, return 0.0 assuming all NaNs are 0.0 - result.append(0.0) - if isinstance(values, float): - return result[0] - else: - return np.array(result) - - -class ConfusionMatrixMetricBase(metric_types.Metric, metaclass=abc.ABCMeta): - """Base for confusion matrix metrics.""" - - def __init__(self, - thresholds: Optional[Union[float, List[float]]] = None, - num_thresholds: Optional[int] = None, - top_k: Optional[int] = None, - class_id: Optional[int] = None, - name: Optional[str] = None, - preprocessors: Optional[List[metric_types.Preprocessor]] = None, - **kwargs): - """Initializes confusion matrix metric. + """Returns the maximum of dependent that satisfies contraints. Args: - thresholds: (Optional) Thresholds to use for calculating the matrices. Use - one of either thresholds or num_thresholds. - num_thresholds: (Optional) Number of thresholds to use for calculating the - matrices. Use one of either thresholds or num_thresholds. - top_k: (Optional) Used with a multi-class model to specify that the top-k - values should be used to compute the confusion matrix. The net effect is - that the non-top-k values are set to -inf and the matrix is then - constructed from the average TP, FP, TN, FN across the classes. When - top_k is used, metrics_specs.binarize settings must not be present. Only - one of class_id or top_k should be configured. When top_k is set, the - default thresholds are [float('-inf')]. - class_id: (Optional) Used with a multi-class model to specify which class - to compute the confusion matrix for. When class_id is used, - metrics_specs.binarize settings must not be present. Only one of - class_id or top_k should be configured. - name: (Optional) Metric name. - preprocessors: User-provided preprocessor for including additional - extracts in StandardMetricInputs (relevant only when use_histogram flag - is not true). - **kwargs: (Optional) Additional args to pass along to init (and eventually - on to _metric_computation and _metric_value) + ---- + constrained: Over these values the constraint is specified. A rank-1 np + array. + dependent: From these values the maximum that satiesfies the constraint is + selected. Values in this array and in `constrained` are linked by having + the same threshold at each position, hence this array must have the same + shape. + values: A list of the value constraints. + operator: A numpy logic functions. Default is greater_equal. + + Returns: + ------- + Maximal dependent value, if no value satiesfies the constraint 0.0. """ - super().__init__( - metric_util.merge_per_key_computations(self._metric_computations), - thresholds=thresholds, - num_thresholds=num_thresholds, - top_k=top_k, - class_id=class_id, - name=name, - **kwargs) - - def _default_threshold(self) -> Optional[float]: - """Returns default threshold if thresholds or num_thresholds unset.""" - return None - - def get_config(self) -> Dict[str, Any]: - """Returns serializable config.""" - # Not all subclasses of ConfusionMatrixMetric support all the __init__ - # parameters as part of their __init__, to avoid deserialization issues - # where an unsupported parameter is passed to the subclass, filter out any - # parameters that are None. - kwargs = copy.copy(self.kwargs) - for arg in ('thresholds', 'num_thresholds', 'top_k', 'class_id'): - if kwargs[arg] is None: - del kwargs[arg] - return kwargs - - @abc.abstractmethod - def _metric_value( - self, - key: metric_types.MetricKey, - matrices: binary_confusion_matrices.Matrices, # pytype: disable=signature-mismatch # always-use-return-annotations - ) -> Union[float, np.ndarray]: - """Returns metric value associated with matrices. - - Subclasses may override this method. Any additional kwargs passed to - __init__ will be forwarded along to this call. - - Args: - key: Metric key. - matrices: Computed binary confusion matrices. - """ - raise NotImplementedError('Must be implemented in subclasses.') - - def _metric_computations( - self, - thresholds: Optional[Union[float, List[float]]] = None, - num_thresholds: Optional[int] = None, - top_k: Optional[int] = None, - class_id: Optional[int] = None, - name: Optional[str] = None, - eval_config: Optional[config_pb2.EvalConfig] = None, - model_name: str = '', - output_name: str = '', - sub_key: Optional[metric_types.SubKey] = None, - aggregation_type: Optional[metric_types.AggregationType] = None, - class_weights: Optional[Dict[int, float]] = None, - example_weighted: bool = False, - preprocessors: Optional[List[metric_types.Preprocessor]] = None, - metric_key: Optional[metric_types.MetricKey] = None, - **kwargs, - ) -> metric_types.MetricComputations: - """Returns computations for confusion matrix metric.""" - sub_key = _validate_and_update_sub_key(name, model_name, output_name, - sub_key, top_k, class_id) - if metric_key: - key = metric_key + result = [] + for value in np.array([values] if isinstance(values, float) else values): + feasible = np.where(operator(constrained, value)) + gathered = np.take(dependent, feasible) + if gathered.size > 0: + result.append( + float(np.where(np.size(feasible) > 0, np.nanmax(gathered), 0.0)) + ) + else: + # If the gathered is empty, return 0.0 assuming all NaNs are 0.0 + result.append(0.0) + if isinstance(values, float): + return result[0] else: - key = metric_types.MetricKey( - name=name, - model_name=model_name, - output_name=output_name, - sub_key=sub_key, - example_weighted=example_weighted, - aggregation_type=aggregation_type, - ) - - if num_thresholds is None and thresholds is None: - # If top_k set, then use -inf as the default threshold setting. - if sub_key and sub_key.top_k: - thresholds = [float('-inf')] - elif self._default_threshold() is not None: - thresholds = [self._default_threshold()] - if isinstance(thresholds, float): - thresholds = [thresholds] - - # Make sure matrices are calculated. - matrices_computations = binary_confusion_matrices.binary_confusion_matrices( - num_thresholds=num_thresholds, - thresholds=thresholds, - eval_config=eval_config, - model_name=model_name, - output_name=output_name, - sub_key=sub_key, - aggregation_type=aggregation_type, - class_weights=class_weights, - example_weighted=example_weighted, - preprocessors=preprocessors) - matrices_key = matrices_computations[-1].keys[-1] - - def result( - metrics: Dict[metric_types.MetricKey, Any] - ) -> Dict[metric_types.MetricKey, Union[float, np.ndarray]]: - value = self._metric_value( - key=key, matrices=metrics[matrices_key], **kwargs) - return {key: value} - - derived_computation = metric_types.DerivedMetricComputation( - keys=[key], result=result) - computations = matrices_computations - computations.append(derived_computation) - return computations - + return np.array(result) -class ConfusionMatrixMetric(ConfusionMatrixMetricBase): - """Base for confusion matrix metrics.""" - def __init__(self, - thresholds: Optional[Union[float, List[float]]] = None, - num_thresholds: Optional[int] = None, - top_k: Optional[int] = None, - class_id: Optional[int] = None, - name: Optional[str] = None, - **kwargs): - """Initializes confusion matrix metric. +class ConfusionMatrixMetricBase(metric_types.Metric, metaclass=abc.ABCMeta): + """Base for confusion matrix metrics.""" + + def __init__( + self, + thresholds: Optional[Union[float, List[float]]] = None, + num_thresholds: Optional[int] = None, + top_k: Optional[int] = None, + class_id: Optional[int] = None, + name: Optional[str] = None, + preprocessors: Optional[List[metric_types.Preprocessor]] = None, + **kwargs, + ): + """Initializes confusion matrix metric. + + Args: + ---- + thresholds: (Optional) Thresholds to use for calculating the matrices. Use + one of either thresholds or num_thresholds. + num_thresholds: (Optional) Number of thresholds to use for calculating the + matrices. Use one of either thresholds or num_thresholds. + top_k: (Optional) Used with a multi-class model to specify that the top-k + values should be used to compute the confusion matrix. The net effect is + that the non-top-k values are set to -inf and the matrix is then + constructed from the average TP, FP, TN, FN across the classes. When + top_k is used, metrics_specs.binarize settings must not be present. Only + one of class_id or top_k should be configured. When top_k is set, the + default thresholds are [float('-inf')]. + class_id: (Optional) Used with a multi-class model to specify which class + to compute the confusion matrix for. When class_id is used, + metrics_specs.binarize settings must not be present. Only one of + class_id or top_k should be configured. + name: (Optional) Metric name. + preprocessors: User-provided preprocessor for including additional + extracts in StandardMetricInputs (relevant only when use_histogram flag + is not true). + **kwargs: (Optional) Additional args to pass along to init (and eventually + on to _metric_computation and _metric_value) + """ + super().__init__( + metric_util.merge_per_key_computations(self._metric_computations), + thresholds=thresholds, + num_thresholds=num_thresholds, + top_k=top_k, + class_id=class_id, + name=name, + **kwargs, + ) + + def _default_threshold(self) -> Optional[float]: + """Returns default threshold if thresholds or num_thresholds unset.""" + return None + + def get_config(self) -> Dict[str, Any]: + """Returns serializable config.""" + # Not all subclasses of ConfusionMatrixMetric support all the __init__ + # parameters as part of their __init__, to avoid deserialization issues + # where an unsupported parameter is passed to the subclass, filter out any + # parameters that are None. + kwargs = copy.copy(self.kwargs) + for arg in ("thresholds", "num_thresholds", "top_k", "class_id"): + if kwargs[arg] is None: + del kwargs[arg] + return kwargs + + @abc.abstractmethod + def _metric_value( + self, + key: metric_types.MetricKey, + matrices: binary_confusion_matrices.Matrices, # pytype: disable=signature-mismatch # always-use-return-annotations + ) -> Union[float, np.ndarray]: + """Returns metric value associated with matrices. + + Subclasses may override this method. Any additional kwargs passed to + __init__ will be forwarded along to this call. + + Args: + ---- + key: Metric key. + matrices: Computed binary confusion matrices. + """ + raise NotImplementedError("Must be implemented in subclasses.") + + def _metric_computations( + self, + thresholds: Optional[Union[float, List[float]]] = None, + num_thresholds: Optional[int] = None, + top_k: Optional[int] = None, + class_id: Optional[int] = None, + name: Optional[str] = None, + eval_config: Optional[config_pb2.EvalConfig] = None, + model_name: str = "", + output_name: str = "", + sub_key: Optional[metric_types.SubKey] = None, + aggregation_type: Optional[metric_types.AggregationType] = None, + class_weights: Optional[Dict[int, float]] = None, + example_weighted: bool = False, + preprocessors: Optional[List[metric_types.Preprocessor]] = None, + metric_key: Optional[metric_types.MetricKey] = None, + **kwargs, + ) -> metric_types.MetricComputations: + """Returns computations for confusion matrix metric.""" + sub_key = _validate_and_update_sub_key( + name, model_name, output_name, sub_key, top_k, class_id + ) + if metric_key: + key = metric_key + else: + key = metric_types.MetricKey( + name=name, + model_name=model_name, + output_name=output_name, + sub_key=sub_key, + example_weighted=example_weighted, + aggregation_type=aggregation_type, + ) + + if num_thresholds is None and thresholds is None: + # If top_k set, then use -inf as the default threshold setting. + if sub_key and sub_key.top_k: + thresholds = [float("-inf")] + elif self._default_threshold() is not None: + thresholds = [self._default_threshold()] + if isinstance(thresholds, float): + thresholds = [thresholds] + + # Make sure matrices are calculated. + matrices_computations = binary_confusion_matrices.binary_confusion_matrices( + num_thresholds=num_thresholds, + thresholds=thresholds, + eval_config=eval_config, + model_name=model_name, + output_name=output_name, + sub_key=sub_key, + aggregation_type=aggregation_type, + class_weights=class_weights, + example_weighted=example_weighted, + preprocessors=preprocessors, + ) + matrices_key = matrices_computations[-1].keys[-1] + + def result( + metrics: Dict[metric_types.MetricKey, Any], + ) -> Dict[metric_types.MetricKey, Union[float, np.ndarray]]: + value = self._metric_value( + key=key, matrices=metrics[matrices_key], **kwargs + ) + return {key: value} + + derived_computation = metric_types.DerivedMetricComputation( + keys=[key], result=result + ) + computations = matrices_computations + computations.append(derived_computation) + return computations - Args: - thresholds: (Optional) Thresholds to use for calculating the matrices. Use - one of either thresholds or num_thresholds. - num_thresholds: (Optional) Number of thresholds to use for calculating the - matrices. Use one of either thresholds or num_thresholds. - top_k: (Optional) Used with a multi-class model to specify that the top-k - values should be used to compute the confusion matrix. The net effect is - that the non-top-k values are set to -inf and the matrix is then - constructed from the average TP, FP, TN, FN across the classes. When - top_k is used, metrics_specs.binarize settings must not be present. Only - one of class_id or top_k should be configured. When top_k is set, the - default thresholds are [float('-inf')]. - class_id: (Optional) Used with a multi-class model to specify which class - to compute the confusion matrix for. When class_id is used, - metrics_specs.binarize settings must not be present. Only one of - class_id or top_k should be configured. - name: (Optional) Metric name. - **kwargs: (Optional) Additional args to pass along to init (and eventually - on to _metric_computation and _metric_value) - """ - super().__init__( - thresholds=thresholds, - num_thresholds=num_thresholds, - top_k=top_k, - class_id=class_id, - name=name, - **kwargs) - - def _default_threshold(self) -> float: - """Returns default threshold if thresholds or num_thresholds unset.""" - return 0.5 - - @abc.abstractmethod - def result(self, tp: float, tn: float, fp: float, fn: float) -> float: - """Function for computing metric value from TP, TN, FP, FN values.""" - raise NotImplementedError('Must be implemented in subclasses.') - - def _metric_value( - self, key: metric_types.MetricKey, - matrices: binary_confusion_matrices.Matrices) -> Union[float, np.ndarray]: - """Returns metric value associated with matrices. - - Subclasses may override this method. Any additional kwargs passed to - __init__ will be forwarded along to this call. Note that since this method - is the only one that calls the result method, subclasses that override this - method are not required to provide an implementation for the result method. - Args: - key: Metric key. - matrices: Computed binary confusion matrices. - """ - values = [] - for i in range(len(matrices.thresholds)): - values.append( - self.result(matrices.tp[i], matrices.tn[i], matrices.fp[i], - matrices.fn[i])) - return values[0] if len(matrices.thresholds) == 1 else np.array(values) +class ConfusionMatrixMetric(ConfusionMatrixMetricBase): + """Base for confusion matrix metrics.""" + + def __init__( + self, + thresholds: Optional[Union[float, List[float]]] = None, + num_thresholds: Optional[int] = None, + top_k: Optional[int] = None, + class_id: Optional[int] = None, + name: Optional[str] = None, + **kwargs, + ): + """Initializes confusion matrix metric. + + Args: + ---- + thresholds: (Optional) Thresholds to use for calculating the matrices. Use + one of either thresholds or num_thresholds. + num_thresholds: (Optional) Number of thresholds to use for calculating the + matrices. Use one of either thresholds or num_thresholds. + top_k: (Optional) Used with a multi-class model to specify that the top-k + values should be used to compute the confusion matrix. The net effect is + that the non-top-k values are set to -inf and the matrix is then + constructed from the average TP, FP, TN, FN across the classes. When + top_k is used, metrics_specs.binarize settings must not be present. Only + one of class_id or top_k should be configured. When top_k is set, the + default thresholds are [float('-inf')]. + class_id: (Optional) Used with a multi-class model to specify which class + to compute the confusion matrix for. When class_id is used, + metrics_specs.binarize settings must not be present. Only one of + class_id or top_k should be configured. + name: (Optional) Metric name. + **kwargs: (Optional) Additional args to pass along to init (and eventually + on to _metric_computation and _metric_value) + """ + super().__init__( + thresholds=thresholds, + num_thresholds=num_thresholds, + top_k=top_k, + class_id=class_id, + name=name, + **kwargs, + ) + + def _default_threshold(self) -> float: + """Returns default threshold if thresholds or num_thresholds unset.""" + return 0.5 + + @abc.abstractmethod + def result(self, tp: float, tn: float, fp: float, fn: float) -> float: + """Function for computing metric value from TP, TN, FP, FN values.""" + raise NotImplementedError("Must be implemented in subclasses.") + + def _metric_value( + self, key: metric_types.MetricKey, matrices: binary_confusion_matrices.Matrices + ) -> Union[float, np.ndarray]: + """Returns metric value associated with matrices. + + Subclasses may override this method. Any additional kwargs passed to + __init__ will be forwarded along to this call. Note that since this method + is the only one that calls the result method, subclasses that override this + method are not required to provide an implementation for the result method. + + Args: + ---- + key: Metric key. + matrices: Computed binary confusion matrices. + """ + values = [] + for i in range(len(matrices.thresholds)): + values.append( + self.result( + matrices.tp[i], matrices.tn[i], matrices.fp[i], matrices.fn[i] + ) + ) + return values[0] if len(matrices.thresholds) == 1 else np.array(values) class AUC(ConfusionMatrixMetricBase): - """Approximates the AUC (Area under the curve) of the ROC or PR curves. - - The AUC (Area under the curve) of the ROC (Receiver operating - characteristic; default) or PR (Precision Recall) curves are quality measures - of binary classifiers. Unlike the accuracy, and like cross-entropy - losses, ROC-AUC and PR-AUC evaluate all the operational points of a model. - - This class approximates AUCs using a Riemann sum. During the metric - accumulation phase, predictions are accumulated within predefined buckets - by value. The AUC is then computed by interpolating per-bucket averages. These - buckets define the evaluated operational points. - - This metric uses `true_positives`, `true_negatives`, `false_positives` and - `false_negatives` to compute the AUC. To discretize the AUC curve, a linearly - spaced set of thresholds is used to compute pairs of recall and precision - values. The area under the ROC-curve is therefore computed using the height of - the recall values by the false positive rate, while the area under the - PR-curve is the computed using the height of the precision values by the - recall. - - This value is ultimately returned as `auc`, an idempotent operation that - computes the area under a discretized curve of precision versus recall values - (computed using the aforementioned variables). The `num_thresholds` variable - controls the degree of discretization with larger numbers of thresholds more - closely approximating the true AUC. The quality of the approximation may vary - dramatically depending on `num_thresholds`. The `thresholds` parameter can be - used to manually specify thresholds which split the predictions more evenly. - - For a best approximation of the real AUC, `predictions` should be distributed - approximately uniformly in the range [0, 1]. The quality of the AUC - approximation may be poor if this is not the case. Setting `summation_method` - to 'minoring' or 'majoring' can help quantify the error in the approximation - by providing lower or upper bound estimate of the AUC. - - If `sample_weight` is `None`, weights default to 1. - Use `sample_weight` of 0 to mask values. - """ - - def __init__(self, - num_thresholds: Optional[int] = None, - curve: str = 'ROC', - summation_method: str = 'interpolation', - name: Optional[str] = None, - thresholds: Optional[Union[float, List[float]]] = None, - top_k: Optional[int] = None, - class_id: Optional[int] = None): - """Initializes AUC metric. - - Args: - num_thresholds: (Optional) Defaults to 10000. The number of thresholds to - use when discretizing the roc curve. Values must be > 1. - curve: (Optional) Specifies the name of the curve to be computed, 'ROC' - [default] or 'PR' for the Precision-Recall-curve. - summation_method: (Optional) Specifies the [Riemann summation method]( - https://en.wikipedia.org/wiki/Riemann_sum) used. 'interpolation' - (default) applies mid-point summation scheme for `ROC`. For PR-AUC, - interpolates (true/false) positives but not the ratio that is - precision (see Davis & Goadrich 2006 for details); 'minoring' applies - left summation for increasing intervals and right summation for - decreasing intervals; 'majoring' does the opposite. - name: (Optional) string name of the metric instance. - thresholds: (Optional) A list of floating point values to use as the - thresholds for discretizing the curve. If set, the `num_thresholds` - parameter is ignored. Values should be in [0, 1]. Endpoint thresholds - equal to {-epsilon, 1+epsilon} for a small positive epsilon value will - be automatically included with these to correctly handle predictions - equal to exactly 0 or 1. - top_k: (Optional) Used with a multi-class model to specify that the top-k - values should be used to compute the confusion matrix. The net effect is - that the non-top-k values are set to -inf and the matrix is then - constructed from the average TP, FP, TN, FN across the classes. When - top_k is used, metrics_specs.binarize settings must not be present. Only - one of class_id or top_k should be configured. When top_k is set, the - default thresholds are [float('-inf')]. - class_id: (Optional) Used with a multi-class model to specify which class - to compute the confusion matrix for. When class_id is used, - metrics_specs.binarize settings must not be present. Only one of - class_id or top_k should be configured. + """Approximates the AUC (Area under the curve) of the ROC or PR curves. + + The AUC (Area under the curve) of the ROC (Receiver operating + characteristic; default) or PR (Precision Recall) curves are quality measures + of binary classifiers. Unlike the accuracy, and like cross-entropy + losses, ROC-AUC and PR-AUC evaluate all the operational points of a model. + + This class approximates AUCs using a Riemann sum. During the metric + accumulation phase, predictions are accumulated within predefined buckets + by value. The AUC is then computed by interpolating per-bucket averages. These + buckets define the evaluated operational points. + + This metric uses `true_positives`, `true_negatives`, `false_positives` and + `false_negatives` to compute the AUC. To discretize the AUC curve, a linearly + spaced set of thresholds is used to compute pairs of recall and precision + values. The area under the ROC-curve is therefore computed using the height of + the recall values by the false positive rate, while the area under the + PR-curve is the computed using the height of the precision values by the + recall. + + This value is ultimately returned as `auc`, an idempotent operation that + computes the area under a discretized curve of precision versus recall values + (computed using the aforementioned variables). The `num_thresholds` variable + controls the degree of discretization with larger numbers of thresholds more + closely approximating the true AUC. The quality of the approximation may vary + dramatically depending on `num_thresholds`. The `thresholds` parameter can be + used to manually specify thresholds which split the predictions more evenly. + + For a best approximation of the real AUC, `predictions` should be distributed + approximately uniformly in the range [0, 1]. The quality of the AUC + approximation may be poor if this is not the case. Setting `summation_method` + to 'minoring' or 'majoring' can help quantify the error in the approximation + by providing lower or upper bound estimate of the AUC. + + If `sample_weight` is `None`, weights default to 1. + Use `sample_weight` of 0 to mask values. """ - super().__init__( - num_thresholds=num_thresholds, - thresholds=thresholds, - curve=curve, - summation_method=summation_method, - name=name, - top_k=top_k, - class_id=class_id) - - def _default_name(self) -> str: - return AUC_NAME - - def _metric_value(self, curve: str, summation_method: str, - key: metric_types.MetricKey, - matrices: binary_confusion_matrices.Matrices) -> float: - del key - curve = AUCCurve(curve.upper()) - summation_method = AUCSummationMethod(summation_method.lower()) - num_thresholds = len(matrices.thresholds) - tp, tn = np.array(matrices.tp), np.array(matrices.tn) - fp, fn = np.array(matrices.fp), np.array(matrices.fn) - if (curve == AUCCurve.PR and - summation_method == AUCSummationMethod.INTERPOLATION): - dtp = tp[:num_thresholds - 1] - tp[1:] - p = tp + fp - dp = p[:num_thresholds - 1] - p[1:] - prec_slope = dtp / np.maximum(dp, 0) - intercept = tp[1:] - prec_slope * p[1:] - safe_p_ratio = np.where( - np.logical_and(p[:num_thresholds - 1] > 0, p[1:] > 0), - p[:num_thresholds - 1] / np.maximum(p[1:], 0), np.ones_like(p[1:])) - pr_auc_increment = ( - prec_slope * (dtp + intercept * np.log(safe_p_ratio)) / - np.maximum(tp[1:] + fn[1:], 0)) - return np.nansum(pr_auc_increment) - - # Set `x` and `y` values for the curves based on `curve` config. - recall = tp / (tp + fn) - if curve == AUCCurve.ROC: - fp_rate = fp / (fp + tn) - x = fp_rate - y = recall - elif curve == AUCCurve.PR: - precision = tp / (tp + fp) - x = recall - y = precision - - # Find the rectangle heights based on `summation_method`. - if summation_method == AUCSummationMethod.INTERPOLATION: - heights = (y[:num_thresholds - 1] + y[1:]) / 2. - elif summation_method == AUCSummationMethod.MINORING: - heights = np.minimum(y[:num_thresholds - 1], y[1:]) - elif summation_method == AUCSummationMethod.MAJORING: - heights = np.maximum(y[:num_thresholds - 1], y[1:]) - - # Sum up the areas of all the rectangles. - return np.nansum((x[:num_thresholds - 1] - x[1:]) * heights) + + def __init__( + self, + num_thresholds: Optional[int] = None, + curve: str = "ROC", + summation_method: str = "interpolation", + name: Optional[str] = None, + thresholds: Optional[Union[float, List[float]]] = None, + top_k: Optional[int] = None, + class_id: Optional[int] = None, + ): + """Initializes AUC metric. + + Args: + ---- + num_thresholds: (Optional) Defaults to 10000. The number of thresholds to + use when discretizing the roc curve. Values must be > 1. + curve: (Optional) Specifies the name of the curve to be computed, 'ROC' + [default] or 'PR' for the Precision-Recall-curve. + summation_method: (Optional) Specifies the [Riemann summation method]( + https://en.wikipedia.org/wiki/Riemann_sum) used. 'interpolation' + (default) applies mid-point summation scheme for `ROC`. For PR-AUC, + interpolates (true/false) positives but not the ratio that is + precision (see Davis & Goadrich 2006 for details); 'minoring' applies + left summation for increasing intervals and right summation for + decreasing intervals; 'majoring' does the opposite. + name: (Optional) string name of the metric instance. + thresholds: (Optional) A list of floating point values to use as the + thresholds for discretizing the curve. If set, the `num_thresholds` + parameter is ignored. Values should be in [0, 1]. Endpoint thresholds + equal to {-epsilon, 1+epsilon} for a small positive epsilon value will + be automatically included with these to correctly handle predictions + equal to exactly 0 or 1. + top_k: (Optional) Used with a multi-class model to specify that the top-k + values should be used to compute the confusion matrix. The net effect is + that the non-top-k values are set to -inf and the matrix is then + constructed from the average TP, FP, TN, FN across the classes. When + top_k is used, metrics_specs.binarize settings must not be present. Only + one of class_id or top_k should be configured. When top_k is set, the + default thresholds are [float('-inf')]. + class_id: (Optional) Used with a multi-class model to specify which class + to compute the confusion matrix for. When class_id is used, + metrics_specs.binarize settings must not be present. Only one of + class_id or top_k should be configured. + """ + super().__init__( + num_thresholds=num_thresholds, + thresholds=thresholds, + curve=curve, + summation_method=summation_method, + name=name, + top_k=top_k, + class_id=class_id, + ) + + def _default_name(self) -> str: + return AUC_NAME + + def _metric_value( + self, + curve: str, + summation_method: str, + key: metric_types.MetricKey, + matrices: binary_confusion_matrices.Matrices, + ) -> float: + del key + curve = AUCCurve(curve.upper()) + summation_method = AUCSummationMethod(summation_method.lower()) + num_thresholds = len(matrices.thresholds) + tp, tn = np.array(matrices.tp), np.array(matrices.tn) + fp, fn = np.array(matrices.fp), np.array(matrices.fn) + if ( + curve == AUCCurve.PR + and summation_method == AUCSummationMethod.INTERPOLATION + ): + dtp = tp[: num_thresholds - 1] - tp[1:] + p = tp + fp + dp = p[: num_thresholds - 1] - p[1:] + prec_slope = dtp / np.maximum(dp, 0) + intercept = tp[1:] - prec_slope * p[1:] + safe_p_ratio = np.where( + np.logical_and(p[: num_thresholds - 1] > 0, p[1:] > 0), + p[: num_thresholds - 1] / np.maximum(p[1:], 0), + np.ones_like(p[1:]), + ) + pr_auc_increment = ( + prec_slope + * (dtp + intercept * np.log(safe_p_ratio)) + / np.maximum(tp[1:] + fn[1:], 0) + ) + return np.nansum(pr_auc_increment) + + # Set `x` and `y` values for the curves based on `curve` config. + recall = tp / (tp + fn) + if curve == AUCCurve.ROC: + fp_rate = fp / (fp + tn) + x = fp_rate + y = recall + elif curve == AUCCurve.PR: + precision = tp / (tp + fp) + x = recall + y = precision + + # Find the rectangle heights based on `summation_method`. + if summation_method == AUCSummationMethod.INTERPOLATION: + heights = (y[: num_thresholds - 1] + y[1:]) / 2.0 + elif summation_method == AUCSummationMethod.MINORING: + heights = np.minimum(y[: num_thresholds - 1], y[1:]) + elif summation_method == AUCSummationMethod.MAJORING: + heights = np.maximum(y[: num_thresholds - 1], y[1:]) + + # Sum up the areas of all the rectangles. + return np.nansum((x[: num_thresholds - 1] - x[1:]) * heights) metric_types.register_metric(AUC) class AUCPrecisionRecall(AUC): - """Alias for AUC(curve='PR').""" - - def __init__(self, - num_thresholds: Optional[int] = None, - summation_method: str = 'interpolation', - name: Optional[str] = None, - thresholds: Optional[Union[float, List[float]]] = None, - top_k: Optional[int] = None, - class_id: Optional[int] = None): - """Initializes AUCPrecisionRecall metric. - - Args: - num_thresholds: (Optional) Defaults to 10000. The number of thresholds to - use when discretizing the roc curve. Values must be > 1. - summation_method: (Optional) Specifies the [Riemann summation method]( - https://en.wikipedia.org/wiki/Riemann_sum) used. 'interpolation' - interpolates (true/false) positives but not the ratio that is - precision (see Davis & Goadrich 2006 for details); 'minoring' applies - left summation for increasing intervals and right summation for - decreasing intervals; 'majoring' does the opposite. - name: (Optional) string name of the metric instance. - thresholds: (Optional) A list of floating point values to use as the - thresholds for discretizing the curve. If set, the `num_thresholds` - parameter is ignored. Values should be in [0, 1]. Endpoint thresholds - equal to {-epsilon, 1+epsilon} for a small positive epsilon value will - be automatically included with these to correctly handle predictions - equal to exactly 0 or 1. - top_k: (Optional) Used with a multi-class model to specify that the top-k - values should be used to compute the confusion matrix. The net effect is - that the non-top-k values are set to -inf and the matrix is then - constructed from the average TP, FP, TN, FN across the classes. When - top_k is used, metrics_specs.binarize settings must not be present. Only - one of class_id or top_k should be configured. When top_k is set, the - default thresholds are [float('-inf')]. - class_id: (Optional) Used with a multi-class model to specify which class - to compute the confusion matrix for. When class_id is used, - metrics_specs.binarize settings must not be present. Only one of - class_id or top_k should be configured. - """ - super().__init__( - num_thresholds=num_thresholds, - thresholds=thresholds, - curve='PR', - summation_method=summation_method, - name=name, - top_k=top_k, - class_id=class_id) - - def get_config(self) -> Dict[str, Any]: - """Returns serializable config.""" - # Remove the irrelevant 'curve' keyword inherited from parent class AUC(). - # This is needed when the __init__ of the child class has a different set of - # kwargs than that of its parent class. - result = super().get_config() - del result['curve'] - return result - - def _default_name(self) -> str: - return AUC_PRECISION_RECALL_NAME + """Alias for AUC(curve='PR').""" + + def __init__( + self, + num_thresholds: Optional[int] = None, + summation_method: str = "interpolation", + name: Optional[str] = None, + thresholds: Optional[Union[float, List[float]]] = None, + top_k: Optional[int] = None, + class_id: Optional[int] = None, + ): + """Initializes AUCPrecisionRecall metric. + + Args: + ---- + num_thresholds: (Optional) Defaults to 10000. The number of thresholds to + use when discretizing the roc curve. Values must be > 1. + summation_method: (Optional) Specifies the [Riemann summation method]( + https://en.wikipedia.org/wiki/Riemann_sum) used. 'interpolation' + interpolates (true/false) positives but not the ratio that is + precision (see Davis & Goadrich 2006 for details); 'minoring' applies + left summation for increasing intervals and right summation for + decreasing intervals; 'majoring' does the opposite. + name: (Optional) string name of the metric instance. + thresholds: (Optional) A list of floating point values to use as the + thresholds for discretizing the curve. If set, the `num_thresholds` + parameter is ignored. Values should be in [0, 1]. Endpoint thresholds + equal to {-epsilon, 1+epsilon} for a small positive epsilon value will + be automatically included with these to correctly handle predictions + equal to exactly 0 or 1. + top_k: (Optional) Used with a multi-class model to specify that the top-k + values should be used to compute the confusion matrix. The net effect is + that the non-top-k values are set to -inf and the matrix is then + constructed from the average TP, FP, TN, FN across the classes. When + top_k is used, metrics_specs.binarize settings must not be present. Only + one of class_id or top_k should be configured. When top_k is set, the + default thresholds are [float('-inf')]. + class_id: (Optional) Used with a multi-class model to specify which class + to compute the confusion matrix for. When class_id is used, + metrics_specs.binarize settings must not be present. Only one of + class_id or top_k should be configured. + """ + super().__init__( + num_thresholds=num_thresholds, + thresholds=thresholds, + curve="PR", + summation_method=summation_method, + name=name, + top_k=top_k, + class_id=class_id, + ) + + def get_config(self) -> Dict[str, Any]: + """Returns serializable config.""" + # Remove the irrelevant 'curve' keyword inherited from parent class AUC(). + # This is needed when the __init__ of the child class has a different set of + # kwargs than that of its parent class. + result = super().get_config() + del result["curve"] + return result + + def _default_name(self) -> str: + return AUC_PRECISION_RECALL_NAME metric_types.register_metric(AUCPrecisionRecall) class SensitivityAtSpecificity(ConfusionMatrixMetricBase): - """Computes best sensitivity where specificity is >= specified value. + """Computes best sensitivity where specificity is >= specified value. - `Sensitivity` measures the proportion of actual positives that are correctly - identified as such (tp / (tp + fn)). - `Specificity` measures the proportion of actual negatives that are correctly - identified as such (tn / (tn + fp)). + `Sensitivity` measures the proportion of actual positives that are correctly + identified as such (tp / (tp + fn)). + `Specificity` measures the proportion of actual negatives that are correctly + identified as such (tn / (tn + fp)). - The threshold for the given specificity value is computed and used to evaluate - the corresponding sensitivity. + The threshold for the given specificity value is computed and used to evaluate + the corresponding sensitivity. - If `sample_weight` is `None`, weights default to 1. - Use `sample_weight` of 0 to mask values. + If `sample_weight` is `None`, weights default to 1. + Use `sample_weight` of 0 to mask values. - For additional information about specificity and sensitivity, see - [the following](https://en.wikipedia.org/wiki/Sensitivity_and_specificity). - """ - - def __init__(self, - specificity: Union[float, List[float]], - num_thresholds: Optional[int] = None, - class_id: Optional[int] = None, - name: Optional[str] = None, - top_k: Optional[int] = None): - """Initializes SensitivityAtSpecificity metric. - - - Args: - specificity: A scalar value in range `[0, 1]`. - num_thresholds: (Optional) Defaults to 1000. The number of thresholds to - use for matching the given specificity. - class_id: (Optional) Used with a multi-class model to specify which class - to compute the confusion matrix for. When class_id is used, - metrics_specs.binarize settings must not be present. Only one of - class_id or top_k should be configured. - name: (Optional) string name of the metric instance. - top_k: (Optional) Used with a multi-class model to specify that the top-k - values should be used to compute the confusion matrix. The net effect is - that the non-top-k values are set to -inf and the matrix is then - constructed from the average TP, FP, TN, FN across the classes. When - top_k is used, metrics_specs.binarize settings must not be present. Only - one of class_id or top_k should be configured. When top_k is set, the - default thresholds are [float('-inf')]. + For additional information about specificity and sensitivity, see + [the following](https://en.wikipedia.org/wiki/Sensitivity_and_specificity). """ - super().__init__( - num_thresholds=num_thresholds, - specificity=specificity, - class_id=class_id, - name=name, - top_k=top_k) - - def _default_name(self) -> str: - return SENSITIVITY_AT_SPECIFICITY_NAME - - def _metric_value( - self, specificity: Union[float, List[float]], key: metric_types.MetricKey, - matrices: binary_confusion_matrices.Matrices) -> Union[float, np.ndarray]: - del key - tp, tn = np.array(matrices.tp), np.array(matrices.tn) - fp, fn = np.array(matrices.fp), np.array(matrices.fn) - specificities = tn / (tn + fp) - sensitivities = tp / (tp + fn) - return _find_max_under_constraint(specificities, sensitivities, specificity) + + def __init__( + self, + specificity: Union[float, List[float]], + num_thresholds: Optional[int] = None, + class_id: Optional[int] = None, + name: Optional[str] = None, + top_k: Optional[int] = None, + ): + """Initializes SensitivityAtSpecificity metric. + + Args: + ---- + specificity: A scalar value in range `[0, 1]`. + num_thresholds: (Optional) Defaults to 1000. The number of thresholds to + use for matching the given specificity. + class_id: (Optional) Used with a multi-class model to specify which class + to compute the confusion matrix for. When class_id is used, + metrics_specs.binarize settings must not be present. Only one of + class_id or top_k should be configured. + name: (Optional) string name of the metric instance. + top_k: (Optional) Used with a multi-class model to specify that the top-k + values should be used to compute the confusion matrix. The net effect is + that the non-top-k values are set to -inf and the matrix is then + constructed from the average TP, FP, TN, FN across the classes. When + top_k is used, metrics_specs.binarize settings must not be present. Only + one of class_id or top_k should be configured. When top_k is set, the + default thresholds are [float('-inf')]. + """ + super().__init__( + num_thresholds=num_thresholds, + specificity=specificity, + class_id=class_id, + name=name, + top_k=top_k, + ) + + def _default_name(self) -> str: + return SENSITIVITY_AT_SPECIFICITY_NAME + + def _metric_value( + self, + specificity: Union[float, List[float]], + key: metric_types.MetricKey, + matrices: binary_confusion_matrices.Matrices, + ) -> Union[float, np.ndarray]: + del key + tp, tn = np.array(matrices.tp), np.array(matrices.tn) + fp, fn = np.array(matrices.fp), np.array(matrices.fn) + specificities = tn / (tn + fp) + sensitivities = tp / (tp + fn) + return _find_max_under_constraint(specificities, sensitivities, specificity) metric_types.register_metric(SensitivityAtSpecificity) class SpecificityAtSensitivity(ConfusionMatrixMetricBase): - """Computes best specificity where sensitivity is >= specified value. - - `Sensitivity` measures the proportion of actual positives that are correctly - identified as such (tp / (tp + fn)). - `Specificity` measures the proportion of actual negatives that are correctly - identified as such (tn / (tn + fp)). - - The threshold for the given sensitivity value is computed and used to evaluate - the corresponding specificity. + """Computes best specificity where sensitivity is >= specified value. - If `sample_weight` is `None`, weights default to 1. - Use `sample_weight` of 0 to mask values. + `Sensitivity` measures the proportion of actual positives that are correctly + identified as such (tp / (tp + fn)). + `Specificity` measures the proportion of actual negatives that are correctly + identified as such (tn / (tn + fp)). - For additional information about specificity and sensitivity, see - [the following](https://en.wikipedia.org/wiki/Sensitivity_and_specificity). - """ + The threshold for the given sensitivity value is computed and used to evaluate + the corresponding specificity. - def __init__(self, - sensitivity: float, - num_thresholds: Optional[int] = None, - class_id: Optional[int] = None, - name: Optional[str] = None, - top_k: Optional[int] = None): - """Initializes SpecificityAtSensitivity metric. + If `sample_weight` is `None`, weights default to 1. + Use `sample_weight` of 0 to mask values. - - Args: - sensitivity: A scalar value or a list of scalar value in range `[0, 1]`. - num_thresholds: (Optional) Defaults to 1000. The number of thresholds to - use for matching the given sensitivity. - class_id: (Optional) Used with a multi-class model to specify which class - to compute the confusion matrix for. When class_id is used, - metrics_specs.binarize settings must not be present. Only one of - class_id or top_k should be configured. - name: (Optional) string name of the metric instance. - top_k: (Optional) Used with a multi-class model to specify that the top-k - values should be used to compute the confusion matrix. The net effect is - that the non-top-k values are set to -inf and the matrix is then - constructed from the average TP, FP, TN, FN across the classes. When - top_k is used, metrics_specs.binarize settings must not be present. Only - one of class_id or top_k should be configured. When top_k is set, the - default thresholds are [float('-inf')]. + For additional information about specificity and sensitivity, see + [the following](https://en.wikipedia.org/wiki/Sensitivity_and_specificity). """ - super().__init__( - num_thresholds=num_thresholds, - sensitivity=sensitivity, - class_id=class_id, - name=name, - top_k=top_k) - - def _default_name(self) -> str: - return SPECIFICITY_AT_SENSITIVITY_NAME - - def _metric_value( - self, sensitivity: Union[float, List[float]], key: metric_types.MetricKey, - matrices: binary_confusion_matrices.Matrices) -> Union[float, np.ndarray]: - del key - tp, tn = np.array(matrices.tp), np.array(matrices.tn) - fp, fn = np.array(matrices.fp), np.array(matrices.fn) - specificities = tn / (tn + fp) - sensitivities = tp / (tp + fn) - return _find_max_under_constraint(sensitivities, specificities, sensitivity) + + def __init__( + self, + sensitivity: float, + num_thresholds: Optional[int] = None, + class_id: Optional[int] = None, + name: Optional[str] = None, + top_k: Optional[int] = None, + ): + """Initializes SpecificityAtSensitivity metric. + + Args: + ---- + sensitivity: A scalar value or a list of scalar value in range `[0, 1]`. + num_thresholds: (Optional) Defaults to 1000. The number of thresholds to + use for matching the given sensitivity. + class_id: (Optional) Used with a multi-class model to specify which class + to compute the confusion matrix for. When class_id is used, + metrics_specs.binarize settings must not be present. Only one of + class_id or top_k should be configured. + name: (Optional) string name of the metric instance. + top_k: (Optional) Used with a multi-class model to specify that the top-k + values should be used to compute the confusion matrix. The net effect is + that the non-top-k values are set to -inf and the matrix is then + constructed from the average TP, FP, TN, FN across the classes. When + top_k is used, metrics_specs.binarize settings must not be present. Only + one of class_id or top_k should be configured. When top_k is set, the + default thresholds are [float('-inf')]. + """ + super().__init__( + num_thresholds=num_thresholds, + sensitivity=sensitivity, + class_id=class_id, + name=name, + top_k=top_k, + ) + + def _default_name(self) -> str: + return SPECIFICITY_AT_SENSITIVITY_NAME + + def _metric_value( + self, + sensitivity: Union[float, List[float]], + key: metric_types.MetricKey, + matrices: binary_confusion_matrices.Matrices, + ) -> Union[float, np.ndarray]: + del key + tp, tn = np.array(matrices.tp), np.array(matrices.tn) + fp, fn = np.array(matrices.fp), np.array(matrices.fn) + specificities = tn / (tn + fp) + sensitivities = tp / (tp + fn) + return _find_max_under_constraint(sensitivities, specificities, sensitivity) metric_types.register_metric(SpecificityAtSensitivity) class PrecisionAtRecall(ConfusionMatrixMetricBase): - """Computes best precision where recall is >= specified value. - - The threshold for the given recall value is computed and used to evaluate the - corresponding precision. - - If `sample_weight` is `None`, weights default to 1. - Use `sample_weight` of 0 to mask values. - """ - - def __init__(self, - recall: Union[float, List[float]], - thresholds: Optional[List[float]] = None, - num_thresholds: Optional[int] = None, - class_id: Optional[int] = None, - name: Optional[str] = None, - top_k: Optional[int] = None, - **kwargs): - """Initializes PrecisionAtRecall metric. + """Computes best precision where recall is >= specified value. + The threshold for the given recall value is computed and used to evaluate the + corresponding precision. - Args: - recall: A scalar or a list of scalar values in range `[0, 1]`. - thresholds: (Optional) Thresholds to use for calculating the matrices. Use - one of either thresholds or num_thresholds. - num_thresholds: (Optional) Defaults to 1000. The number of thresholds to - use for matching the given recall. - class_id: (Optional) Used with a multi-class model to specify which class - to compute the confusion matrix for. When class_id is used, - metrics_specs.binarize settings must not be present. Only one of - class_id or top_k should be configured. - name: (Optional) string name of the metric instance. - top_k: (Optional) Used with a multi-class model to specify that the top-k - values should be used to compute the confusion matrix. The net effect is - that the non-top-k values are set to -inf and the matrix is then - constructed from the average TP, FP, TN, FN across the classes. When - top_k is used, metrics_specs.binarize settings must not be present. Only - one of class_id or top_k should be configured. When top_k is set, the - default thresholds are [float('-inf')]. - **kwargs: (Optional) Additional args to pass along to init (and eventually - on to _metric_computation and _metric_value) + If `sample_weight` is `None`, weights default to 1. + Use `sample_weight` of 0 to mask values. """ - for r in [recall] if isinstance(recall, float) else recall: - if r < 0 or r > 1: - raise ValueError('Argument `recall` must be in the range [0, 1]. ' - f'Received: recall={r}') - - super().__init__( - thresholds=thresholds, - num_thresholds=num_thresholds, - recall=recall, - class_id=class_id, - name=name, - top_k=top_k, - **kwargs) - - def _default_name(self) -> str: - return PRECISION_AT_RECALL_NAME - - def _metric_value( - self, recall: Union[float, List[float]], key: metric_types.MetricKey, - matrices: binary_confusion_matrices.Matrices) -> Union[float, np.ndarray]: - del key - tp = np.array(matrices.tp) - fp, fn = np.array(matrices.fp), np.array(matrices.fn) - recalls = tp / (tp + fn) - precisions = tp / (tp + fp) - return _find_max_under_constraint(recalls, precisions, recall) + + def __init__( + self, + recall: Union[float, List[float]], + thresholds: Optional[List[float]] = None, + num_thresholds: Optional[int] = None, + class_id: Optional[int] = None, + name: Optional[str] = None, + top_k: Optional[int] = None, + **kwargs, + ): + """Initializes PrecisionAtRecall metric. + + Args: + ---- + recall: A scalar or a list of scalar values in range `[0, 1]`. + thresholds: (Optional) Thresholds to use for calculating the matrices. Use + one of either thresholds or num_thresholds. + num_thresholds: (Optional) Defaults to 1000. The number of thresholds to + use for matching the given recall. + class_id: (Optional) Used with a multi-class model to specify which class + to compute the confusion matrix for. When class_id is used, + metrics_specs.binarize settings must not be present. Only one of + class_id or top_k should be configured. + name: (Optional) string name of the metric instance. + top_k: (Optional) Used with a multi-class model to specify that the top-k + values should be used to compute the confusion matrix. The net effect is + that the non-top-k values are set to -inf and the matrix is then + constructed from the average TP, FP, TN, FN across the classes. When + top_k is used, metrics_specs.binarize settings must not be present. Only + one of class_id or top_k should be configured. When top_k is set, the + default thresholds are [float('-inf')]. + **kwargs: (Optional) Additional args to pass along to init (and eventually + on to _metric_computation and _metric_value) + """ + for r in [recall] if isinstance(recall, float) else recall: + if r < 0 or r > 1: + raise ValueError( + "Argument `recall` must be in the range [0, 1]. " + f"Received: recall={r}" + ) + + super().__init__( + thresholds=thresholds, + num_thresholds=num_thresholds, + recall=recall, + class_id=class_id, + name=name, + top_k=top_k, + **kwargs, + ) + + def _default_name(self) -> str: + return PRECISION_AT_RECALL_NAME + + def _metric_value( + self, + recall: Union[float, List[float]], + key: metric_types.MetricKey, + matrices: binary_confusion_matrices.Matrices, + ) -> Union[float, np.ndarray]: + del key + tp = np.array(matrices.tp) + fp, fn = np.array(matrices.fp), np.array(matrices.fn) + recalls = tp / (tp + fn) + precisions = tp / (tp + fp) + return _find_max_under_constraint(recalls, precisions, recall) metric_types.register_metric(PrecisionAtRecall) class RecallAtPrecision(ConfusionMatrixMetricBase): - """Computes best recall where precision is >= specified value. - - For a given score-label-distribution the required precision might not - be achievable, in this case 0.0 is returned as recall. - - This metric creates three local variables, `true_positives`, `false_positives` - and `false_negatives` that are used to compute the recall at the given - precision. The threshold for the given precision value is computed and used to - evaluate the corresponding recall. + """Computes best recall where precision is >= specified value. - If `sample_weight` is `None`, weights default to 1. - Use `sample_weight` of 0 to mask values. - """ + For a given score-label-distribution the required precision might not + be achievable, in this case 0.0 is returned as recall. - def __init__(self, - precision: float, - num_thresholds: Optional[int] = None, - class_id: Optional[int] = None, - name: Optional[str] = None, - top_k: Optional[int] = None): - """Initializes RecallAtPrecision. + This metric creates three local variables, `true_positives`, `false_positives` + and `false_negatives` that are used to compute the recall at the given + precision. The threshold for the given precision value is computed and used to + evaluate the corresponding recall. - - Args: - precision: A scalar value in range `[0, 1]`. - num_thresholds: (Optional) Defaults to 1000. The number of thresholds to - use for matching the given precision. - class_id: (Optional) Used with a multi-class model to specify which class - to compute the confusion matrix for. When class_id is used, - metrics_specs.binarize settings must not be present. Only one of - class_id or top_k should be configured. - name: (Optional) string name of the metric instance. - top_k: (Optional) Used with a multi-class model to specify that the top-k - values should be used to compute the confusion matrix. The net effect is - that the non-top-k values are set to -inf and the matrix is then - constructed from the average TP, FP, TN, FN across the classes. When - top_k is used, metrics_specs.binarize settings must not be present. Only - one of class_id or top_k should be configured. When top_k is set, the - default thresholds are [float('-inf')]. + If `sample_weight` is `None`, weights default to 1. + Use `sample_weight` of 0 to mask values. """ - if precision < 0 or precision > 1: - raise ValueError('Argument `precision` must be in the range [0, 1]. ' - f'Received: precision={precision}') - super().__init__( - num_thresholds=num_thresholds, - precision=precision, - class_id=class_id, - name=name, - top_k=top_k) - - def _default_name(self) -> str: - return RECALL_AT_PRECISION_NAME - - def _metric_value( - self, precision: Union[float, List[float]], key: metric_types.MetricKey, - matrices: binary_confusion_matrices.Matrices) -> Union[float, np.ndarray]: - del key - tp = np.array(matrices.tp) - fp, fn = np.array(matrices.fp), np.array(matrices.fn) - recalls = tp / (tp + fn) - precisions = tp / (tp + fp) - return _find_max_under_constraint(precisions, recalls, precision) + + def __init__( + self, + precision: float, + num_thresholds: Optional[int] = None, + class_id: Optional[int] = None, + name: Optional[str] = None, + top_k: Optional[int] = None, + ): + """Initializes RecallAtPrecision. + + Args: + ---- + precision: A scalar value in range `[0, 1]`. + num_thresholds: (Optional) Defaults to 1000. The number of thresholds to + use for matching the given precision. + class_id: (Optional) Used with a multi-class model to specify which class + to compute the confusion matrix for. When class_id is used, + metrics_specs.binarize settings must not be present. Only one of + class_id or top_k should be configured. + name: (Optional) string name of the metric instance. + top_k: (Optional) Used with a multi-class model to specify that the top-k + values should be used to compute the confusion matrix. The net effect is + that the non-top-k values are set to -inf and the matrix is then + constructed from the average TP, FP, TN, FN across the classes. When + top_k is used, metrics_specs.binarize settings must not be present. Only + one of class_id or top_k should be configured. When top_k is set, the + default thresholds are [float('-inf')]. + """ + if precision < 0 or precision > 1: + raise ValueError( + "Argument `precision` must be in the range [0, 1]. " + f"Received: precision={precision}" + ) + super().__init__( + num_thresholds=num_thresholds, + precision=precision, + class_id=class_id, + name=name, + top_k=top_k, + ) + + def _default_name(self) -> str: + return RECALL_AT_PRECISION_NAME + + def _metric_value( + self, + precision: Union[float, List[float]], + key: metric_types.MetricKey, + matrices: binary_confusion_matrices.Matrices, + ) -> Union[float, np.ndarray]: + del key + tp = np.array(matrices.tp) + fp, fn = np.array(matrices.fp), np.array(matrices.fn) + recalls = tp / (tp + fn) + precisions = tp / (tp + fp) + return _find_max_under_constraint(precisions, recalls, precision) metric_types.register_metric(RecallAtPrecision) class RecallAtFalsePositiveRate(ConfusionMatrixMetricBase): - """Computes best recall where false positive rate is <= specified value. - - For a given score-label-distribution the required false positive rate might - not be achievable, in this case 0.0 is returned as recall. - - This metric creates four local variables, `true_positives`, `false_positives`, - `false_negatives` and `true_negatives` that are used to compute the recall at - the given false positive rate. - The threshold for the given false positive rate value is - computed and used to evaluate the corresponding recall. - - If `sample_weight` is `None`, weights default to 1. - Use `sample_weight` of 0 to mask values. - """ - - def __init__( - self, - false_positive_rate: float, - num_thresholds: Optional[int] = None, - class_id: Optional[int] = None, - name: Optional[str] = None, - top_k: Optional[int] = None, - ): - """Initializes RecallAtFalsePositiveRate. + """Computes best recall where false positive rate is <= specified value. - Args: - false_positive_rate: A scalar value in range `[0, 1]`. - num_thresholds: (Optional) Defaults to 1000. The number of thresholds to - use for matching the given precision. - class_id: (Optional) Used with a multi-class model to specify which class - to compute the confusion matrix for. When class_id is used, - metrics_specs.binarize settings must not be present. Only one of - class_id or top_k should be configured. - name: (Optional) string name of the metric instance. - top_k: (Optional) Used with a multi-class model to specify that the top-k - values should be used to compute the confusion matrix. The net effect is - that the non-top-k values are set to -inf and the matrix is then - constructed from the average TP, FP, TN, FN across the classes. When - top_k is used, metrics_specs.binarize settings must not be present. Only - one of class_id or top_k should be configured. When top_k is set, the - default thresholds are [float('-inf')]. + For a given score-label-distribution the required false positive rate might + not be achievable, in this case 0.0 is returned as recall. + + This metric creates four local variables, `true_positives`, `false_positives`, + `false_negatives` and `true_negatives` that are used to compute the recall at + the given false positive rate. + The threshold for the given false positive rate value is + computed and used to evaluate the corresponding recall. + + If `sample_weight` is `None`, weights default to 1. + Use `sample_weight` of 0 to mask values. """ - if false_positive_rate < 0 or false_positive_rate > 1: - raise ValueError( - 'Argument `false_positive_rate` must be in the range' - f'[0, 1]. Received: false_positive_rate={false_positive_rate}' - ) - super().__init__( - false_positive_rate=false_positive_rate, - num_thresholds=num_thresholds, - class_id=class_id, - name=name, - top_k=top_k, - ) - - def _default_name(self) -> str: - return RECALL_AT_FALSE_POSITIVE_RATE_NAME - - def _metric_value( - self, - false_positive_rate: Union[float, List[float]], - key: metric_types.MetricKey, - matrices: binary_confusion_matrices.Matrices, - ) -> Union[float, np.ndarray]: - del key - tp, tn = np.array(matrices.tp), np.array(matrices.tn) - fp, fn = np.array(matrices.fp), np.array(matrices.fn) - false_positive_rates = fp / (fp + tn) - recalls = tp / (tp + fn) - return _find_max_under_constraint( - constrained=false_positive_rates, - dependent=recalls, - values=false_positive_rate, - operator=np.less_equal, - ) + + def __init__( + self, + false_positive_rate: float, + num_thresholds: Optional[int] = None, + class_id: Optional[int] = None, + name: Optional[str] = None, + top_k: Optional[int] = None, + ): + """Initializes RecallAtFalsePositiveRate. + + Args: + ---- + false_positive_rate: A scalar value in range `[0, 1]`. + num_thresholds: (Optional) Defaults to 1000. The number of thresholds to + use for matching the given precision. + class_id: (Optional) Used with a multi-class model to specify which class + to compute the confusion matrix for. When class_id is used, + metrics_specs.binarize settings must not be present. Only one of + class_id or top_k should be configured. + name: (Optional) string name of the metric instance. + top_k: (Optional) Used with a multi-class model to specify that the top-k + values should be used to compute the confusion matrix. The net effect is + that the non-top-k values are set to -inf and the matrix is then + constructed from the average TP, FP, TN, FN across the classes. When + top_k is used, metrics_specs.binarize settings must not be present. Only + one of class_id or top_k should be configured. When top_k is set, the + default thresholds are [float('-inf')]. + """ + if false_positive_rate < 0 or false_positive_rate > 1: + raise ValueError( + "Argument `false_positive_rate` must be in the range" + f"[0, 1]. Received: false_positive_rate={false_positive_rate}" + ) + super().__init__( + false_positive_rate=false_positive_rate, + num_thresholds=num_thresholds, + class_id=class_id, + name=name, + top_k=top_k, + ) + + def _default_name(self) -> str: + return RECALL_AT_FALSE_POSITIVE_RATE_NAME + + def _metric_value( + self, + false_positive_rate: Union[float, List[float]], + key: metric_types.MetricKey, + matrices: binary_confusion_matrices.Matrices, + ) -> Union[float, np.ndarray]: + del key + tp, tn = np.array(matrices.tp), np.array(matrices.tn) + fp, fn = np.array(matrices.fp), np.array(matrices.fn) + false_positive_rates = fp / (fp + tn) + recalls = tp / (tp + fn) + return _find_max_under_constraint( + constrained=false_positive_rates, + dependent=recalls, + values=false_positive_rate, + operator=np.less_equal, + ) metric_types.register_metric(RecallAtFalsePositiveRate) class TruePositives(ConfusionMatrixMetric): - """Calculates the number of true positives. - - If `sample_weight` is given, calculates the sum of the weights of - true positives. This metric creates one local variable, `true_positives` - that is used to keep track of the number of true positives. - - If `sample_weight` is `None`, weights default to 1. - Use `sample_weight` of 0 to mask values. - """ + """Calculates the number of true positives. - def __init__(self, - thresholds: Optional[Union[float, List[float]]] = None, - name: Optional[str] = None, - top_k: Optional[int] = None, - class_id: Optional[int] = None): - """Initializes TruePositives metric. + If `sample_weight` is given, calculates the sum of the weights of + true positives. This metric creates one local variable, `true_positives` + that is used to keep track of the number of true positives. - Args: - thresholds: (Optional) Defaults to [0.5]. A float value or a python - list/tuple of float threshold values in [0, 1]. A threshold is compared - with prediction values to determine the truth value of predictions - (i.e., above the threshold is `true`, below is `false`). One metric - value is generated for each threshold value. - name: (Optional) Metric name. - top_k: (Optional) Used with a multi-class model to specify that the top-k - values should be used to compute the confusion matrix. The net effect is - that the non-top-k values are set to -inf and the matrix is then - constructed from the average TP, FP, TN, FN across the classes. When - top_k is used, metrics_specs.binarize settings must not be present. Only - one of class_id or top_k should be configured. When top_k is set, the - default thresholds are [float('-inf')]. - class_id: (Optional) Used with a multi-class model to specify which class - to compute the confusion matrix for. When class_id is used, - metrics_specs.binarize settings must not be present. Only one of - class_id or top_k should be configured. + If `sample_weight` is `None`, weights default to 1. + Use `sample_weight` of 0 to mask values. """ - super().__init__( - thresholds=thresholds, name=name, top_k=top_k, class_id=class_id) - - def _default_name(self) -> str: - return TRUE_POSITIVES_NAME - def result(self, tp: float, tn: float, fp: float, fn: float) -> float: - return tp + def __init__( + self, + thresholds: Optional[Union[float, List[float]]] = None, + name: Optional[str] = None, + top_k: Optional[int] = None, + class_id: Optional[int] = None, + ): + """Initializes TruePositives metric. + + Args: + ---- + thresholds: (Optional) Defaults to [0.5]. A float value or a python + list/tuple of float threshold values in [0, 1]. A threshold is compared + with prediction values to determine the truth value of predictions + (i.e., above the threshold is `true`, below is `false`). One metric + value is generated for each threshold value. + name: (Optional) Metric name. + top_k: (Optional) Used with a multi-class model to specify that the top-k + values should be used to compute the confusion matrix. The net effect is + that the non-top-k values are set to -inf and the matrix is then + constructed from the average TP, FP, TN, FN across the classes. When + top_k is used, metrics_specs.binarize settings must not be present. Only + one of class_id or top_k should be configured. When top_k is set, the + default thresholds are [float('-inf')]. + class_id: (Optional) Used with a multi-class model to specify which class + to compute the confusion matrix for. When class_id is used, + metrics_specs.binarize settings must not be present. Only one of + class_id or top_k should be configured. + """ + super().__init__( + thresholds=thresholds, name=name, top_k=top_k, class_id=class_id + ) + + def _default_name(self) -> str: + return TRUE_POSITIVES_NAME + + def result(self, tp: float, tn: float, fp: float, fn: float) -> float: + return tp metric_types.register_metric(TruePositives) class TP(TruePositives): - """Alias for TruePositives.""" + """Alias for TruePositives.""" - def __init__(self, - thresholds: Optional[Union[float, List[float]]] = None, - name: Optional[str] = None, - top_k: Optional[int] = None, - class_id: Optional[int] = None): - """Initializes TP metric.""" - super().__init__( - thresholds=thresholds, name=name, top_k=top_k, class_id=class_id) + def __init__( + self, + thresholds: Optional[Union[float, List[float]]] = None, + name: Optional[str] = None, + top_k: Optional[int] = None, + class_id: Optional[int] = None, + ): + """Initializes TP metric.""" + super().__init__( + thresholds=thresholds, name=name, top_k=top_k, class_id=class_id + ) - def _default_name(self) -> str: - return TP_NAME + def _default_name(self) -> str: + return TP_NAME metric_types.register_metric(TP) class TrueNegatives(ConfusionMatrixMetric): - """Calculates the number of true negatives. - - If `sample_weight` is given, calculates the sum of the weights of true - negatives. - - If `sample_weight` is `None`, weights default to 1. - Use `sample_weight` of 0 to mask values. - """ + """Calculates the number of true negatives. - def __init__(self, - thresholds: Optional[Union[float, List[float]]] = None, - name: Optional[str] = None, - top_k: Optional[int] = None, - class_id: Optional[int] = None): - """Initializes TrueNegatives metric. + If `sample_weight` is given, calculates the sum of the weights of true + negatives. - Args: - thresholds: (Optional) Defaults to [0.5]. A float value or a python - list/tuple of float threshold values in [0, 1]. A threshold is compared - with prediction values to determine the truth value of predictions - (i.e., above the threshold is `true`, below is `false`). One metric - value is generated for each threshold value. - name: (Optional) Metric name. - top_k: (Optional) Used with a multi-class model to specify that the top-k - values should be used to compute the confusion matrix. The net effect is - that the non-top-k values are set to -inf and the matrix is then - constructed from the average TP, FP, TN, FN across the classes. When - top_k is used, metrics_specs.binarize settings must not be present. Only - one of class_id or top_k should be configured. When top_k is set, the - default thresholds are [float('-inf')]. - class_id: (Optional) Used with a multi-class model to specify which class - to compute the confusion matrix for. When class_id is used, - metrics_specs.binarize settings must not be present. Only one of - class_id or top_k should be configured. + If `sample_weight` is `None`, weights default to 1. + Use `sample_weight` of 0 to mask values. """ - super().__init__( - thresholds=thresholds, name=name, top_k=top_k, class_id=class_id) - - def _default_name(self) -> str: - return TRUE_NEGATIVES_NAME - def result(self, tp: float, tn: float, fp: float, fn: float) -> float: - return tn + def __init__( + self, + thresholds: Optional[Union[float, List[float]]] = None, + name: Optional[str] = None, + top_k: Optional[int] = None, + class_id: Optional[int] = None, + ): + """Initializes TrueNegatives metric. + + Args: + ---- + thresholds: (Optional) Defaults to [0.5]. A float value or a python + list/tuple of float threshold values in [0, 1]. A threshold is compared + with prediction values to determine the truth value of predictions + (i.e., above the threshold is `true`, below is `false`). One metric + value is generated for each threshold value. + name: (Optional) Metric name. + top_k: (Optional) Used with a multi-class model to specify that the top-k + values should be used to compute the confusion matrix. The net effect is + that the non-top-k values are set to -inf and the matrix is then + constructed from the average TP, FP, TN, FN across the classes. When + top_k is used, metrics_specs.binarize settings must not be present. Only + one of class_id or top_k should be configured. When top_k is set, the + default thresholds are [float('-inf')]. + class_id: (Optional) Used with a multi-class model to specify which class + to compute the confusion matrix for. When class_id is used, + metrics_specs.binarize settings must not be present. Only one of + class_id or top_k should be configured. + """ + super().__init__( + thresholds=thresholds, name=name, top_k=top_k, class_id=class_id + ) + + def _default_name(self) -> str: + return TRUE_NEGATIVES_NAME + + def result(self, tp: float, tn: float, fp: float, fn: float) -> float: + return tn metric_types.register_metric(TrueNegatives) class TN(TrueNegatives): - """Alias for TrueNegatives.""" + """Alias for TrueNegatives.""" - def __init__(self, - thresholds: Optional[Union[float, List[float]]] = None, - name: Optional[str] = None, - top_k: Optional[int] = None, - class_id: Optional[int] = None): - """Initializes TN metric.""" - super().__init__( - thresholds=thresholds, name=name, top_k=top_k, class_id=class_id) + def __init__( + self, + thresholds: Optional[Union[float, List[float]]] = None, + name: Optional[str] = None, + top_k: Optional[int] = None, + class_id: Optional[int] = None, + ): + """Initializes TN metric.""" + super().__init__( + thresholds=thresholds, name=name, top_k=top_k, class_id=class_id + ) - def _default_name(self) -> str: - return TN_NAME + def _default_name(self) -> str: + return TN_NAME metric_types.register_metric(TN) class FalsePositives(ConfusionMatrixMetric): - """Calculates the number of false positives. - - If `sample_weight` is given, calculates the sum of the weights of false - positives. + """Calculates the number of false positives. - If `sample_weight` is `None`, weights default to 1. - Use `sample_weight` of 0 to mask values. - """ + If `sample_weight` is given, calculates the sum of the weights of false + positives. - def __init__(self, - thresholds: Optional[Union[float, List[float]]] = None, - name: Optional[str] = None, - top_k: Optional[int] = None, - class_id: Optional[int] = None): - """Initializes FalsePositives metric. - - Args: - thresholds: (Optional) Defaults to [0.5]. A float value or a python - list/tuple of float threshold values in [0, 1]. A threshold is compared - with prediction values to determine the truth value of predictions - (i.e., above the threshold is `true`, below is `false`). One metric - value is generated for each threshold value. - name: (Optional) Metric name. - top_k: (Optional) Used with a multi-class model to specify that the top-k - values should be used to compute the confusion matrix. The net effect is - that the non-top-k values are set to -inf and the matrix is then - constructed from the average TP, FP, TN, FN across the classes. When - top_k is used, metrics_specs.binarize settings must not be present. Only - one of class_id or top_k should be configured. When top_k is set, the - default thresholds are [float('-inf')]. - class_id: (Optional) Used with a multi-class model to specify which class - to compute the confusion matrix for. When class_id is used, - metrics_specs.binarize settings must not be present. Only one of - class_id or top_k should be configured. + If `sample_weight` is `None`, weights default to 1. + Use `sample_weight` of 0 to mask values. """ - super().__init__( - thresholds=thresholds, name=name, top_k=top_k, class_id=class_id) - - def _default_name(self) -> str: - return FALSE_POSITIVES_NAME - def result(self, tp: float, tn: float, fp: float, fn: float) -> float: - return fp + def __init__( + self, + thresholds: Optional[Union[float, List[float]]] = None, + name: Optional[str] = None, + top_k: Optional[int] = None, + class_id: Optional[int] = None, + ): + """Initializes FalsePositives metric. + + Args: + ---- + thresholds: (Optional) Defaults to [0.5]. A float value or a python + list/tuple of float threshold values in [0, 1]. A threshold is compared + with prediction values to determine the truth value of predictions + (i.e., above the threshold is `true`, below is `false`). One metric + value is generated for each threshold value. + name: (Optional) Metric name. + top_k: (Optional) Used with a multi-class model to specify that the top-k + values should be used to compute the confusion matrix. The net effect is + that the non-top-k values are set to -inf and the matrix is then + constructed from the average TP, FP, TN, FN across the classes. When + top_k is used, metrics_specs.binarize settings must not be present. Only + one of class_id or top_k should be configured. When top_k is set, the + default thresholds are [float('-inf')]. + class_id: (Optional) Used with a multi-class model to specify which class + to compute the confusion matrix for. When class_id is used, + metrics_specs.binarize settings must not be present. Only one of + class_id or top_k should be configured. + """ + super().__init__( + thresholds=thresholds, name=name, top_k=top_k, class_id=class_id + ) + + def _default_name(self) -> str: + return FALSE_POSITIVES_NAME + + def result(self, tp: float, tn: float, fp: float, fn: float) -> float: + return fp metric_types.register_metric(FalsePositives) class FP(FalsePositives): - """Alias for FalsePositives.""" + """Alias for FalsePositives.""" - def __init__(self, - thresholds: Optional[Union[float, List[float]]] = None, - name: Optional[str] = None, - top_k: Optional[int] = None, - class_id: Optional[int] = None): - """Initializes FP metric.""" - super().__init__( - thresholds=thresholds, name=name, top_k=top_k, class_id=class_id) + def __init__( + self, + thresholds: Optional[Union[float, List[float]]] = None, + name: Optional[str] = None, + top_k: Optional[int] = None, + class_id: Optional[int] = None, + ): + """Initializes FP metric.""" + super().__init__( + thresholds=thresholds, name=name, top_k=top_k, class_id=class_id + ) - def _default_name(self) -> str: - return FP_NAME + def _default_name(self) -> str: + return FP_NAME metric_types.register_metric(FP) class FalseNegatives(ConfusionMatrixMetric): - """Calculates the number of false negatives. + """Calculates the number of false negatives. - If `sample_weight` is given, calculates the sum of the weights of false - negatives. + If `sample_weight` is given, calculates the sum of the weights of false + negatives. - If `sample_weight` is `None`, weights default to 1. - Use `sample_weight` of 0 to mask values. - """ - - def __init__(self, - thresholds: Optional[Union[float, List[float]]] = None, - name: Optional[str] = None, - top_k: Optional[int] = None, - class_id: Optional[int] = None): - """Initializes FalseNegatives metric. - - Args: - thresholds: (Optional) Defaults to [0.5]. A float value or a python - list/tuple of float threshold values in [0, 1]. A threshold is compared - with prediction values to determine the truth value of predictions - (i.e., above the threshold is `true`, below is `false`). One metric - value is generated for each threshold value. - name: (Optional) Metric name. - top_k: (Optional) Used with a multi-class model to specify that the top-k - values should be used to compute the confusion matrix. The net effect is - that the non-top-k values are set to -inf and the matrix is then - constructed from the average TP, FP, TN, FN across the classes. When - top_k is used, metrics_specs.binarize settings must not be present. Only - one of class_id or top_k should be configured. When top_k is set, the - default thresholds are [float('-inf')]. - class_id: (Optional) Used with a multi-class model to specify which class - to compute the confusion matrix for. When class_id is used, - metrics_specs.binarize settings must not be present. Only one of - class_id or top_k should be configured. + If `sample_weight` is `None`, weights default to 1. + Use `sample_weight` of 0 to mask values. """ - super().__init__( - thresholds=thresholds, name=name, top_k=top_k, class_id=class_id) - def _default_name(self) -> str: - return FALSE_NEGATIVES_NAME - - def result(self, tp: float, tn: float, fp: float, fn: float) -> float: - return fn + def __init__( + self, + thresholds: Optional[Union[float, List[float]]] = None, + name: Optional[str] = None, + top_k: Optional[int] = None, + class_id: Optional[int] = None, + ): + """Initializes FalseNegatives metric. + + Args: + ---- + thresholds: (Optional) Defaults to [0.5]. A float value or a python + list/tuple of float threshold values in [0, 1]. A threshold is compared + with prediction values to determine the truth value of predictions + (i.e., above the threshold is `true`, below is `false`). One metric + value is generated for each threshold value. + name: (Optional) Metric name. + top_k: (Optional) Used with a multi-class model to specify that the top-k + values should be used to compute the confusion matrix. The net effect is + that the non-top-k values are set to -inf and the matrix is then + constructed from the average TP, FP, TN, FN across the classes. When + top_k is used, metrics_specs.binarize settings must not be present. Only + one of class_id or top_k should be configured. When top_k is set, the + default thresholds are [float('-inf')]. + class_id: (Optional) Used with a multi-class model to specify which class + to compute the confusion matrix for. When class_id is used, + metrics_specs.binarize settings must not be present. Only one of + class_id or top_k should be configured. + """ + super().__init__( + thresholds=thresholds, name=name, top_k=top_k, class_id=class_id + ) + + def _default_name(self) -> str: + return FALSE_NEGATIVES_NAME + + def result(self, tp: float, tn: float, fp: float, fn: float) -> float: + return fn metric_types.register_metric(FalseNegatives) class FN(FalseNegatives): - """Alias for FalseNegatives.""" + """Alias for FalseNegatives.""" - def __init__(self, - thresholds: Optional[Union[float, List[float]]] = None, - name: Optional[str] = None, - top_k: Optional[int] = None, - class_id: Optional[int] = None): - """Initializes FN metric.""" - super().__init__( - thresholds=thresholds, name=name, top_k=top_k, class_id=class_id) + def __init__( + self, + thresholds: Optional[Union[float, List[float]]] = None, + name: Optional[str] = None, + top_k: Optional[int] = None, + class_id: Optional[int] = None, + ): + """Initializes FN metric.""" + super().__init__( + thresholds=thresholds, name=name, top_k=top_k, class_id=class_id + ) - def _default_name(self) -> str: - return FN_NAME + def _default_name(self) -> str: + return FN_NAME metric_types.register_metric(FN) class BinaryAccuracy(ConfusionMatrixMetric): - """Calculates how often predictions match binary labels. - - This metric computes the accuracy based on (TP + TN) / (TP + FP + TN + FN). + """Calculates how often predictions match binary labels. - If `sample_weight` is `None`, weights default to 1. - Use `sample_weight` of 0 to mask values. - """ + This metric computes the accuracy based on (TP + TN) / (TP + FP + TN + FN). - def __init__(self, - threshold: Optional[float] = None, - top_k: Optional[int] = None, - class_id: Optional[int] = None, - name: Optional[str] = None): - """Initializes BinaryAccuracy metric. - - Args: - threshold: (Optional) A float value in [0, 1]. The threshold is compared - with prediction values to determine the truth value of predictions - (i.e., above the threshold is `true`, below is `false`). If neither - threshold nor top_k are set, the default is to calculate with - `threshold=0.5`. - top_k: (Optional) Used with a multi-class model to specify that the top-k - values should be used to compute the confusion matrix. The net effect is - that the non-top-k values are set to -inf and the matrix is then - constructed from the average TP, FP, TN, FN across the classes. When - top_k is used, metrics_specs.binarize settings must not be present. Only - one of class_id or top_k should be configured. When top_k is set, the - default thresholds are [float('-inf')]. - class_id: (Optional) Used with a multi-class model to specify which class - to compute the confusion matrix for. When class_id is used, - metrics_specs.binarize settings must not be present. Only one of - class_id or top_k should be configured. - name: (Optional) string name of the metric instance. + If `sample_weight` is `None`, weights default to 1. + Use `sample_weight` of 0 to mask values. """ - super().__init__( - thresholds=threshold, top_k=top_k, class_id=class_id, name=name) - def _default_name(self) -> str: - return BINARY_ACCURACY_NAME - - def result(self, tp: float, tn: float, fp: float, fn: float) -> float: - return _divide_only_positive_denominator(tp + tn, tp + fp + tn + fn) + def __init__( + self, + threshold: Optional[float] = None, + top_k: Optional[int] = None, + class_id: Optional[int] = None, + name: Optional[str] = None, + ): + """Initializes BinaryAccuracy metric. + + Args: + ---- + threshold: (Optional) A float value in [0, 1]. The threshold is compared + with prediction values to determine the truth value of predictions + (i.e., above the threshold is `true`, below is `false`). If neither + threshold nor top_k are set, the default is to calculate with + `threshold=0.5`. + top_k: (Optional) Used with a multi-class model to specify that the top-k + values should be used to compute the confusion matrix. The net effect is + that the non-top-k values are set to -inf and the matrix is then + constructed from the average TP, FP, TN, FN across the classes. When + top_k is used, metrics_specs.binarize settings must not be present. Only + one of class_id or top_k should be configured. When top_k is set, the + default thresholds are [float('-inf')]. + class_id: (Optional) Used with a multi-class model to specify which class + to compute the confusion matrix for. When class_id is used, + metrics_specs.binarize settings must not be present. Only one of + class_id or top_k should be configured. + name: (Optional) string name of the metric instance. + """ + super().__init__( + thresholds=threshold, top_k=top_k, class_id=class_id, name=name + ) + + def _default_name(self) -> str: + return BINARY_ACCURACY_NAME + + def result(self, tp: float, tn: float, fp: float, fn: float) -> float: + return _divide_only_positive_denominator(tp + tn, tp + fp + tn + fn) metric_types.register_metric(BinaryAccuracy) class Precision(ConfusionMatrixMetric): - """Computes the precision of the predictions with respect to the labels. - - The metric uses true positives and false positives to compute precision by - dividing the true positives by the sum of true positives and false positives. - - If `sample_weight` is `None`, weights default to 1. - Use `sample_weight` of 0 to mask values. - """ + """Computes the precision of the predictions with respect to the labels. - def __init__(self, - thresholds: Optional[Union[float, List[float]]] = None, - top_k: Optional[int] = None, - class_id: Optional[int] = None, - name: Optional[str] = None, - **kwargs): - """Initializes Precision metric. + The metric uses true positives and false positives to compute precision by + dividing the true positives by the sum of true positives and false positives. - Args: - thresholds: (Optional) A float value or a python list/tuple of float - threshold values in [0, 1]. A threshold is compared with prediction - values to determine the truth value of predictions (i.e., above the - threshold is `true`, below is `false`). One metric value is generated - for each threshold value. If neither thresholds nor top_k are set, the - default is to calculate precision with `thresholds=0.5`. - top_k: (Optional) Used with a multi-class model to specify that the top-k - values should be used to compute the confusion matrix. The net effect is - that the non-top-k values are set to -inf and the matrix is then - constructed from the average TP, FP, TN, FN across the classes. When - top_k is used, metrics_specs.binarize settings must not be present. Only - one of class_id or top_k should be configured. When top_k is set, the - default thresholds are [float('-inf')]. - class_id: (Optional) Used with a multi-class model to specify which class - to compute the confusion matrix for. When class_id is used, - metrics_specs.binarize settings must not be present. Only one of - class_id or top_k should be configured. - name: (Optional) string name of the metric instance. - **kwargs: (Optional) Additional args to pass along to init (and eventually - on to _metric_computation and _metric_value). + If `sample_weight` is `None`, weights default to 1. + Use `sample_weight` of 0 to mask values. """ - super().__init__( - thresholds=thresholds, - top_k=top_k, - class_id=class_id, - name=name, - **kwargs) - - def _default_name(self) -> str: - return PRECISION_NAME - def result(self, tp: float, tn: float, fp: float, fn: float) -> float: - del tn, fn - return _divide_only_positive_denominator(tp, tp + fp) + def __init__( + self, + thresholds: Optional[Union[float, List[float]]] = None, + top_k: Optional[int] = None, + class_id: Optional[int] = None, + name: Optional[str] = None, + **kwargs, + ): + """Initializes Precision metric. + + Args: + ---- + thresholds: (Optional) A float value or a python list/tuple of float + threshold values in [0, 1]. A threshold is compared with prediction + values to determine the truth value of predictions (i.e., above the + threshold is `true`, below is `false`). One metric value is generated + for each threshold value. If neither thresholds nor top_k are set, the + default is to calculate precision with `thresholds=0.5`. + top_k: (Optional) Used with a multi-class model to specify that the top-k + values should be used to compute the confusion matrix. The net effect is + that the non-top-k values are set to -inf and the matrix is then + constructed from the average TP, FP, TN, FN across the classes. When + top_k is used, metrics_specs.binarize settings must not be present. Only + one of class_id or top_k should be configured. When top_k is set, the + default thresholds are [float('-inf')]. + class_id: (Optional) Used with a multi-class model to specify which class + to compute the confusion matrix for. When class_id is used, + metrics_specs.binarize settings must not be present. Only one of + class_id or top_k should be configured. + name: (Optional) string name of the metric instance. + **kwargs: (Optional) Additional args to pass along to init (and eventually + on to _metric_computation and _metric_value). + """ + super().__init__( + thresholds=thresholds, top_k=top_k, class_id=class_id, name=name, **kwargs + ) + + def _default_name(self) -> str: + return PRECISION_NAME + + def result(self, tp: float, tn: float, fp: float, fn: float) -> float: + del tn, fn + return _divide_only_positive_denominator(tp, tp + fp) metric_types.register_metric(Precision) class PPV(Precision): - """Alias for Precision.""" + """Alias for Precision.""" - def __init__(self, - thresholds: Optional[Union[float, List[float]]] = None, - name: Optional[str] = None, - top_k: Optional[int] = None, - class_id: Optional[int] = None): - """Initializes PPV metric.""" - super().__init__( - thresholds=thresholds, name=name, top_k=top_k, class_id=class_id) + def __init__( + self, + thresholds: Optional[Union[float, List[float]]] = None, + name: Optional[str] = None, + top_k: Optional[int] = None, + class_id: Optional[int] = None, + ): + """Initializes PPV metric.""" + super().__init__( + thresholds=thresholds, name=name, top_k=top_k, class_id=class_id + ) - def _default_name(self) -> str: - return PPV_NAME + def _default_name(self) -> str: + return PPV_NAME metric_types.register_metric(PPV) class Recall(ConfusionMatrixMetric): - """Computes the recall of the predictions with respect to the labels. - - The metric uses true positives and false negatives to compute recall by - dividing the true positives by the sum of true positives and false negatives. + """Computes the recall of the predictions with respect to the labels. - If `sample_weight` is `None`, weights default to 1. - Use `sample_weight` of 0 to mask values. - """ + The metric uses true positives and false negatives to compute recall by + dividing the true positives by the sum of true positives and false negatives. - def __init__(self, - thresholds: Optional[Union[float, List[float]]] = None, - top_k: Optional[int] = None, - class_id: Optional[int] = None, - name: Optional[str] = None, - **kwargs): - """Initializes Recall metric. - - Args: - thresholds: (Optional) A float value or a python list/tuple of float - threshold values in [0, 1]. A threshold is compared with prediction - values to determine the truth value of predictions (i.e., above the - threshold is `true`, below is `false`). One metric value is generated - for each threshold value. If neither thresholds nor top_k are set, the - default is to calculate precision with `thresholds=0.5`. - top_k: (Optional) Used with a multi-class model to specify that the top-k - values should be used to compute the confusion matrix. The net effect is - that the non-top-k values are set to -inf and the matrix is then - constructed from the average TP, FP, TN, FN across the classes. When - top_k is used, metrics_specs.binarize settings must not be present. Only - one of class_id or top_k should be configured. When top_k is set, the - default thresholds are [float('-inf')]. - class_id: (Optional) Used with a multi-class model to specify which class - to compute the confusion matrix for. When class_id is used, - metrics_specs.binarize settings must not be present. Only one of - class_id or top_k should be configured. - name: (Optional) string name of the metric instance. - **kwargs: (Optional) Additional args to pass along to init (and eventually - on to _metric_computation and _metric_value) + If `sample_weight` is `None`, weights default to 1. + Use `sample_weight` of 0 to mask values. """ - super().__init__( - thresholds=thresholds, - top_k=top_k, - class_id=class_id, - name=name, - **kwargs) - - def _default_name(self) -> str: - return RECALL_NAME - def result(self, tp: float, tn: float, fp: float, fn: float) -> float: - del tn, fp - return _divide_only_positive_denominator(tp, tp + fn) + def __init__( + self, + thresholds: Optional[Union[float, List[float]]] = None, + top_k: Optional[int] = None, + class_id: Optional[int] = None, + name: Optional[str] = None, + **kwargs, + ): + """Initializes Recall metric. + + Args: + ---- + thresholds: (Optional) A float value or a python list/tuple of float + threshold values in [0, 1]. A threshold is compared with prediction + values to determine the truth value of predictions (i.e., above the + threshold is `true`, below is `false`). One metric value is generated + for each threshold value. If neither thresholds nor top_k are set, the + default is to calculate precision with `thresholds=0.5`. + top_k: (Optional) Used with a multi-class model to specify that the top-k + values should be used to compute the confusion matrix. The net effect is + that the non-top-k values are set to -inf and the matrix is then + constructed from the average TP, FP, TN, FN across the classes. When + top_k is used, metrics_specs.binarize settings must not be present. Only + one of class_id or top_k should be configured. When top_k is set, the + default thresholds are [float('-inf')]. + class_id: (Optional) Used with a multi-class model to specify which class + to compute the confusion matrix for. When class_id is used, + metrics_specs.binarize settings must not be present. Only one of + class_id or top_k should be configured. + name: (Optional) string name of the metric instance. + **kwargs: (Optional) Additional args to pass along to init (and eventually + on to _metric_computation and _metric_value) + """ + super().__init__( + thresholds=thresholds, top_k=top_k, class_id=class_id, name=name, **kwargs + ) + + def _default_name(self) -> str: + return RECALL_NAME + + def result(self, tp: float, tn: float, fp: float, fn: float) -> float: + del tn, fp + return _divide_only_positive_denominator(tp, tp + fn) metric_types.register_metric(Recall) class TPR(Recall): - """Alias for Recall.""" + """Alias for Recall.""" - def __init__(self, - thresholds: Optional[Union[float, List[float]]] = None, - name: Optional[str] = None, - top_k: Optional[int] = None, - class_id: Optional[int] = None): - """Initializes TPR metric.""" - super().__init__( - thresholds=thresholds, name=name, top_k=top_k, class_id=class_id) + def __init__( + self, + thresholds: Optional[Union[float, List[float]]] = None, + name: Optional[str] = None, + top_k: Optional[int] = None, + class_id: Optional[int] = None, + ): + """Initializes TPR metric.""" + super().__init__( + thresholds=thresholds, name=name, top_k=top_k, class_id=class_id + ) - def _default_name(self) -> str: - return TPR_NAME + def _default_name(self) -> str: + return TPR_NAME metric_types.register_metric(TPR) class Specificity(ConfusionMatrixMetric): - """Specificity (TNR) or selectivity.""" - - def __init__(self, - thresholds: Optional[Union[float, List[float]]] = None, - name: Optional[str] = None, - top_k: Optional[int] = None, - class_id: Optional[int] = None): - """Initializes specificity metric. - - Args: - thresholds: (Optional) Thresholds to use for specificity. Defaults to - [0.5]. - name: (Optional) Metric name. - top_k: (Optional) Used with a multi-class model to specify that the top-k - values should be used to compute the confusion matrix. The net effect is - that the non-top-k values are set to -inf and the matrix is then - constructed from the average TP, FP, TN, FN across the classes. When - top_k is used, metrics_specs.binarize settings must not be present. Only - one of class_id or top_k should be configured. When top_k is set, the - default thresholds are [float('-inf')]. - class_id: (Optional) Used with a multi-class model to specify which class - to compute the confusion matrix for. When class_id is used, - metrics_specs.binarize settings must not be present. Only one of - class_id or top_k should be configured. - """ - super().__init__( - thresholds=thresholds, name=name, top_k=top_k, class_id=class_id) - - def _default_name(self) -> str: - return SPECIFICITY_NAME - - def result(self, tp: float, tn: float, fp: float, fn: float) -> float: - del tp, fn - return _divide_only_positive_denominator(tn, tn + fp) + """Specificity (TNR) or selectivity.""" + + def __init__( + self, + thresholds: Optional[Union[float, List[float]]] = None, + name: Optional[str] = None, + top_k: Optional[int] = None, + class_id: Optional[int] = None, + ): + """Initializes specificity metric. + + Args: + ---- + thresholds: (Optional) Thresholds to use for specificity. Defaults to + [0.5]. + name: (Optional) Metric name. + top_k: (Optional) Used with a multi-class model to specify that the top-k + values should be used to compute the confusion matrix. The net effect is + that the non-top-k values are set to -inf and the matrix is then + constructed from the average TP, FP, TN, FN across the classes. When + top_k is used, metrics_specs.binarize settings must not be present. Only + one of class_id or top_k should be configured. When top_k is set, the + default thresholds are [float('-inf')]. + class_id: (Optional) Used with a multi-class model to specify which class + to compute the confusion matrix for. When class_id is used, + metrics_specs.binarize settings must not be present. Only one of + class_id or top_k should be configured. + """ + super().__init__( + thresholds=thresholds, name=name, top_k=top_k, class_id=class_id + ) + + def _default_name(self) -> str: + return SPECIFICITY_NAME + + def result(self, tp: float, tn: float, fp: float, fn: float) -> float: + del tp, fn + return _divide_only_positive_denominator(tn, tn + fp) metric_types.register_metric(Specificity) class TNR(Specificity): - """Alias for Specificity.""" + """Alias for Specificity.""" - def __init__(self, - thresholds: Optional[Union[float, List[float]]] = None, - name: Optional[str] = None, - top_k: Optional[int] = None, - class_id: Optional[int] = None): - """Initializes TNR metric.""" - super().__init__( - thresholds=thresholds, name=name, top_k=top_k, class_id=class_id) + def __init__( + self, + thresholds: Optional[Union[float, List[float]]] = None, + name: Optional[str] = None, + top_k: Optional[int] = None, + class_id: Optional[int] = None, + ): + """Initializes TNR metric.""" + super().__init__( + thresholds=thresholds, name=name, top_k=top_k, class_id=class_id + ) - def _default_name(self) -> str: - return TNR_NAME + def _default_name(self) -> str: + return TNR_NAME metric_types.register_metric(TNR) class FallOut(ConfusionMatrixMetric): - """Fall-out (FPR).""" - - def __init__(self, - thresholds: Optional[Union[float, List[float]]] = None, - name: Optional[str] = None, - top_k: Optional[int] = None, - class_id: Optional[int] = None): - """Initializes fall-out metric. - - Args: - thresholds: (Optional) Thresholds to use for fall-out. Defaults to [0.5]. - name: (Optional) Metric name. - top_k: (Optional) Used with a multi-class model to specify that the top-k - values should be used to compute the confusion matrix. The net effect is - that the non-top-k values are set to -inf and the matrix is then - constructed from the average TP, FP, TN, FN across the classes. When - top_k is used, metrics_specs.binarize settings must not be present. Only - one of class_id or top_k should be configured. When top_k is set, the - default thresholds are [float('-inf')]. - class_id: (Optional) Used with a multi-class model to specify which class - to compute the confusion matrix for. When class_id is used, - metrics_specs.binarize settings must not be present. Only one of - class_id or top_k should be configured. - """ - super().__init__( - thresholds=thresholds, name=name, top_k=top_k, class_id=class_id) - - def _default_name(self) -> str: - return FALL_OUT_NAME - - def result(self, tp: float, tn: float, fp: float, fn: float) -> float: - del tp, fn - return _divide_only_positive_denominator(fp, fp + tn) + """Fall-out (FPR).""" + + def __init__( + self, + thresholds: Optional[Union[float, List[float]]] = None, + name: Optional[str] = None, + top_k: Optional[int] = None, + class_id: Optional[int] = None, + ): + """Initializes fall-out metric. + + Args: + ---- + thresholds: (Optional) Thresholds to use for fall-out. Defaults to [0.5]. + name: (Optional) Metric name. + top_k: (Optional) Used with a multi-class model to specify that the top-k + values should be used to compute the confusion matrix. The net effect is + that the non-top-k values are set to -inf and the matrix is then + constructed from the average TP, FP, TN, FN across the classes. When + top_k is used, metrics_specs.binarize settings must not be present. Only + one of class_id or top_k should be configured. When top_k is set, the + default thresholds are [float('-inf')]. + class_id: (Optional) Used with a multi-class model to specify which class + to compute the confusion matrix for. When class_id is used, + metrics_specs.binarize settings must not be present. Only one of + class_id or top_k should be configured. + """ + super().__init__( + thresholds=thresholds, name=name, top_k=top_k, class_id=class_id + ) + + def _default_name(self) -> str: + return FALL_OUT_NAME + + def result(self, tp: float, tn: float, fp: float, fn: float) -> float: + del tp, fn + return _divide_only_positive_denominator(fp, fp + tn) metric_types.register_metric(FallOut) class FPR(FallOut): - """Alias for FallOut.""" + """Alias for FallOut.""" - def __init__(self, - thresholds: Optional[Union[float, List[float]]] = None, - name: Optional[str] = None, - top_k: Optional[int] = None, - class_id: Optional[int] = None): - """Initializes FPR metric.""" - super().__init__( - thresholds=thresholds, name=name, top_k=top_k, class_id=class_id) + def __init__( + self, + thresholds: Optional[Union[float, List[float]]] = None, + name: Optional[str] = None, + top_k: Optional[int] = None, + class_id: Optional[int] = None, + ): + """Initializes FPR metric.""" + super().__init__( + thresholds=thresholds, name=name, top_k=top_k, class_id=class_id + ) - def _default_name(self) -> str: - return FPR_NAME + def _default_name(self) -> str: + return FPR_NAME metric_types.register_metric(FPR) class MissRate(ConfusionMatrixMetric): - """Miss rate (FNR).""" - - def __init__(self, - thresholds: Optional[Union[float, List[float]]] = None, - name: Optional[str] = None, - top_k: Optional[int] = None, - class_id: Optional[int] = None): - """Initializes miss rate metric. - - Args: - thresholds: (Optional) Thresholds to use for miss rate. Defaults to [0.5]. - name: (Optional) Metric name. - top_k: (Optional) Used with a multi-class model to specify that the top-k - values should be used to compute the confusion matrix. The net effect is - that the non-top-k values are set to -inf and the matrix is then - constructed from the average TP, FP, TN, FN across the classes. When - top_k is used, metrics_specs.binarize settings must not be present. Only - one of class_id or top_k should be configured. When top_k is set, the - default thresholds are [float('-inf')]. - class_id: (Optional) Used with a multi-class model to specify which class - to compute the confusion matrix for. When class_id is used, - metrics_specs.binarize settings must not be present. Only one of - class_id or top_k should be configured. - """ - super().__init__( - thresholds=thresholds, name=name, top_k=top_k, class_id=class_id) - - def _default_name(self) -> str: - return MISS_RATE_NAME - - def result(self, tp: float, tn: float, fp: float, fn: float) -> float: - del tn, fp - return _divide_only_positive_denominator(fn, fn + tp) + """Miss rate (FNR).""" + + def __init__( + self, + thresholds: Optional[Union[float, List[float]]] = None, + name: Optional[str] = None, + top_k: Optional[int] = None, + class_id: Optional[int] = None, + ): + """Initializes miss rate metric. + + Args: + ---- + thresholds: (Optional) Thresholds to use for miss rate. Defaults to [0.5]. + name: (Optional) Metric name. + top_k: (Optional) Used with a multi-class model to specify that the top-k + values should be used to compute the confusion matrix. The net effect is + that the non-top-k values are set to -inf and the matrix is then + constructed from the average TP, FP, TN, FN across the classes. When + top_k is used, metrics_specs.binarize settings must not be present. Only + one of class_id or top_k should be configured. When top_k is set, the + default thresholds are [float('-inf')]. + class_id: (Optional) Used with a multi-class model to specify which class + to compute the confusion matrix for. When class_id is used, + metrics_specs.binarize settings must not be present. Only one of + class_id or top_k should be configured. + """ + super().__init__( + thresholds=thresholds, name=name, top_k=top_k, class_id=class_id + ) + + def _default_name(self) -> str: + return MISS_RATE_NAME + + def result(self, tp: float, tn: float, fp: float, fn: float) -> float: + del tn, fp + return _divide_only_positive_denominator(fn, fn + tp) metric_types.register_metric(MissRate) class FNR(MissRate): - """Alias for MissRate.""" + """Alias for MissRate.""" - def __init__(self, - thresholds: Optional[Union[float, List[float]]] = None, - name: Optional[str] = None, - top_k: Optional[int] = None, - class_id: Optional[int] = None): - """Initializes FNR metric.""" - super().__init__( - thresholds=thresholds, name=name, top_k=top_k, class_id=class_id) + def __init__( + self, + thresholds: Optional[Union[float, List[float]]] = None, + name: Optional[str] = None, + top_k: Optional[int] = None, + class_id: Optional[int] = None, + ): + """Initializes FNR metric.""" + super().__init__( + thresholds=thresholds, name=name, top_k=top_k, class_id=class_id + ) - def _default_name(self) -> str: - return FNR_NAME + def _default_name(self) -> str: + return FNR_NAME metric_types.register_metric(FNR) class NegativePredictiveValue(ConfusionMatrixMetric): - """Negative predictive value (NPV).""" - - def __init__(self, - thresholds: Optional[Union[float, List[float]]] = None, - name: Optional[str] = None, - top_k: Optional[int] = None, - class_id: Optional[int] = None): - """Initializes negative predictive value. - - Args: - thresholds: (Optional) Thresholds to use. Defaults to [0.5]. - name: (Optional) Metric name. - top_k: (Optional) Used with a multi-class model to specify that the top-k - values should be used to compute the confusion matrix. The net effect is - that the non-top-k values are set to -inf and the matrix is then - constructed from the average TP, FP, TN, FN across the classes. When - top_k is used, metrics_specs.binarize settings must not be present. Only - one of class_id or top_k should be configured. When top_k is set, the - default thresholds are [float('-inf')]. - class_id: (Optional) Used with a multi-class model to specify which class - to compute the confusion matrix for. When class_id is used, - metrics_specs.binarize settings must not be present. Only one of - class_id or top_k should be configured. - """ - super().__init__( - thresholds=thresholds, name=name, top_k=top_k, class_id=class_id) - - def _default_name(self) -> str: - return NEGATIVE_PREDICTIVE_VALUE_NAME - - def result(self, tp: float, tn: float, fp: float, fn: float) -> float: - del tp, fp - return _divide_only_positive_denominator(tn, tn + fn) + """Negative predictive value (NPV).""" + + def __init__( + self, + thresholds: Optional[Union[float, List[float]]] = None, + name: Optional[str] = None, + top_k: Optional[int] = None, + class_id: Optional[int] = None, + ): + """Initializes negative predictive value. + + Args: + ---- + thresholds: (Optional) Thresholds to use. Defaults to [0.5]. + name: (Optional) Metric name. + top_k: (Optional) Used with a multi-class model to specify that the top-k + values should be used to compute the confusion matrix. The net effect is + that the non-top-k values are set to -inf and the matrix is then + constructed from the average TP, FP, TN, FN across the classes. When + top_k is used, metrics_specs.binarize settings must not be present. Only + one of class_id or top_k should be configured. When top_k is set, the + default thresholds are [float('-inf')]. + class_id: (Optional) Used with a multi-class model to specify which class + to compute the confusion matrix for. When class_id is used, + metrics_specs.binarize settings must not be present. Only one of + class_id or top_k should be configured. + """ + super().__init__( + thresholds=thresholds, name=name, top_k=top_k, class_id=class_id + ) + + def _default_name(self) -> str: + return NEGATIVE_PREDICTIVE_VALUE_NAME + + def result(self, tp: float, tn: float, fp: float, fn: float) -> float: + del tp, fp + return _divide_only_positive_denominator(tn, tn + fn) metric_types.register_metric(NegativePredictiveValue) class NPV(NegativePredictiveValue): - """Alias for NegativePredictiveValue.""" + """Alias for NegativePredictiveValue.""" - def __init__(self, - thresholds: Optional[Union[float, List[float]]] = None, - name: Optional[str] = None, - top_k: Optional[int] = None, - class_id: Optional[int] = None): - """Initializes PPV metric.""" - super().__init__( - thresholds=thresholds, name=name, top_k=top_k, class_id=class_id) + def __init__( + self, + thresholds: Optional[Union[float, List[float]]] = None, + name: Optional[str] = None, + top_k: Optional[int] = None, + class_id: Optional[int] = None, + ): + """Initializes PPV metric.""" + super().__init__( + thresholds=thresholds, name=name, top_k=top_k, class_id=class_id + ) - def _default_name(self) -> str: - return NPV_NAME + def _default_name(self) -> str: + return NPV_NAME metric_types.register_metric(NPV) class FalseDiscoveryRate(ConfusionMatrixMetric): - """False discovery rate (FDR).""" - - def __init__(self, - thresholds: Optional[Union[float, List[float]]] = None, - name: Optional[str] = None, - top_k: Optional[int] = None, - class_id: Optional[int] = None): - """Initializes false discovery rate. - - Args: - thresholds: (Optional) Thresholds to use. Defaults to [0.5]. - name: (Optional) Metric name. - top_k: (Optional) Used with a multi-class model to specify that the top-k - values should be used to compute the confusion matrix. The net effect is - that the non-top-k values are set to -inf and the matrix is then - constructed from the average TP, FP, TN, FN across the classes. When - top_k is used, metrics_specs.binarize settings must not be present. Only - one of class_id or top_k should be configured. When top_k is set, the - default thresholds are [float('-inf')]. - class_id: (Optional) Used with a multi-class model to specify which class - to compute the confusion matrix for. When class_id is used, - metrics_specs.binarize settings must not be present. Only one of - class_id or top_k should be configured. - """ - super().__init__( - thresholds=thresholds, name=name, top_k=top_k, class_id=class_id) - - def _default_name(self) -> str: - return FALSE_DISCOVERY_RATE_NAME - - def result(self, tp: float, tn: float, fp: float, fn: float) -> float: - del tn, fn - return _divide_only_positive_denominator(fp, fp + tp) + """False discovery rate (FDR).""" + + def __init__( + self, + thresholds: Optional[Union[float, List[float]]] = None, + name: Optional[str] = None, + top_k: Optional[int] = None, + class_id: Optional[int] = None, + ): + """Initializes false discovery rate. + + Args: + ---- + thresholds: (Optional) Thresholds to use. Defaults to [0.5]. + name: (Optional) Metric name. + top_k: (Optional) Used with a multi-class model to specify that the top-k + values should be used to compute the confusion matrix. The net effect is + that the non-top-k values are set to -inf and the matrix is then + constructed from the average TP, FP, TN, FN across the classes. When + top_k is used, metrics_specs.binarize settings must not be present. Only + one of class_id or top_k should be configured. When top_k is set, the + default thresholds are [float('-inf')]. + class_id: (Optional) Used with a multi-class model to specify which class + to compute the confusion matrix for. When class_id is used, + metrics_specs.binarize settings must not be present. Only one of + class_id or top_k should be configured. + """ + super().__init__( + thresholds=thresholds, name=name, top_k=top_k, class_id=class_id + ) + + def _default_name(self) -> str: + return FALSE_DISCOVERY_RATE_NAME + + def result(self, tp: float, tn: float, fp: float, fn: float) -> float: + del tn, fn + return _divide_only_positive_denominator(fp, fp + tp) metric_types.register_metric(FalseDiscoveryRate) class FalseOmissionRate(ConfusionMatrixMetric): - """False omission rate (FOR).""" - - def __init__(self, - thresholds: Optional[Union[float, List[float]]] = None, - name: Optional[str] = None, - top_k: Optional[int] = None, - class_id: Optional[int] = None): - """Initializes false omission rate. - - Args: - thresholds: (Optional) Thresholds to use. Defaults to [0.5]. - name: (Optional) Metric name. - top_k: (Optional) Used with a multi-class model to specify that the top-k - values should be used to compute the confusion matrix. The net effect is - that the non-top-k values are set to -inf and the matrix is then - constructed from the average TP, FP, TN, FN across the classes. When - top_k is used, metrics_specs.binarize settings must not be present. Only - one of class_id or top_k should be configured. When top_k is set, the - default thresholds are [float('-inf')]. - class_id: (Optional) Used with a multi-class model to specify which class - to compute the confusion matrix for. When class_id is used, - metrics_specs.binarize settings must not be present. Only one of - class_id or top_k should be configured. - """ - super().__init__( - thresholds=thresholds, name=name, top_k=top_k, class_id=class_id) - - def _default_name(self) -> str: - return FALSE_OMISSION_RATE_NAME - - def result(self, tp: float, tn: float, fp: float, fn: float) -> float: - del tp, fp - return _divide_only_positive_denominator(fn, fn + tn) + """False omission rate (FOR).""" + + def __init__( + self, + thresholds: Optional[Union[float, List[float]]] = None, + name: Optional[str] = None, + top_k: Optional[int] = None, + class_id: Optional[int] = None, + ): + """Initializes false omission rate. + + Args: + ---- + thresholds: (Optional) Thresholds to use. Defaults to [0.5]. + name: (Optional) Metric name. + top_k: (Optional) Used with a multi-class model to specify that the top-k + values should be used to compute the confusion matrix. The net effect is + that the non-top-k values are set to -inf and the matrix is then + constructed from the average TP, FP, TN, FN across the classes. When + top_k is used, metrics_specs.binarize settings must not be present. Only + one of class_id or top_k should be configured. When top_k is set, the + default thresholds are [float('-inf')]. + class_id: (Optional) Used with a multi-class model to specify which class + to compute the confusion matrix for. When class_id is used, + metrics_specs.binarize settings must not be present. Only one of + class_id or top_k should be configured. + """ + super().__init__( + thresholds=thresholds, name=name, top_k=top_k, class_id=class_id + ) + + def _default_name(self) -> str: + return FALSE_OMISSION_RATE_NAME + + def result(self, tp: float, tn: float, fp: float, fn: float) -> float: + del tp, fp + return _divide_only_positive_denominator(fn, fn + tn) metric_types.register_metric(FalseOmissionRate) class Prevalence(ConfusionMatrixMetric): - """Prevalence.""" - - def __init__(self, - thresholds: Optional[Union[float, List[float]]] = None, - name: Optional[str] = None, - top_k: Optional[int] = None, - class_id: Optional[int] = None): - """Initializes prevalence. - - Args: - thresholds: (Optional) Thresholds to use. Defaults to [0.5]. - name: (Optional) Metric name. - top_k: (Optional) Used with a multi-class model to specify that the top-k - values should be used to compute the confusion matrix. The net effect is - that the non-top-k values are set to -inf and the matrix is then - constructed from the average TP, FP, TN, FN across the classes. When - top_k is used, metrics_specs.binarize settings must not be present. Only - one of class_id or top_k should be configured. When top_k is set, the - default thresholds are [float('-inf')]. - class_id: (Optional) Used with a multi-class model to specify which class - to compute the confusion matrix for. When class_id is used, - metrics_specs.binarize settings must not be present. Only one of - class_id or top_k should be configured. - """ - super().__init__( - thresholds=thresholds, name=name, top_k=top_k, class_id=class_id) - - def _default_name(self) -> str: - return PREVALENCE_NAME - - def result(self, tp: float, tn: float, fp: float, fn: float) -> float: - return _divide_only_positive_denominator(tp + fn, tp + tn + fp + fn) + """Prevalence.""" + + def __init__( + self, + thresholds: Optional[Union[float, List[float]]] = None, + name: Optional[str] = None, + top_k: Optional[int] = None, + class_id: Optional[int] = None, + ): + """Initializes prevalence. + + Args: + ---- + thresholds: (Optional) Thresholds to use. Defaults to [0.5]. + name: (Optional) Metric name. + top_k: (Optional) Used with a multi-class model to specify that the top-k + values should be used to compute the confusion matrix. The net effect is + that the non-top-k values are set to -inf and the matrix is then + constructed from the average TP, FP, TN, FN across the classes. When + top_k is used, metrics_specs.binarize settings must not be present. Only + one of class_id or top_k should be configured. When top_k is set, the + default thresholds are [float('-inf')]. + class_id: (Optional) Used with a multi-class model to specify which class + to compute the confusion matrix for. When class_id is used, + metrics_specs.binarize settings must not be present. Only one of + class_id or top_k should be configured. + """ + super().__init__( + thresholds=thresholds, name=name, top_k=top_k, class_id=class_id + ) + + def _default_name(self) -> str: + return PREVALENCE_NAME + + def result(self, tp: float, tn: float, fp: float, fn: float) -> float: + return _divide_only_positive_denominator(tp + fn, tp + tn + fp + fn) metric_types.register_metric(Prevalence) class PrevalenceThreshold(ConfusionMatrixMetric): - """Prevalence threshold (PT).""" - - def __init__(self, - thresholds: Optional[Union[float, List[float]]] = None, - name: Optional[str] = None, - top_k: Optional[int] = None, - class_id: Optional[int] = None): - """Initializes prevalence threshold. - - Args: - thresholds: (Optional) Thresholds to use. Defaults to [0.5]. - name: (Optional) Metric name. - top_k: (Optional) Used with a multi-class model to specify that the top-k - values should be used to compute the confusion matrix. The net effect is - that the non-top-k values are set to -inf and the matrix is then - constructed from the average TP, FP, TN, FN across the classes. When - top_k is used, metrics_specs.binarize settings must not be present. Only - one of class_id or top_k should be configured. When top_k is set, the - default thresholds are [float('-inf')]. - class_id: (Optional) Used with a multi-class model to specify which class - to compute the confusion matrix for. When class_id is used, - metrics_specs.binarize settings must not be present. Only one of - class_id or top_k should be configured. - """ - super().__init__( - thresholds=thresholds, name=name, top_k=top_k, class_id=class_id) - - def _default_name(self) -> str: - return PREVALENCE_THRESHOLD_NAME - - def result(self, tp: float, tn: float, fp: float, fn: float) -> float: - tpr_denominator = tp + fn - tnr_denominator = tn + fp - if tpr_denominator > 0.0 and tnr_denominator > 0.0: - tpr = tp / tpr_denominator - tnr = tn / tnr_denominator - return (_pos_sqrt(tpr * (1 - tnr)) + tnr - 1) / (tpr + tnr - 1) - else: - return float('nan') + """Prevalence threshold (PT).""" + + def __init__( + self, + thresholds: Optional[Union[float, List[float]]] = None, + name: Optional[str] = None, + top_k: Optional[int] = None, + class_id: Optional[int] = None, + ): + """Initializes prevalence threshold. + + Args: + ---- + thresholds: (Optional) Thresholds to use. Defaults to [0.5]. + name: (Optional) Metric name. + top_k: (Optional) Used with a multi-class model to specify that the top-k + values should be used to compute the confusion matrix. The net effect is + that the non-top-k values are set to -inf and the matrix is then + constructed from the average TP, FP, TN, FN across the classes. When + top_k is used, metrics_specs.binarize settings must not be present. Only + one of class_id or top_k should be configured. When top_k is set, the + default thresholds are [float('-inf')]. + class_id: (Optional) Used with a multi-class model to specify which class + to compute the confusion matrix for. When class_id is used, + metrics_specs.binarize settings must not be present. Only one of + class_id or top_k should be configured. + """ + super().__init__( + thresholds=thresholds, name=name, top_k=top_k, class_id=class_id + ) + + def _default_name(self) -> str: + return PREVALENCE_THRESHOLD_NAME + + def result(self, tp: float, tn: float, fp: float, fn: float) -> float: + tpr_denominator = tp + fn + tnr_denominator = tn + fp + if tpr_denominator > 0.0 and tnr_denominator > 0.0: + tpr = tp / tpr_denominator + tnr = tn / tnr_denominator + return (_pos_sqrt(tpr * (1 - tnr)) + tnr - 1) / (tpr + tnr - 1) + else: + return float("nan") metric_types.register_metric(PrevalenceThreshold) class ThreatScore(ConfusionMatrixMetric): - """Threat score or critical success index (TS or CSI).""" - - def __init__(self, - thresholds: Optional[Union[float, List[float]]] = None, - name: Optional[str] = None, - top_k: Optional[int] = None, - class_id: Optional[int] = None): - """Initializes threat score. - - Args: - thresholds: (Optional) Thresholds to use. Defaults to [0.5]. - name: (Optional) Metric name. - top_k: (Optional) Used with a multi-class model to specify that the top-k - values should be used to compute the confusion matrix. The net effect is - that the non-top-k values are set to -inf and the matrix is then - constructed from the average TP, FP, TN, FN across the classes. When - top_k is used, metrics_specs.binarize settings must not be present. Only - one of class_id or top_k should be configured. When top_k is set, the - default thresholds are [float('-inf')]. - class_id: (Optional) Used with a multi-class model to specify which class - to compute the confusion matrix for. When class_id is used, - metrics_specs.binarize settings must not be present. Only one of - class_id or top_k should be configured. - """ - super().__init__( - thresholds=thresholds, name=name, top_k=top_k, class_id=class_id) - - def _default_name(self) -> str: - return THREAT_SCORE_NAME - - def result(self, tp: float, tn: float, fp: float, fn: float) -> float: - del tn - return _divide_only_positive_denominator(tp, tp + fn + fp) + """Threat score or critical success index (TS or CSI).""" + + def __init__( + self, + thresholds: Optional[Union[float, List[float]]] = None, + name: Optional[str] = None, + top_k: Optional[int] = None, + class_id: Optional[int] = None, + ): + """Initializes threat score. + + Args: + ---- + thresholds: (Optional) Thresholds to use. Defaults to [0.5]. + name: (Optional) Metric name. + top_k: (Optional) Used with a multi-class model to specify that the top-k + values should be used to compute the confusion matrix. The net effect is + that the non-top-k values are set to -inf and the matrix is then + constructed from the average TP, FP, TN, FN across the classes. When + top_k is used, metrics_specs.binarize settings must not be present. Only + one of class_id or top_k should be configured. When top_k is set, the + default thresholds are [float('-inf')]. + class_id: (Optional) Used with a multi-class model to specify which class + to compute the confusion matrix for. When class_id is used, + metrics_specs.binarize settings must not be present. Only one of + class_id or top_k should be configured. + """ + super().__init__( + thresholds=thresholds, name=name, top_k=top_k, class_id=class_id + ) + + def _default_name(self) -> str: + return THREAT_SCORE_NAME + + def result(self, tp: float, tn: float, fp: float, fn: float) -> float: + del tn + return _divide_only_positive_denominator(tp, tp + fn + fp) metric_types.register_metric(ThreatScore) class BalancedAccuracy(ConfusionMatrixMetric): - """Balanced accuracy (BA).""" - - def __init__(self, - thresholds: Optional[Union[float, List[float]]] = None, - name: Optional[str] = None, - top_k: Optional[int] = None, - class_id: Optional[int] = None): - """Initializes balanced accuracy. - - Args: - thresholds: (Optional) Thresholds to use. Defaults to [0.5]. - name: (Optional) Metric name. - top_k: (Optional) Used with a multi-class model to specify that the top-k - values should be used to compute the confusion matrix. The net effect is - that the non-top-k values are set to -inf and the matrix is then - constructed from the average TP, FP, TN, FN across the classes. When - top_k is used, metrics_specs.binarize settings must not be present. Only - one of class_id or top_k should be configured. When top_k is set, the - default thresholds are [float('-inf')]. - class_id: (Optional) Used with a multi-class model to specify which class - to compute the confusion matrix for. When class_id is used, - metrics_specs.binarize settings must not be present. Only one of - class_id or top_k should be configured. - """ - super().__init__( - thresholds=thresholds, name=name, top_k=top_k, class_id=class_id) - - def _default_name(self) -> str: - return BALANCED_ACCURACY_NAME - - def result(self, tp: float, tn: float, fp: float, fn: float) -> float: - tpr_denominator = tp + fn - tnr_denominator = tn + fp - if tpr_denominator > 0.0 and tnr_denominator > 0.0: - tpr = tp / tpr_denominator - tnr = tn / tnr_denominator - return (tpr + tnr) / 2 - else: - return float('nan') + """Balanced accuracy (BA).""" + + def __init__( + self, + thresholds: Optional[Union[float, List[float]]] = None, + name: Optional[str] = None, + top_k: Optional[int] = None, + class_id: Optional[int] = None, + ): + """Initializes balanced accuracy. + + Args: + ---- + thresholds: (Optional) Thresholds to use. Defaults to [0.5]. + name: (Optional) Metric name. + top_k: (Optional) Used with a multi-class model to specify that the top-k + values should be used to compute the confusion matrix. The net effect is + that the non-top-k values are set to -inf and the matrix is then + constructed from the average TP, FP, TN, FN across the classes. When + top_k is used, metrics_specs.binarize settings must not be present. Only + one of class_id or top_k should be configured. When top_k is set, the + default thresholds are [float('-inf')]. + class_id: (Optional) Used with a multi-class model to specify which class + to compute the confusion matrix for. When class_id is used, + metrics_specs.binarize settings must not be present. Only one of + class_id or top_k should be configured. + """ + super().__init__( + thresholds=thresholds, name=name, top_k=top_k, class_id=class_id + ) + + def _default_name(self) -> str: + return BALANCED_ACCURACY_NAME + + def result(self, tp: float, tn: float, fp: float, fn: float) -> float: + tpr_denominator = tp + fn + tnr_denominator = tn + fp + if tpr_denominator > 0.0 and tnr_denominator > 0.0: + tpr = tp / tpr_denominator + tnr = tn / tnr_denominator + return (tpr + tnr) / 2 + else: + return float("nan") metric_types.register_metric(BalancedAccuracy) class F1Score(ConfusionMatrixMetric): - """F1 score.""" - - def __init__(self, - thresholds: Optional[Union[float, List[float]]] = None, - name: Optional[str] = None, - top_k: Optional[int] = None, - class_id: Optional[int] = None): - """Initializes F1 score. - - Args: - thresholds: (Optional) Thresholds to use. Defaults to [0.5]. - name: (Optional) Metric name. - top_k: (Optional) Used with a multi-class model to specify that the top-k - values should be used to compute the confusion matrix. The net effect is - that the non-top-k values are set to -inf and the matrix is then - constructed from the average TP, FP, TN, FN across the classes. When - top_k is used, metrics_specs.binarize settings must not be present. Only - one of class_id or top_k should be configured. When top_k is set, the - default thresholds are [float('-inf')]. - class_id: (Optional) Used with a multi-class model to specify which class - to compute the confusion matrix for. When class_id is used, - metrics_specs.binarize settings must not be present. Only one of - class_id or top_k should be configured. - """ - super().__init__( - thresholds=thresholds, name=name, top_k=top_k, class_id=class_id) - - def _default_name(self) -> str: - return F1_SCORE_NAME - - def result(self, tp: float, tn: float, fp: float, fn: float) -> float: - del tn - # This is the harmonic mean of precision and recall or the same as - # 2 * (precision * recall) / (precision + recall). - # See https://en.wikipedia.org/wiki/Confusion_matrix for more information. - return _divide_only_positive_denominator( - numerator=2 * tp, denominator=2 * tp + fp + fn - ) + """F1 score.""" + + def __init__( + self, + thresholds: Optional[Union[float, List[float]]] = None, + name: Optional[str] = None, + top_k: Optional[int] = None, + class_id: Optional[int] = None, + ): + """Initializes F1 score. + + Args: + ---- + thresholds: (Optional) Thresholds to use. Defaults to [0.5]. + name: (Optional) Metric name. + top_k: (Optional) Used with a multi-class model to specify that the top-k + values should be used to compute the confusion matrix. The net effect is + that the non-top-k values are set to -inf and the matrix is then + constructed from the average TP, FP, TN, FN across the classes. When + top_k is used, metrics_specs.binarize settings must not be present. Only + one of class_id or top_k should be configured. When top_k is set, the + default thresholds are [float('-inf')]. + class_id: (Optional) Used with a multi-class model to specify which class + to compute the confusion matrix for. When class_id is used, + metrics_specs.binarize settings must not be present. Only one of + class_id or top_k should be configured. + """ + super().__init__( + thresholds=thresholds, name=name, top_k=top_k, class_id=class_id + ) + + def _default_name(self) -> str: + return F1_SCORE_NAME + + def result(self, tp: float, tn: float, fp: float, fn: float) -> float: + del tn + # This is the harmonic mean of precision and recall or the same as + # 2 * (precision * recall) / (precision + recall). + # See https://en.wikipedia.org/wiki/Confusion_matrix for more information. + return _divide_only_positive_denominator( + numerator=2 * tp, denominator=2 * tp + fp + fn + ) metric_types.register_metric(F1Score) class MatthewsCorrelationCoefficient(ConfusionMatrixMetric): - """Matthews corrrelation coefficient (MCC).""" - - def __init__(self, - thresholds: Optional[Union[float, List[float]]] = None, - name: Optional[str] = None, - top_k: Optional[int] = None, - class_id: Optional[int] = None): - """Initializes matthews corrrelation coefficient. - - Args: - thresholds: (Optional) Thresholds to use. Defaults to [0.5]. - name: (Optional) Metric name. - top_k: (Optional) Used with a multi-class model to specify that the top-k - values should be used to compute the confusion matrix. The net effect is - that the non-top-k values are set to -inf and the matrix is then - constructed from the average TP, FP, TN, FN across the classes. When - top_k is used, metrics_specs.binarize settings must not be present. Only - one of class_id or top_k should be configured. When top_k is set, the - default thresholds are [float('-inf')]. - class_id: (Optional) Used with a multi-class model to specify which class - to compute the confusion matrix for. When class_id is used, - metrics_specs.binarize settings must not be present. Only one of - class_id or top_k should be configured. - """ - super().__init__( - thresholds=thresholds, name=name, top_k=top_k, class_id=class_id) - - def _default_name(self) -> str: - return MATTHEWS_CORRELATION_COEFFICIENT_NAME - - def result(self, tp: float, tn: float, fp: float, fn: float) -> float: - return _divide_only_positive_denominator( - numerator=tp * tn - fp * fn, - denominator=_pos_sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn)), - ) + """Matthews corrrelation coefficient (MCC).""" + + def __init__( + self, + thresholds: Optional[Union[float, List[float]]] = None, + name: Optional[str] = None, + top_k: Optional[int] = None, + class_id: Optional[int] = None, + ): + """Initializes matthews corrrelation coefficient. + + Args: + ---- + thresholds: (Optional) Thresholds to use. Defaults to [0.5]. + name: (Optional) Metric name. + top_k: (Optional) Used with a multi-class model to specify that the top-k + values should be used to compute the confusion matrix. The net effect is + that the non-top-k values are set to -inf and the matrix is then + constructed from the average TP, FP, TN, FN across the classes. When + top_k is used, metrics_specs.binarize settings must not be present. Only + one of class_id or top_k should be configured. When top_k is set, the + default thresholds are [float('-inf')]. + class_id: (Optional) Used with a multi-class model to specify which class + to compute the confusion matrix for. When class_id is used, + metrics_specs.binarize settings must not be present. Only one of + class_id or top_k should be configured. + """ + super().__init__( + thresholds=thresholds, name=name, top_k=top_k, class_id=class_id + ) + + def _default_name(self) -> str: + return MATTHEWS_CORRELATION_COEFFICIENT_NAME + + def result(self, tp: float, tn: float, fp: float, fn: float) -> float: + return _divide_only_positive_denominator( + numerator=tp * tn - fp * fn, + denominator=_pos_sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn)), + ) metric_types.register_metric(MatthewsCorrelationCoefficient) class FowlkesMallowsIndex(ConfusionMatrixMetric): - """Fowlkes-Mallows index (FM).""" - - def __init__(self, - thresholds: Optional[Union[float, List[float]]] = None, - name: Optional[str] = None, - top_k: Optional[int] = None, - class_id: Optional[int] = None): - """Initializes fowlkes-mallows index. - - Args: - thresholds: (Optional) Thresholds to use. Defaults to [0.5]. - name: (Optional) Metric name. - top_k: (Optional) Used with a multi-class model to specify that the top-k - values should be used to compute the confusion matrix. The net effect is - that the non-top-k values are set to -inf and the matrix is then - constructed from the average TP, FP, TN, FN across the classes. When - top_k is used, metrics_specs.binarize settings must not be present. Only - one of class_id or top_k should be configured. When top_k is set, the - default thresholds are [float('-inf')]. - class_id: (Optional) Used with a multi-class model to specify which class - to compute the confusion matrix for. When class_id is used, - metrics_specs.binarize settings must not be present. Only one of - class_id or top_k should be configured. - """ - super().__init__( - thresholds=thresholds, name=name, top_k=top_k, class_id=class_id) - - def _default_name(self) -> str: - return FOWLKES_MALLOWS_INDEX_NAME - - def result(self, tp: float, tn: float, fp: float, fn: float) -> float: - del tn - ppv_denominator = tp + fp - tpr_denominator = tp + fn - if ppv_denominator > 0.0 and tpr_denominator > 0.0: - ppv = tp / ppv_denominator - tnr = tp / tpr_denominator - return _pos_sqrt(ppv * tnr) - else: - return float('nan') + """Fowlkes-Mallows index (FM).""" + + def __init__( + self, + thresholds: Optional[Union[float, List[float]]] = None, + name: Optional[str] = None, + top_k: Optional[int] = None, + class_id: Optional[int] = None, + ): + """Initializes fowlkes-mallows index. + + Args: + ---- + thresholds: (Optional) Thresholds to use. Defaults to [0.5]. + name: (Optional) Metric name. + top_k: (Optional) Used with a multi-class model to specify that the top-k + values should be used to compute the confusion matrix. The net effect is + that the non-top-k values are set to -inf and the matrix is then + constructed from the average TP, FP, TN, FN across the classes. When + top_k is used, metrics_specs.binarize settings must not be present. Only + one of class_id or top_k should be configured. When top_k is set, the + default thresholds are [float('-inf')]. + class_id: (Optional) Used with a multi-class model to specify which class + to compute the confusion matrix for. When class_id is used, + metrics_specs.binarize settings must not be present. Only one of + class_id or top_k should be configured. + """ + super().__init__( + thresholds=thresholds, name=name, top_k=top_k, class_id=class_id + ) + + def _default_name(self) -> str: + return FOWLKES_MALLOWS_INDEX_NAME + + def result(self, tp: float, tn: float, fp: float, fn: float) -> float: + del tn + ppv_denominator = tp + fp + tpr_denominator = tp + fn + if ppv_denominator > 0.0 and tpr_denominator > 0.0: + ppv = tp / ppv_denominator + tnr = tp / tpr_denominator + return _pos_sqrt(ppv * tnr) + else: + return float("nan") metric_types.register_metric(FowlkesMallowsIndex) class Informedness(ConfusionMatrixMetric): - """Informedness or bookmaker informedness (BM).""" - - def __init__(self, - thresholds: Optional[Union[float, List[float]]] = None, - name: Optional[str] = None, - top_k: Optional[int] = None, - class_id: Optional[int] = None): - """Initializes informedness. - - Args: - thresholds: (Optional) Thresholds to use. Defaults to [0.5]. - name: (Optional) Metric name. - top_k: (Optional) Used with a multi-class model to specify that the top-k - values should be used to compute the confusion matrix. The net effect is - that the non-top-k values are set to -inf and the matrix is then - constructed from the average TP, FP, TN, FN across the classes. When - top_k is used, metrics_specs.binarize settings must not be present. Only - one of class_id or top_k should be configured. When top_k is set, the - default thresholds are [float('-inf')]. - class_id: (Optional) Used with a multi-class model to specify which class - to compute the confusion matrix for. When class_id is used, - metrics_specs.binarize settings must not be present. Only one of - class_id or top_k should be configured. - """ - super().__init__( - thresholds=thresholds, name=name, top_k=top_k, class_id=class_id) - - def _default_name(self) -> str: - return INFORMEDNESS_NAME - - def result(self, tp: float, tn: float, fp: float, fn: float) -> float: - positives = tp + fn - negatives = tn + fp - if positives > 0.0 and negatives > 0.0: - tpr = tp / positives - tnr = tn / negatives - return tpr + tnr - 1 - else: - return float('nan') + """Informedness or bookmaker informedness (BM).""" + + def __init__( + self, + thresholds: Optional[Union[float, List[float]]] = None, + name: Optional[str] = None, + top_k: Optional[int] = None, + class_id: Optional[int] = None, + ): + """Initializes informedness. + + Args: + ---- + thresholds: (Optional) Thresholds to use. Defaults to [0.5]. + name: (Optional) Metric name. + top_k: (Optional) Used with a multi-class model to specify that the top-k + values should be used to compute the confusion matrix. The net effect is + that the non-top-k values are set to -inf and the matrix is then + constructed from the average TP, FP, TN, FN across the classes. When + top_k is used, metrics_specs.binarize settings must not be present. Only + one of class_id or top_k should be configured. When top_k is set, the + default thresholds are [float('-inf')]. + class_id: (Optional) Used with a multi-class model to specify which class + to compute the confusion matrix for. When class_id is used, + metrics_specs.binarize settings must not be present. Only one of + class_id or top_k should be configured. + """ + super().__init__( + thresholds=thresholds, name=name, top_k=top_k, class_id=class_id + ) + + def _default_name(self) -> str: + return INFORMEDNESS_NAME + + def result(self, tp: float, tn: float, fp: float, fn: float) -> float: + positives = tp + fn + negatives = tn + fp + if positives > 0.0 and negatives > 0.0: + tpr = tp / positives + tnr = tn / negatives + return tpr + tnr - 1 + else: + return float("nan") metric_types.register_metric(Informedness) class Markedness(ConfusionMatrixMetric): - """Markedness (MK) or deltaP.""" - - def __init__(self, - thresholds: Optional[Union[float, List[float]]] = None, - name: Optional[str] = None, - top_k: Optional[int] = None, - class_id: Optional[int] = None): - """Initializes markedness. - - Args: - thresholds: (Optional) Thresholds to use. Defaults to [0.5]. - name: (Optional) Metric name. - top_k: (Optional) Used with a multi-class model to specify that the top-k - values should be used to compute the confusion matrix. The net effect is - that the non-top-k values are set to -inf and the matrix is then - constructed from the average TP, FP, TN, FN across the classes. When - top_k is used, metrics_specs.binarize settings must not be present. Only - one of class_id or top_k should be configured. When top_k is set, the - default thresholds are [float('-inf')]. - class_id: (Optional) Used with a multi-class model to specify which class - to compute the confusion matrix for. When class_id is used, - metrics_specs.binarize settings must not be present. Only one of - class_id or top_k should be configured. - """ - super().__init__( - thresholds=thresholds, name=name, top_k=top_k, class_id=class_id) - - def _default_name(self) -> str: - return MARKEDNESS_NAME - - def result(self, tp: float, tn: float, fp: float, fn: float) -> float: - ppv_denominator = tp + fp - npv_denominator = tn + fn - if ppv_denominator > 0.0 and npv_denominator > 0.0: - ppv = tp / ppv_denominator - npv = tn / npv_denominator - return ppv + npv - 1 - else: - return float('nan') + """Markedness (MK) or deltaP.""" + + def __init__( + self, + thresholds: Optional[Union[float, List[float]]] = None, + name: Optional[str] = None, + top_k: Optional[int] = None, + class_id: Optional[int] = None, + ): + """Initializes markedness. + + Args: + ---- + thresholds: (Optional) Thresholds to use. Defaults to [0.5]. + name: (Optional) Metric name. + top_k: (Optional) Used with a multi-class model to specify that the top-k + values should be used to compute the confusion matrix. The net effect is + that the non-top-k values are set to -inf and the matrix is then + constructed from the average TP, FP, TN, FN across the classes. When + top_k is used, metrics_specs.binarize settings must not be present. Only + one of class_id or top_k should be configured. When top_k is set, the + default thresholds are [float('-inf')]. + class_id: (Optional) Used with a multi-class model to specify which class + to compute the confusion matrix for. When class_id is used, + metrics_specs.binarize settings must not be present. Only one of + class_id or top_k should be configured. + """ + super().__init__( + thresholds=thresholds, name=name, top_k=top_k, class_id=class_id + ) + + def _default_name(self) -> str: + return MARKEDNESS_NAME + + def result(self, tp: float, tn: float, fp: float, fn: float) -> float: + ppv_denominator = tp + fp + npv_denominator = tn + fn + if ppv_denominator > 0.0 and npv_denominator > 0.0: + ppv = tp / ppv_denominator + npv = tn / npv_denominator + return ppv + npv - 1 + else: + return float("nan") metric_types.register_metric(Markedness) class PositiveLikelihoodRatio(ConfusionMatrixMetric): - """Positive likelihood ratio (LR+).""" - - def __init__(self, - thresholds: Optional[Union[float, List[float]]] = None, - name: Optional[str] = None, - top_k: Optional[int] = None, - class_id: Optional[int] = None): - """Initializes positive likelihood ratio. - - Args: - thresholds: (Optional) Thresholds to use. Defaults to [0.5]. - name: (Optional) Metric name. - top_k: (Optional) Used with a multi-class model to specify that the top-k - values should be used to compute the confusion matrix. The net effect is - that the non-top-k values are set to -inf and the matrix is then - constructed from the average TP, FP, TN, FN across the classes. When - top_k is used, metrics_specs.binarize settings must not be present. Only - one of class_id or top_k should be configured. When top_k is set, the - default thresholds are [float('-inf')]. - class_id: (Optional) Used with a multi-class model to specify which class - to compute the confusion matrix for. When class_id is used, - metrics_specs.binarize settings must not be present. Only one of - class_id or top_k should be configured. - """ - super().__init__( - thresholds=thresholds, name=name, top_k=top_k, class_id=class_id) - - def _default_name(self) -> str: - return POSITIVE_LIKELIHOOD_RATIO_NAME - - def result(self, tp: float, tn: float, fp: float, fn: float) -> float: - tpr_denominator = tp + fn - fpr_denominator = fp + tn - if tpr_denominator > 0.0 and fpr_denominator > 0.0 and fp > 0.0: - tpr = tp / tpr_denominator - fpr = fp / fpr_denominator - return tpr / fpr - else: - return float('nan') + """Positive likelihood ratio (LR+).""" + + def __init__( + self, + thresholds: Optional[Union[float, List[float]]] = None, + name: Optional[str] = None, + top_k: Optional[int] = None, + class_id: Optional[int] = None, + ): + """Initializes positive likelihood ratio. + + Args: + ---- + thresholds: (Optional) Thresholds to use. Defaults to [0.5]. + name: (Optional) Metric name. + top_k: (Optional) Used with a multi-class model to specify that the top-k + values should be used to compute the confusion matrix. The net effect is + that the non-top-k values are set to -inf and the matrix is then + constructed from the average TP, FP, TN, FN across the classes. When + top_k is used, metrics_specs.binarize settings must not be present. Only + one of class_id or top_k should be configured. When top_k is set, the + default thresholds are [float('-inf')]. + class_id: (Optional) Used with a multi-class model to specify which class + to compute the confusion matrix for. When class_id is used, + metrics_specs.binarize settings must not be present. Only one of + class_id or top_k should be configured. + """ + super().__init__( + thresholds=thresholds, name=name, top_k=top_k, class_id=class_id + ) + + def _default_name(self) -> str: + return POSITIVE_LIKELIHOOD_RATIO_NAME + + def result(self, tp: float, tn: float, fp: float, fn: float) -> float: + tpr_denominator = tp + fn + fpr_denominator = fp + tn + if tpr_denominator > 0.0 and fpr_denominator > 0.0 and fp > 0.0: + tpr = tp / tpr_denominator + fpr = fp / fpr_denominator + return tpr / fpr + else: + return float("nan") metric_types.register_metric(PositiveLikelihoodRatio) class NegativeLikelihoodRatio(ConfusionMatrixMetric): - """Negative likelihood ratio (LR-).""" - - def __init__(self, - thresholds: Optional[Union[float, List[float]]] = None, - name: Optional[str] = None, - top_k: Optional[int] = None, - class_id: Optional[int] = None): - """Initializes negative likelihood ratio. - - Args: - thresholds: (Optional) Thresholds to use. Defaults to [0.5]. - name: (Optional) Metric name. - top_k: (Optional) Used with a multi-class model to specify that the top-k - values should be used to compute the confusion matrix. The net effect is - that the non-top-k values are set to -inf and the matrix is then - constructed from the average TP, FP, TN, FN across the classes. When - top_k is used, metrics_specs.binarize settings must not be present. Only - one of class_id or top_k should be configured. When top_k is set, the - default thresholds are [float('-inf')]. - class_id: (Optional) Used with a multi-class model to specify which class - to compute the confusion matrix for. When class_id is used, - metrics_specs.binarize settings must not be present. Only one of - class_id or top_k should be configured. - """ - super().__init__( - thresholds=thresholds, name=name, top_k=top_k, class_id=class_id) - - def _default_name(self) -> str: - return NEGATIVE_LIKELIHOOD_RATIO_NAME - - def result(self, tp: float, tn: float, fp: float, fn: float) -> float: - fnr_denominator = fn + tp - tnr_denominator = tn + fp - if fnr_denominator > 0.0 and tnr_denominator > 0.0 and tn > 0.0: - fnr = fn / fnr_denominator - tnr = tn / tnr_denominator - return fnr / tnr - else: - return float('nan') + """Negative likelihood ratio (LR-).""" + + def __init__( + self, + thresholds: Optional[Union[float, List[float]]] = None, + name: Optional[str] = None, + top_k: Optional[int] = None, + class_id: Optional[int] = None, + ): + """Initializes negative likelihood ratio. + + Args: + ---- + thresholds: (Optional) Thresholds to use. Defaults to [0.5]. + name: (Optional) Metric name. + top_k: (Optional) Used with a multi-class model to specify that the top-k + values should be used to compute the confusion matrix. The net effect is + that the non-top-k values are set to -inf and the matrix is then + constructed from the average TP, FP, TN, FN across the classes. When + top_k is used, metrics_specs.binarize settings must not be present. Only + one of class_id or top_k should be configured. When top_k is set, the + default thresholds are [float('-inf')]. + class_id: (Optional) Used with a multi-class model to specify which class + to compute the confusion matrix for. When class_id is used, + metrics_specs.binarize settings must not be present. Only one of + class_id or top_k should be configured. + """ + super().__init__( + thresholds=thresholds, name=name, top_k=top_k, class_id=class_id + ) + + def _default_name(self) -> str: + return NEGATIVE_LIKELIHOOD_RATIO_NAME + + def result(self, tp: float, tn: float, fp: float, fn: float) -> float: + fnr_denominator = fn + tp + tnr_denominator = tn + fp + if fnr_denominator > 0.0 and tnr_denominator > 0.0 and tn > 0.0: + fnr = fn / fnr_denominator + tnr = tn / tnr_denominator + return fnr / tnr + else: + return float("nan") metric_types.register_metric(NegativeLikelihoodRatio) class DiagnosticOddsRatio(ConfusionMatrixMetric): - """Diagnostic odds ratio (DOR).""" - - def __init__(self, - thresholds: Optional[Union[float, List[float]]] = None, - name: Optional[str] = None, - top_k: Optional[int] = None, - class_id: Optional[int] = None): - """Initializes diagnostic odds ratio. - - Args: - thresholds: (Optional) Thresholds to use. Defaults to [0.5]. - name: (Optional) Metric name. - top_k: (Optional) Used with a multi-class model to specify that the top-k - values should be used to compute the confusion matrix. The net effect is - that the non-top-k values are set to -inf and the matrix is then - constructed from the average TP, FP, TN, FN across the classes. When - top_k is used, metrics_specs.binarize settings must not be present. Only - one of class_id or top_k should be configured. When top_k is set, the - default thresholds are [float('-inf')]. - class_id: (Optional) Used with a multi-class model to specify which class - to compute the confusion matrix for. When class_id is used, - metrics_specs.binarize settings must not be present. Only one of - class_id or top_k should be configured. - """ - super().__init__( - thresholds=thresholds, name=name, top_k=top_k, class_id=class_id) - - def _default_name(self) -> str: - return DIAGNOSTIC_ODDS_RATIO_NAME - - def result(self, tp: float, tn: float, fp: float, fn: float) -> float: - if fn > 0.0 and fp > 0.0 and tn > 0.0: - return (tp / fn) / (fp / tn) - else: - return float('nan') + """Diagnostic odds ratio (DOR).""" + + def __init__( + self, + thresholds: Optional[Union[float, List[float]]] = None, + name: Optional[str] = None, + top_k: Optional[int] = None, + class_id: Optional[int] = None, + ): + """Initializes diagnostic odds ratio. + + Args: + ---- + thresholds: (Optional) Thresholds to use. Defaults to [0.5]. + name: (Optional) Metric name. + top_k: (Optional) Used with a multi-class model to specify that the top-k + values should be used to compute the confusion matrix. The net effect is + that the non-top-k values are set to -inf and the matrix is then + constructed from the average TP, FP, TN, FN across the classes. When + top_k is used, metrics_specs.binarize settings must not be present. Only + one of class_id or top_k should be configured. When top_k is set, the + default thresholds are [float('-inf')]. + class_id: (Optional) Used with a multi-class model to specify which class + to compute the confusion matrix for. When class_id is used, + metrics_specs.binarize settings must not be present. Only one of + class_id or top_k should be configured. + """ + super().__init__( + thresholds=thresholds, name=name, top_k=top_k, class_id=class_id + ) + + def _default_name(self) -> str: + return DIAGNOSTIC_ODDS_RATIO_NAME + + def result(self, tp: float, tn: float, fp: float, fn: float) -> float: + if fn > 0.0 and fp > 0.0 and tn > 0.0: + return (tp / fn) / (fp / tn) + else: + return float("nan") metric_types.register_metric(DiagnosticOddsRatio) class PredictedPositiveRate(ConfusionMatrixMetric): - """Predicted positive rate.""" - - def __init__(self, - thresholds: Optional[Union[float, List[float]]] = None, - name: Optional[str] = None, - top_k: Optional[int] = None, - class_id: Optional[int] = None): - """Initializes predicted positive rate. - - Args: - thresholds: (Optional) Thresholds to use. Defaults to [0.5]. - name: (Optional) Metric name. - top_k: (Optional) Used with a multi-class model to specify that the top-k - values should be used to compute the confusion matrix. The net effect is - that the non-top-k values are set to -inf and the matrix is then - constructed from the average TP, FP, TN, FN across the classes. When - top_k is used, metrics_specs.binarize settings must not be present. Only - one of class_id or top_k should be configured. When top_k is set, the - default thresholds are [float('-inf')]. - class_id: (Optional) Used with a multi-class model to specify which class - to compute the confusion matrix for. When class_id is used, - metrics_specs.binarize settings must not be present. Only one of - class_id or top_k should be configured. - """ - super().__init__( - thresholds=thresholds, name=name, top_k=top_k, class_id=class_id) - - def _default_name(self) -> str: - return PREDICTED_POSITIVE_RATE_NAME - - def result(self, tp: float, tn: float, fp: float, fn: float) -> float: - predicted_positives = tp + fp - total_count = tp + fp + tn + fn - return _divide_only_positive_denominator(predicted_positives, total_count) + """Predicted positive rate.""" + + def __init__( + self, + thresholds: Optional[Union[float, List[float]]] = None, + name: Optional[str] = None, + top_k: Optional[int] = None, + class_id: Optional[int] = None, + ): + """Initializes predicted positive rate. + + Args: + ---- + thresholds: (Optional) Thresholds to use. Defaults to [0.5]. + name: (Optional) Metric name. + top_k: (Optional) Used with a multi-class model to specify that the top-k + values should be used to compute the confusion matrix. The net effect is + that the non-top-k values are set to -inf and the matrix is then + constructed from the average TP, FP, TN, FN across the classes. When + top_k is used, metrics_specs.binarize settings must not be present. Only + one of class_id or top_k should be configured. When top_k is set, the + default thresholds are [float('-inf')]. + class_id: (Optional) Used with a multi-class model to specify which class + to compute the confusion matrix for. When class_id is used, + metrics_specs.binarize settings must not be present. Only one of + class_id or top_k should be configured. + """ + super().__init__( + thresholds=thresholds, name=name, top_k=top_k, class_id=class_id + ) + + def _default_name(self) -> str: + return PREDICTED_POSITIVE_RATE_NAME + + def result(self, tp: float, tn: float, fp: float, fn: float) -> float: + predicted_positives = tp + fp + total_count = tp + fp + tn + fn + return _divide_only_positive_denominator(predicted_positives, total_count) metric_types.register_metric(PredictedPositiveRate) -class ConfusionMatrixFeatureSamplerBase( - metric_types.Metric, metaclass=abc.ABCMeta -): - """Base class for metrics that sample features per confusion matrix case.""" - - def __init__( - self, - feature_key: str, - sample_size: int, - threshold: float, - name: Optional[str] = None, - top_k: Optional[int] = None, - class_id: Optional[int] = None, - ): - """Initializes confusion matrix samples at thresholds. - - Args: - feature_key: Feature key to sample. - sample_size: Number of samples to collect per confusion matrix case. - threshold: (Optional) Defaults to [0.5]. A float value in [0, 1]. A - threshold is compared with prediction values to determine the truth - value of predictions (i.e., above the threshold is `true`, below is - `false`). One metric value is generated for each threshold value. - name: (Optional) Metric name. - top_k: (Optional) Used with a multi-class model to specify that the top-k - values should be used to compute the confusion matrix. The net effect is - that the non-top-k values are set to -inf and the matrix is then - constructed from the average TP, FP, TN, FN across the classes. When - top_k is used, metrics_specs.binarize settings must not be present. Only - one of class_id or top_k should be configured. When top_k is set, the - default thresholds are [float('-inf')]. - class_id: (Optional) Used with a multi-class model to specify which class - to compute the confusion matrix for. When class_id is used, - metrics_specs.binarize settings must not be present. Only one of - class_id or top_k should be configured. - """ - super().__init__( - metric_util.merge_per_key_computations(self._metric_computations), - threshold=threshold, - feature_key=feature_key, - sample_size=sample_size, - name=name, - top_k=top_k, - class_id=class_id, - ) - - @abc.abstractmethod - def _get_samples( - self, examples: binary_confusion_matrices.Examples - ) -> np.ndarray: - """Returns the samples for the given examples. - - Note that the storage format for examples supports multiple thresholds, - however - this base class only supports a single threshold. This means that a typical - _get_samples implementation should index into the first element for each - confusion matrix case, as in examples.tp_examples[0]. - - Args: - examples: The binary_confusion_matrices.Examples NamedTuple object from - which to get the appropriate samples. - """ - - def _metric_computations( - self, - feature_key: str, - sample_size: int, - threshold: Optional[float] = None, - top_k: Optional[int] = None, - class_id: Optional[int] = None, - name: Optional[str] = None, - eval_config: Optional[config_pb2.EvalConfig] = None, - model_name: str = '', - output_name: str = '', - sub_key: Optional[metric_types.SubKey] = None, - aggregation_type: Optional[metric_types.AggregationType] = None, - class_weights: Optional[Dict[int, float]] = None, - example_weighted: bool = False, - ) -> metric_types.MetricComputations: - """Returns metric computations for confusion matrix at thresholds.""" - sub_key = _validate_and_update_sub_key( - name, model_name, output_name, sub_key, top_k, class_id - ) - - # Make sure matrices are calculated with examples - matrices_computations = binary_confusion_matrices.binary_confusion_matrices( - thresholds=[threshold], - example_id_key=feature_key, - example_ids_count=sample_size, - use_histogram=False, - preprocessors=[ - metric_types.FeaturePreprocessor(feature_keys=[feature_key]) - ], - eval_config=eval_config, - model_name=model_name, - output_name=output_name, - sub_key=sub_key, - aggregation_type=aggregation_type, - class_weights=class_weights, - example_weighted=example_weighted, - ) - examples_key = matrices_computations[-1].keys[0] - - output_key = metric_types.MetricKey( - name=name or self._default_name(), - model_name=model_name, - output_name=output_name, - sub_key=sub_key, - example_weighted=example_weighted, - aggregation_type=aggregation_type, - ) - - def result(metrics): - metrics[output_key] = self._get_samples(metrics[examples_key]) - return metrics - - derived_computation = metric_types.DerivedMetricComputation( - keys=[], result=result - ) - computations = matrices_computations - computations.append(derived_computation) - return computations +class ConfusionMatrixFeatureSamplerBase(metric_types.Metric, metaclass=abc.ABCMeta): + """Base class for metrics that sample features per confusion matrix case.""" + + def __init__( + self, + feature_key: str, + sample_size: int, + threshold: float, + name: Optional[str] = None, + top_k: Optional[int] = None, + class_id: Optional[int] = None, + ): + """Initializes confusion matrix samples at thresholds. + + Args: + ---- + feature_key: Feature key to sample. + sample_size: Number of samples to collect per confusion matrix case. + threshold: (Optional) Defaults to [0.5]. A float value in [0, 1]. A + threshold is compared with prediction values to determine the truth + value of predictions (i.e., above the threshold is `true`, below is + `false`). One metric value is generated for each threshold value. + name: (Optional) Metric name. + top_k: (Optional) Used with a multi-class model to specify that the top-k + values should be used to compute the confusion matrix. The net effect is + that the non-top-k values are set to -inf and the matrix is then + constructed from the average TP, FP, TN, FN across the classes. When + top_k is used, metrics_specs.binarize settings must not be present. Only + one of class_id or top_k should be configured. When top_k is set, the + default thresholds are [float('-inf')]. + class_id: (Optional) Used with a multi-class model to specify which class + to compute the confusion matrix for. When class_id is used, + metrics_specs.binarize settings must not be present. Only one of + class_id or top_k should be configured. + """ + super().__init__( + metric_util.merge_per_key_computations(self._metric_computations), + threshold=threshold, + feature_key=feature_key, + sample_size=sample_size, + name=name, + top_k=top_k, + class_id=class_id, + ) + + @abc.abstractmethod + def _get_samples(self, examples: binary_confusion_matrices.Examples) -> np.ndarray: + """Returns the samples for the given examples. + + Note that the storage format for examples supports multiple thresholds, + however + this base class only supports a single threshold. This means that a typical + _get_samples implementation should index into the first element for each + confusion matrix case, as in examples.tp_examples[0]. + + Args: + ---- + examples: The binary_confusion_matrices.Examples NamedTuple object from + which to get the appropriate samples. + """ + + def _metric_computations( + self, + feature_key: str, + sample_size: int, + threshold: Optional[float] = None, + top_k: Optional[int] = None, + class_id: Optional[int] = None, + name: Optional[str] = None, + eval_config: Optional[config_pb2.EvalConfig] = None, + model_name: str = "", + output_name: str = "", + sub_key: Optional[metric_types.SubKey] = None, + aggregation_type: Optional[metric_types.AggregationType] = None, + class_weights: Optional[Dict[int, float]] = None, + example_weighted: bool = False, + ) -> metric_types.MetricComputations: + """Returns metric computations for confusion matrix at thresholds.""" + sub_key = _validate_and_update_sub_key( + name, model_name, output_name, sub_key, top_k, class_id + ) + + # Make sure matrices are calculated with examples + matrices_computations = binary_confusion_matrices.binary_confusion_matrices( + thresholds=[threshold], + example_id_key=feature_key, + example_ids_count=sample_size, + use_histogram=False, + preprocessors=[ + metric_types.FeaturePreprocessor(feature_keys=[feature_key]) + ], + eval_config=eval_config, + model_name=model_name, + output_name=output_name, + sub_key=sub_key, + aggregation_type=aggregation_type, + class_weights=class_weights, + example_weighted=example_weighted, + ) + examples_key = matrices_computations[-1].keys[0] + + output_key = metric_types.MetricKey( + name=name or self._default_name(), + model_name=model_name, + output_name=output_name, + sub_key=sub_key, + example_weighted=example_weighted, + aggregation_type=aggregation_type, + ) + + def result(metrics): + metrics[output_key] = self._get_samples(metrics[examples_key]) + return metrics + + derived_computation = metric_types.DerivedMetricComputation( + keys=[], result=result + ) + computations = matrices_computations + computations.append(derived_computation) + return computations class FalsePositiveFeatureSampler(ConfusionMatrixFeatureSamplerBase): - """False positive feature samples.""" - - def __init__( - self, - feature_key: str, - sample_size: int, - threshold: float = 0.5, - name: Optional[str] = None, - top_k: Optional[int] = None, - class_id: Optional[int] = None, - ): - """Initializes FalsePositiveFeatureSampler metric. - - Args: - feature_key: Feature key to sample. - sample_size: Number of samples to collect per confusion matrix case. - threshold: (Optional) Defaults to [0.5]. A float value in [0, 1]. A - threshold is compared with prediction values to determine the truth - value of predictions (i.e., above the threshold is `true`, below is - `false`). One metric value is generated for each threshold value. - name: (Optional) Metric name. - top_k: (Optional) Used with a multi-class model to specify that the top-k - values should be used to compute the confusion matrix. The net effect is - that the non-top-k values are set to -inf and the matrix is then - constructed from the average TP, FP, TN, FN across the classes. When - top_k is used, metrics_specs.binarize settings must not be present. Only - one of class_id or top_k should be configured. When top_k is set, the - default thresholds are [float('-inf')]. - class_id: (Optional) Used with a multi-class model to specify which class - to compute the confusion matrix for. When class_id is used, - metrics_specs.binarize settings must not be present. Only one of - class_id or top_k should be configured. - """ - super().__init__( - feature_key=feature_key, - sample_size=sample_size, - threshold=threshold, - name=name, - top_k=top_k, - class_id=class_id, - ) - - def _get_samples( - self, examples: binary_confusion_matrices.Examples - ) -> np.ndarray: - assert len(examples.fp_examples) == 1, 'Expected exactly one threshold' - result = np.concatenate(examples.fp_examples[0]) - return result - - def _default_name(self) -> str: - return FALSE_POSITIVE_FEATURE_SAMPLER_NAME + """False positive feature samples.""" + + def __init__( + self, + feature_key: str, + sample_size: int, + threshold: float = 0.5, + name: Optional[str] = None, + top_k: Optional[int] = None, + class_id: Optional[int] = None, + ): + """Initializes FalsePositiveFeatureSampler metric. + + Args: + ---- + feature_key: Feature key to sample. + sample_size: Number of samples to collect per confusion matrix case. + threshold: (Optional) Defaults to [0.5]. A float value in [0, 1]. A + threshold is compared with prediction values to determine the truth + value of predictions (i.e., above the threshold is `true`, below is + `false`). One metric value is generated for each threshold value. + name: (Optional) Metric name. + top_k: (Optional) Used with a multi-class model to specify that the top-k + values should be used to compute the confusion matrix. The net effect is + that the non-top-k values are set to -inf and the matrix is then + constructed from the average TP, FP, TN, FN across the classes. When + top_k is used, metrics_specs.binarize settings must not be present. Only + one of class_id or top_k should be configured. When top_k is set, the + default thresholds are [float('-inf')]. + class_id: (Optional) Used with a multi-class model to specify which class + to compute the confusion matrix for. When class_id is used, + metrics_specs.binarize settings must not be present. Only one of + class_id or top_k should be configured. + """ + super().__init__( + feature_key=feature_key, + sample_size=sample_size, + threshold=threshold, + name=name, + top_k=top_k, + class_id=class_id, + ) + + def _get_samples(self, examples: binary_confusion_matrices.Examples) -> np.ndarray: + assert len(examples.fp_examples) == 1, "Expected exactly one threshold" + result = np.concatenate(examples.fp_examples[0]) + return result + + def _default_name(self) -> str: + return FALSE_POSITIVE_FEATURE_SAMPLER_NAME metric_types.register_metric(FalsePositiveFeatureSampler) class FalseNegativeFeatureSampler(ConfusionMatrixFeatureSamplerBase): - """False negative feature samples.""" - - def __init__( - self, - feature_key: str, - sample_size: int, - threshold: float = 0.5, - name: Optional[str] = None, - top_k: Optional[int] = None, - class_id: Optional[int] = None, - ): - """Initializes FalseNegativeFeatureSampler metric. - - Args: - feature_key: Feature key to sample. - sample_size: Number of samples to collect per confusion matrix case. - threshold: (Optional) Defaults to [0.5]. A float value in [0, 1]. A - threshold is compared with prediction values to determine the truth - value of predictions (i.e., above the threshold is `true`, below is - `false`). One metric value is generated for each threshold value. - name: (Optional) Metric name. - top_k: (Optional) Used with a multi-class model to specify that the top-k - values should be used to compute the confusion matrix. The net effect is - that the non-top-k values are set to -inf and the matrix is then - constructed from the average TP, FP, TN, FN across the classes. When - top_k is used, metrics_specs.binarize settings must not be present. Only - one of class_id or top_k should be configured. When top_k is set, the - default thresholds are [float('-inf')]. - class_id: (Optional) Used with a multi-class model to specify which class - to compute the confusion matrix for. When class_id is used, - metrics_specs.binarize settings must not be present. Only one of - class_id or top_k should be configured. - """ - super().__init__( - feature_key=feature_key, - sample_size=sample_size, - threshold=threshold, - name=name, - top_k=top_k, - class_id=class_id, - ) - - def _get_samples( - self, examples: binary_confusion_matrices.Examples - ) -> np.ndarray: - assert len(examples.fp_examples) == 1, 'Expected exactly one threshold' - result = np.concatenate(examples.fn_examples[0]) - return result - - def _default_name(self) -> str: - return FALSE_NEGATIVE_FEATURE_SAMPLER_NAME + """False negative feature samples.""" + + def __init__( + self, + feature_key: str, + sample_size: int, + threshold: float = 0.5, + name: Optional[str] = None, + top_k: Optional[int] = None, + class_id: Optional[int] = None, + ): + """Initializes FalseNegativeFeatureSampler metric. + + Args: + ---- + feature_key: Feature key to sample. + sample_size: Number of samples to collect per confusion matrix case. + threshold: (Optional) Defaults to [0.5]. A float value in [0, 1]. A + threshold is compared with prediction values to determine the truth + value of predictions (i.e., above the threshold is `true`, below is + `false`). One metric value is generated for each threshold value. + name: (Optional) Metric name. + top_k: (Optional) Used with a multi-class model to specify that the top-k + values should be used to compute the confusion matrix. The net effect is + that the non-top-k values are set to -inf and the matrix is then + constructed from the average TP, FP, TN, FN across the classes. When + top_k is used, metrics_specs.binarize settings must not be present. Only + one of class_id or top_k should be configured. When top_k is set, the + default thresholds are [float('-inf')]. + class_id: (Optional) Used with a multi-class model to specify which class + to compute the confusion matrix for. When class_id is used, + metrics_specs.binarize settings must not be present. Only one of + class_id or top_k should be configured. + """ + super().__init__( + feature_key=feature_key, + sample_size=sample_size, + threshold=threshold, + name=name, + top_k=top_k, + class_id=class_id, + ) + + def _get_samples(self, examples: binary_confusion_matrices.Examples) -> np.ndarray: + assert len(examples.fp_examples) == 1, "Expected exactly one threshold" + result = np.concatenate(examples.fn_examples[0]) + return result + + def _default_name(self) -> str: + return FALSE_NEGATIVE_FEATURE_SAMPLER_NAME metric_types.register_metric(FalseNegativeFeatureSampler) class ConfusionMatrixAtThresholds(metric_types.Metric): - """Confusion matrix at thresholds.""" - - def __init__(self, - thresholds: List[float], - name: Optional[str] = None, - top_k: Optional[int] = None, - class_id: Optional[int] = None): - """Initializes confusion matrix at thresholds. - - Args: - thresholds: Thresholds to use for confusion matrix. - name: (Optional) Metric name. - top_k: (Optional) Used with a multi-class model to specify that the top-k - values should be used to compute the confusion matrix. The net effect is - that the non-top-k values are set to -inf and the matrix is then - constructed from the average TP, FP, TN, FN across the classes. When - top_k is used, metrics_specs.binarize settings must not be present. Only - one of class_id or top_k should be configured. When top_k is set, the - default thresholds are [float('-inf')]. - class_id: (Optional) Used with a multi-class model to specify which class - to compute the confusion matrix for. When class_id is used, - metrics_specs.binarize settings must not be present. Only one of - class_id or top_k should be configured. - """ - super().__init__( - metric_util.merge_per_key_computations(self._metric_computations), - thresholds=thresholds, - name=name, - top_k=top_k, - class_id=class_id) - - def _default_name(self) -> str: - return CONFUSION_MATRIX_AT_THRESHOLDS_NAME - - def _metric_computations( - self, - thresholds: List[float], - top_k: Optional[int] = None, - class_id: Optional[int] = None, - name: Optional[str] = None, - eval_config: Optional[config_pb2.EvalConfig] = None, - model_name: str = '', - output_name: str = '', - sub_key: Optional[metric_types.SubKey] = None, - aggregation_type: Optional[metric_types.AggregationType] = None, - class_weights: Optional[Dict[int, float]] = None, - example_weighted: bool = False) -> metric_types.MetricComputations: - """Returns metric computations for confusion matrix at thresholds.""" - sub_key = _validate_and_update_sub_key(name, model_name, output_name, - sub_key, top_k, class_id) - key = metric_types.MetricKey( - name=name, - model_name=model_name, - output_name=output_name, - sub_key=sub_key, - example_weighted=example_weighted, - aggregation_type=aggregation_type) - - # Make sure matrices are calculated. - matrices_computations = binary_confusion_matrices.binary_confusion_matrices( - eval_config=eval_config, - model_name=model_name, - output_name=output_name, - sub_key=sub_key, - aggregation_type=aggregation_type, - class_weights=class_weights, - thresholds=thresholds, - example_weighted=example_weighted) - matrices_key = matrices_computations[-1].keys[-1] - - def result( - metrics: Dict[metric_types.MetricKey, - binary_confusion_matrices.Matrices] - ) -> Dict[metric_types.MetricKey, Any]: - return {key: metrics[matrices_key]} - - derived_computation = metric_types.DerivedMetricComputation( - keys=[key], result=result) - computations = matrices_computations - computations.append(derived_computation) - return computations + """Confusion matrix at thresholds.""" + + def __init__( + self, + thresholds: List[float], + name: Optional[str] = None, + top_k: Optional[int] = None, + class_id: Optional[int] = None, + ): + """Initializes confusion matrix at thresholds. + + Args: + ---- + thresholds: Thresholds to use for confusion matrix. + name: (Optional) Metric name. + top_k: (Optional) Used with a multi-class model to specify that the top-k + values should be used to compute the confusion matrix. The net effect is + that the non-top-k values are set to -inf and the matrix is then + constructed from the average TP, FP, TN, FN across the classes. When + top_k is used, metrics_specs.binarize settings must not be present. Only + one of class_id or top_k should be configured. When top_k is set, the + default thresholds are [float('-inf')]. + class_id: (Optional) Used with a multi-class model to specify which class + to compute the confusion matrix for. When class_id is used, + metrics_specs.binarize settings must not be present. Only one of + class_id or top_k should be configured. + """ + super().__init__( + metric_util.merge_per_key_computations(self._metric_computations), + thresholds=thresholds, + name=name, + top_k=top_k, + class_id=class_id, + ) + + def _default_name(self) -> str: + return CONFUSION_MATRIX_AT_THRESHOLDS_NAME + + def _metric_computations( + self, + thresholds: List[float], + top_k: Optional[int] = None, + class_id: Optional[int] = None, + name: Optional[str] = None, + eval_config: Optional[config_pb2.EvalConfig] = None, + model_name: str = "", + output_name: str = "", + sub_key: Optional[metric_types.SubKey] = None, + aggregation_type: Optional[metric_types.AggregationType] = None, + class_weights: Optional[Dict[int, float]] = None, + example_weighted: bool = False, + ) -> metric_types.MetricComputations: + """Returns metric computations for confusion matrix at thresholds.""" + sub_key = _validate_and_update_sub_key( + name, model_name, output_name, sub_key, top_k, class_id + ) + key = metric_types.MetricKey( + name=name, + model_name=model_name, + output_name=output_name, + sub_key=sub_key, + example_weighted=example_weighted, + aggregation_type=aggregation_type, + ) + + # Make sure matrices are calculated. + matrices_computations = binary_confusion_matrices.binary_confusion_matrices( + eval_config=eval_config, + model_name=model_name, + output_name=output_name, + sub_key=sub_key, + aggregation_type=aggregation_type, + class_weights=class_weights, + thresholds=thresholds, + example_weighted=example_weighted, + ) + matrices_key = matrices_computations[-1].keys[-1] + + def result( + metrics: Dict[metric_types.MetricKey, binary_confusion_matrices.Matrices], + ) -> Dict[metric_types.MetricKey, Any]: + return {key: metrics[matrices_key]} + + derived_computation = metric_types.DerivedMetricComputation( + keys=[key], result=result + ) + computations = matrices_computations + computations.append(derived_computation) + return computations metric_types.register_metric(ConfusionMatrixAtThresholds) @@ -2671,119 +2872,132 @@ def result( class MaxRecall(Recall): - """Computes the max recall of the predictions with respect to the labels. - - The metric uses true positives and false negatives to compute recall by - dividing the true positives by the sum of true positives and false negatives. - - Effectively the recall at threshold = epsilon(1.0e-12). It is equilvalent - to the recall defined in COCO metrics. + """Computes the max recall of the predictions with respect to the labels. - If `sample_weight` is `None`, weights default to 1. - Use `sample_weight` of 0 to mask values. - """ + The metric uses true positives and false negatives to compute recall by + dividing the true positives by the sum of true positives and false negatives. - def __init__(self, - top_k: Optional[int] = None, - class_id: Optional[int] = None, - name: Optional[str] = None, - **kwargs): - """Initializes MaxRecall metrics, it calculates the maximum recall. + Effectively the recall at threshold = epsilon(1.0e-12). It is equilvalent + to the recall defined in COCO metrics. - Args: - top_k: (Optional) Used with a multi-class model to specify that the top-k - values should be used to compute the confusion matrix. The net effect is - that the non-top-k values are set to -inf and the matrix is then - constructed from the average TP, FP, TN, FN across the classes. When - top_k is used, metrics_specs.binarize settings must not be present. Only - one of class_id or top_k should be configured. When top_k is set, the - default thresholds are [float('-inf')]. - class_id: (Optional) Used with a multi-class model to specify which class - to compute the confusion matrix for. When class_id is used, - metrics_specs.binarize settings must not be present. Only one of - class_id or top_k should be configured. - name: (Optional) string name of the metric instance. - **kwargs: (Optional) Additional args to pass along to init (and eventually - on to _metric_computation and _metric_value) + If `sample_weight` is `None`, weights default to 1. + Use `sample_weight` of 0 to mask values. """ - super().__init__( - thresholds=_DEFAULT_THRESHOLD_FOR_MAX_RECALL, - top_k=top_k, - class_id=class_id, - name=name, - **kwargs) - def _default_name(self) -> str: - return MAX_RECALL_NAME + def __init__( + self, + top_k: Optional[int] = None, + class_id: Optional[int] = None, + name: Optional[str] = None, + **kwargs, + ): + """Initializes MaxRecall metrics, it calculates the maximum recall. + + Args: + ---- + top_k: (Optional) Used with a multi-class model to specify that the top-k + values should be used to compute the confusion matrix. The net effect is + that the non-top-k values are set to -inf and the matrix is then + constructed from the average TP, FP, TN, FN across the classes. When + top_k is used, metrics_specs.binarize settings must not be present. Only + one of class_id or top_k should be configured. When top_k is set, the + default thresholds are [float('-inf')]. + class_id: (Optional) Used with a multi-class model to specify which class + to compute the confusion matrix for. When class_id is used, + metrics_specs.binarize settings must not be present. Only one of + class_id or top_k should be configured. + name: (Optional) string name of the metric instance. + **kwargs: (Optional) Additional args to pass along to init (and eventually + on to _metric_computation and _metric_value) + """ + super().__init__( + thresholds=_DEFAULT_THRESHOLD_FOR_MAX_RECALL, + top_k=top_k, + class_id=class_id, + name=name, + **kwargs, + ) + + def _default_name(self) -> str: + return MAX_RECALL_NAME metric_types.register_metric(MaxRecall) class ThresholdAtRecall(ConfusionMatrixMetricBase): - """Computes the maximum threshold where recall is >= specified value. - - If `sample_weight` is `None`, weights default to 1. - Use `sample_weight` of 0 to mask values. - """ - - def __init__(self, - recall: Union[float, List[float]], - thresholds: Optional[List[float]] = None, - num_thresholds: Optional[int] = None, - class_id: Optional[int] = None, - name: Optional[str] = None, - top_k: Optional[int] = None, - **kwargs): - """Initializes ThresholdAtRecall metric. + """Computes the maximum threshold where recall is >= specified value. - Args: - recall: A scalar or a list of scalar values in range `[0, 1]`. - thresholds: (Optional) Thresholds to use for calculating the matrices. Use - one of either thresholds or num_thresholds. - num_thresholds: (Optional) Defaults to 1000. The number of thresholds to - use for matching the given recall. - class_id: (Optional) Used with a multi-class model to specify which class - to compute the confusion matrix for. When class_id is used, - metrics_specs.binarize settings must not be present. Only one of - class_id or top_k should be configured. - name: (Optional) string name of the metric instance. - top_k: (Optional) Used with a multi-class model to specify that the top-k - values should be used to compute the confusion matrix. The net effect is - that the non-top-k values are set to -inf and the matrix is then - constructed from the average TP, FP, TN, FN across the classes. When - top_k is used, metrics_specs.binarize settings must not be present. Only - one of class_id or top_k should be configured. When top_k is set, the - default thresholds are [float('-inf')]. - **kwargs: (Optional) Additional args to pass along to init (and eventually - on to _metric_computation and _metric_value) + If `sample_weight` is `None`, weights default to 1. + Use `sample_weight` of 0 to mask values. """ - for r in [recall] if isinstance(recall, float) else recall: - if r < 0 or r > 1: - raise ValueError('Argument `recall` must be in the range [0, 1]. ' - f'Received: recall={r}') - - super().__init__( - thresholds=thresholds, - num_thresholds=num_thresholds, - recall=recall, - class_id=class_id, - name=name, - top_k=top_k, - **kwargs) - - def _default_name(self) -> str: - return THRESHOLD_AT_RECALL_NAME - - def _metric_value( - self, recall: Union[float, List[float]], key: metric_types.MetricKey, - matrices: binary_confusion_matrices.Matrices) -> Union[float, np.ndarray]: - del key - tp = np.array(matrices.tp) - fn = np.array(matrices.fn) - recalls = tp / (tp + fn) - thresholds = np.array(matrices.thresholds) - return _find_max_under_constraint(recalls, thresholds, recall) + + def __init__( + self, + recall: Union[float, List[float]], + thresholds: Optional[List[float]] = None, + num_thresholds: Optional[int] = None, + class_id: Optional[int] = None, + name: Optional[str] = None, + top_k: Optional[int] = None, + **kwargs, + ): + """Initializes ThresholdAtRecall metric. + + Args: + ---- + recall: A scalar or a list of scalar values in range `[0, 1]`. + thresholds: (Optional) Thresholds to use for calculating the matrices. Use + one of either thresholds or num_thresholds. + num_thresholds: (Optional) Defaults to 1000. The number of thresholds to + use for matching the given recall. + class_id: (Optional) Used with a multi-class model to specify which class + to compute the confusion matrix for. When class_id is used, + metrics_specs.binarize settings must not be present. Only one of + class_id or top_k should be configured. + name: (Optional) string name of the metric instance. + top_k: (Optional) Used with a multi-class model to specify that the top-k + values should be used to compute the confusion matrix. The net effect is + that the non-top-k values are set to -inf and the matrix is then + constructed from the average TP, FP, TN, FN across the classes. When + top_k is used, metrics_specs.binarize settings must not be present. Only + one of class_id or top_k should be configured. When top_k is set, the + default thresholds are [float('-inf')]. + **kwargs: (Optional) Additional args to pass along to init (and eventually + on to _metric_computation and _metric_value) + """ + for r in [recall] if isinstance(recall, float) else recall: + if r < 0 or r > 1: + raise ValueError( + "Argument `recall` must be in the range [0, 1]. " + f"Received: recall={r}" + ) + + super().__init__( + thresholds=thresholds, + num_thresholds=num_thresholds, + recall=recall, + class_id=class_id, + name=name, + top_k=top_k, + **kwargs, + ) + + def _default_name(self) -> str: + return THRESHOLD_AT_RECALL_NAME + + def _metric_value( + self, + recall: Union[float, List[float]], + key: metric_types.MetricKey, + matrices: binary_confusion_matrices.Matrices, + ) -> Union[float, np.ndarray]: + del key + tp = np.array(matrices.tp) + fn = np.array(matrices.fn) + recalls = tp / (tp + fn) + thresholds = np.array(matrices.thresholds) + return _find_max_under_constraint(recalls, thresholds, recall) metric_types.register_metric(ThresholdAtRecall) diff --git a/tensorflow_model_analysis/metrics/confusion_matrix_metrics_test.py b/tensorflow_model_analysis/metrics/confusion_matrix_metrics_test.py index f7df21fa41..5655b65f88 100644 --- a/tensorflow_model_analysis/metrics/confusion_matrix_metrics_test.py +++ b/tensorflow_model_analysis/metrics/confusion_matrix_metrics_test.py @@ -15,20 +15,22 @@ import math -from absl.testing import parameterized import apache_beam as beam -from apache_beam.testing import util import numpy as np import tensorflow as tf -from tensorflow_model_analysis.metrics import binary_confusion_matrices -from tensorflow_model_analysis.metrics import confusion_matrix_metrics -from tensorflow_model_analysis.metrics import metric_types -from tensorflow_model_analysis.metrics import metric_util +from absl.testing import parameterized +from apache_beam.testing import util + +from tensorflow_model_analysis.metrics import ( + binary_confusion_matrices, + confusion_matrix_metrics, + metric_types, + metric_util, +) from tensorflow_model_analysis.metrics import test_util as metric_test_util from tensorflow_model_analysis.utils import test_util - -_TF_MAJOR_VERSION = int(tf.version.VERSION.split('.')[0]) +_TF_MAJOR_VERSION = int(tf.version.VERSION.split(".")[0]) _TRUE_POISITIVE = (1, 1) _TRUE_NEGATIVE = (0, 0) @@ -38,1142 +40,1205 @@ class ConfusionMatrixMetricsTest( metric_test_util.TestCase, parameterized.TestCase, ): - - @parameterized.named_parameters( - ( - 'Precision', - confusion_matrix_metrics.Precision(), - _TRUE_NEGATIVE, - float('nan'), - ), - ( - 'Recall', - confusion_matrix_metrics.Recall(), - _TRUE_NEGATIVE, - float('nan'), - ), - ( - 'Specificity', - confusion_matrix_metrics.Specificity(), - _TRUE_POISITIVE, - float('nan'), - ), - ( - 'FallOut', - confusion_matrix_metrics.FallOut(), - _TRUE_POISITIVE, - float('nan'), - ), - ( - 'MissRate', - confusion_matrix_metrics.MissRate(), - _TRUE_NEGATIVE, - float('nan'), - ), - ( - 'NegativePredictiveValue', - confusion_matrix_metrics.NegativePredictiveValue(), - _TRUE_POISITIVE, - float('nan'), - ), - ( - 'FalseDiscoveryRate', - confusion_matrix_metrics.FalseDiscoveryRate(), - _TRUE_NEGATIVE, - float('nan'), - ), - ( - 'FalseOmissionRate', - confusion_matrix_metrics.FalseOmissionRate(), - _TRUE_POISITIVE, - float('nan'), - ), - ( - 'ThreatScore', - confusion_matrix_metrics.ThreatScore(), - _TRUE_NEGATIVE, - float('nan'), - ), - ( - 'F1Score', - confusion_matrix_metrics.F1Score(), - _TRUE_NEGATIVE, - float('nan'), - ), - ( - 'MatthewsCorrelationCoefficient', - confusion_matrix_metrics.MatthewsCorrelationCoefficient(), - _TRUE_NEGATIVE, - float('nan'), - ), - ) - def testConfusionMatrixMetrics_DivideByZero_( - self, metric, pred_label, expected_value - ): - if _TF_MAJOR_VERSION < 2 and metric.__class__.__name__ in ( - 'SpecificityAtSensitivity', - 'SensitivityAtSpecificity', - 'PrecisionAtRecall', - 'RecallAtPrecision', - 'RecallAtFalsePositiveRate', - ): - self.skipTest('Not supported in TFv1.') - - computations = metric.computations(example_weighted=True) - histogram = computations[0] - matrices = computations[1] - metrics = computations[2] - - # Using one example to create a situation where the denominator in the - # corresponding calculation is 0. - pred, label = pred_label - example1 = { - 'labels': np.array([label]), - 'predictions': np.array([pred]), - } - - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'Create' >> beam.Create([example1]) - | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs) - | 'AddSlice' >> beam.Map(lambda x: ((), x)) - | 'ComputeHistogram' >> beam.CombinePerKey(histogram.combiner) - | 'ComputeMatrices' - >> beam.Map( - lambda x: (x[0], matrices.result(x[1])) - ) # pyformat: ignore - | 'ComputeMetrics' >> beam.Map(lambda x: (x[0], metrics.result(x[1]))) - ) # pyformat: ignore - - # pylint: enable=no-value-for-parameter - def check_result(got): - try: - self.assertLen(got, 1) - got_slice_key, got_metrics = got[0] - self.assertEqual(got_slice_key, ()) - self.assertLen(got_metrics, 1) - key = metrics.keys[0] - self.assertIn(key, got_metrics) - # np.testing utils automatically cast floats to arrays which fails - # to catch type mismatches. - self.assertEqual(type(expected_value), type(got_metrics[key])) - np.testing.assert_almost_equal( - got_metrics[key], expected_value, decimal=5 - ) - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(result, check_result, label='result') - - # LINT.IfChange(tfma_confusion_matrix_metrics_tests) - @parameterized.named_parameters( - ('auc', confusion_matrix_metrics.AUC(), np.float64(0.26)), - ( - 'auc_precision_recall', - confusion_matrix_metrics.AUCPrecisionRecall(), - np.float64(0.36205), - ), - ( - 'specificity_at_sensitivity', - confusion_matrix_metrics.SpecificityAtSensitivity(0.5), - 0.2, - ), - ( - 'sensitivity_at_specificity', - confusion_matrix_metrics.SensitivityAtSpecificity(0.5), - 0.0, - ), - ( - 'precision_at_recall', - confusion_matrix_metrics.PrecisionAtRecall(0.5), - 0.5, - ), - ( - 'recall_at_precision', - confusion_matrix_metrics.RecallAtPrecision(0.5), - 1.0, - ), - ( - 'recall_at_false_positive_rate', - confusion_matrix_metrics.RecallAtFalsePositiveRate(0.6), - 0.4, - ), - ('true_positives', confusion_matrix_metrics.TruePositives(), 1.0), - ('tp', confusion_matrix_metrics.TP(), 1.0), - ('false_positives', confusion_matrix_metrics.FalsePositives(), 3.0), - ('fp', confusion_matrix_metrics.FP(), 3.0), - ('true_negatives', confusion_matrix_metrics.TrueNegatives(), 2.0), - ('tn', confusion_matrix_metrics.TN(), 2.0), - ('false_negatives', confusion_matrix_metrics.FalseNegatives(), 4.0), - ('fn', confusion_matrix_metrics.FN(), 4.0), - ( - 'binary_accuracy', - confusion_matrix_metrics.BinaryAccuracy(), - (1.0 + 2.0) / (1.0 + 2.0 + 3.0 + 4.0), - ), - ('precision', confusion_matrix_metrics.Precision(), 1.0 / (1.0 + 3.0)), - ('ppv', confusion_matrix_metrics.PPV(), 1.0 / (1.0 + 3.0)), - ('recall', confusion_matrix_metrics.Recall(), 1.0 / (1.0 + 4.0)), - ('tpr', confusion_matrix_metrics.TPR(), 1.0 / (1.0 + 4.0)), - ( - 'specificity', - confusion_matrix_metrics.Specificity(), - 2.0 / (2.0 + 3.0), - ), - ('tnr', confusion_matrix_metrics.TNR(), 2.0 / (2.0 + 3.0)), - ('fall_out', confusion_matrix_metrics.FallOut(), 3.0 / (3.0 + 2.0)), - ('fpr', confusion_matrix_metrics.FPR(), 3.0 / (3.0 + 2.0)), - ('miss_rate', confusion_matrix_metrics.MissRate(), 4.0 / (4.0 + 1.0)), - ('fnr', confusion_matrix_metrics.FNR(), 4.0 / (4.0 + 1.0)), - ( - 'negative_predictive_value', - confusion_matrix_metrics.NegativePredictiveValue(), - 2.0 / (2.0 + 4.0), - ), - ('npv', confusion_matrix_metrics.NPV(), 2.0 / (2.0 + 4.0)), - ( - 'false_discovery_rate', - confusion_matrix_metrics.FalseDiscoveryRate(), - 3.0 / (3.0 + 1.0), - ), - ( - 'false_omission_rate', - confusion_matrix_metrics.FalseOmissionRate(), - 4.0 / (4.0 + 2.0), - ), - ( - 'prevalence', - confusion_matrix_metrics.Prevalence(), - (1.0 + 4.0) / (1.0 + 2.0 + 3.0 + 4.0), - ), - ( - 'prevalence_threshold', - confusion_matrix_metrics.PrevalenceThreshold(), - ( - math.sqrt((1.0 / (1.0 + 4.0)) * (1.0 - (2.0 / (2.0 + 3.0)))) - + (2.0 / (2.0 + 3.0) - 1.0) - ) - / ((1.0 / (1.0 + 4.0) + (2.0 / (2.0 + 3.0)) - 1.0)), - ), - ( - 'threat_score', - confusion_matrix_metrics.ThreatScore(), - 1.0 / (1.0 + 4.0 + 3.0), - ), - ( - 'balanced_accuracy', - confusion_matrix_metrics.BalancedAccuracy(), - ((1.0 / (1.0 + 4.0)) + (2.0 / (2.0 + 3.0))) / 2, - ), - ( - 'f1_score', - confusion_matrix_metrics.F1Score(), - 2 * 1.0 / (2 * 1.0 + 3.0 + 4.0), - ), - ( - 'matthews_correlation_coefficient', - confusion_matrix_metrics.MatthewsCorrelationCoefficient(), - (1.0 * 2.0 - 3.0 * 4.0) - / math.sqrt((1.0 + 3.0) * (1.0 + 4.0) * (2.0 + 3.0) * (2.0 + 4.0)), - ), - ( - 'fowlkes_mallows_index', - confusion_matrix_metrics.FowlkesMallowsIndex(), - math.sqrt(1.0 / (1.0 + 3.0) * 1.0 / (1.0 + 4.0)), - ), - ( - 'informedness', - confusion_matrix_metrics.Informedness(), - (1.0 / (1.0 + 4.0)) + (2.0 / (2.0 + 3.0)) - 1.0, - ), - ( - 'markedness', - confusion_matrix_metrics.Markedness(), - (1.0 / (1.0 + 3.0)) + (2.0 / (2.0 + 4.0)) - 1.0, - ), - ( - 'positive_likelihood_ratio', - confusion_matrix_metrics.PositiveLikelihoodRatio(), - (1.0 / (1.0 + 4.0)) / (3.0 / (3.0 + 2.0)), - ), - ( - 'negative_likelihood_ratio', - confusion_matrix_metrics.NegativeLikelihoodRatio(), - (4.0 / (4.0 + 1.0)) / (2.0 / (2.0 + 3.0)), - ), - ( - 'diagnostic_odds_ratio', - confusion_matrix_metrics.DiagnosticOddsRatio(), - ((1.0 / 3.0)) / (4.0 / 2.0), - ), - ( - 'predicted_positive_rate', - confusion_matrix_metrics.PredictedPositiveRate(), - (1.0 + 3.0) / (1.0 + 2.0 + 3.0 + 4.0), - ), - ( - 'threshold_at_recall', - confusion_matrix_metrics.ThresholdAtRecall(0.5), - 0.29993, - ), - ) - def testConfusionMatrixMetrics(self, metric, expected_value): - if _TF_MAJOR_VERSION < 2 and metric.__class__.__name__ in ( - 'SpecificityAtSensitivity', - 'SensitivityAtSpecificity', - 'PrecisionAtRecall', - 'RecallAtPrecision', - 'RecallAtFalsePositiveRate', - ): - self.skipTest('Not supported in TFv1.') - - computations = metric.computations(example_weighted=True) - histogram = computations[0] - matrices = computations[1] - metrics = computations[2] - - # tp = 1 - # tn = 2 - # fp = 3 - # fn = 4 - example1 = { - 'labels': np.array([1.0]), - 'predictions': np.array([0.6]), - 'example_weights': np.array([1.0]), - } - example2 = { - 'labels': np.array([0.0]), - 'predictions': np.array([0.3]), - 'example_weights': np.array([1.0]), - } - example3 = { - 'labels': np.array([0.0]), - 'predictions': np.array([0.2]), - 'example_weights': np.array([1.0]), - } - example4 = { - 'labels': np.array([0.0]), - 'predictions': np.array([0.6]), - 'example_weights': np.array([1.0]), - } - example5 = { - 'labels': np.array([0.0]), - 'predictions': np.array([0.7]), - 'example_weights': np.array([1.0]), - } - example6 = { - 'labels': np.array([0.0]), - 'predictions': np.array([0.8]), - 'example_weights': np.array([1.0]), - } - example7 = { - 'labels': np.array([1.0]), - 'predictions': np.array([0.1]), - 'example_weights': np.array([1.0]), - } - example8 = { - 'labels': np.array([1.0]), - 'predictions': np.array([0.2]), - 'example_weights': np.array([1.0]), - } - example9 = { - 'labels': np.array([1.0]), - 'predictions': np.array([0.3]), - 'example_weights': np.array([1.0]), - } - example10 = { - 'labels': np.array([1.0]), - 'predictions': np.array([0.4]), - 'example_weights': np.array([1.0]), - } - - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'Create' - >> beam.Create([ - example1, - example2, - example3, - example4, - example5, - example6, - example7, - example8, - example9, - example10, - ]) - | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs) - | 'AddSlice' >> beam.Map(lambda x: ((), x)) - | 'ComputeHistogram' >> beam.CombinePerKey(histogram.combiner) - | 'ComputeMatrices' - >> beam.Map(lambda x: (x[0], matrices.result(x[1]))) - | 'ComputeMetrics' >> beam.Map(lambda x: (x[0], metrics.result(x[1]))) - ) - - # pylint: enable=no-value-for-parameter - - def check_result(got): - try: - self.assertLen(got, 1) - got_slice_key, got_metrics = got[0] - self.assertEqual(got_slice_key, ()) - self.assertLen(got_metrics, 1) - key = metrics.keys[0] - self.assertIn(key, got_metrics) - # np.testing utils automatically cast floats to arrays which fails - # to catch type mismatches. - self.assertEqual(type(expected_value), type(got_metrics[key])) - np.testing.assert_almost_equal( - got_metrics[key], expected_value, decimal=5) - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(result, check_result, label='result') - - @parameterized.named_parameters( - ('auc', confusion_matrix_metrics.AUC(), np.float64(0.64286)), - ( - 'auc_precision_recall', - confusion_matrix_metrics.AUCPrecisionRecall(), - np.float64(0.37467), - ), - ( - 'specificity_at_sensitivity', - confusion_matrix_metrics.SpecificityAtSensitivity(0.5), - 0.642857, - ), - ( - 'sensitivity_at_specificity', - confusion_matrix_metrics.SensitivityAtSpecificity(0.5), - 1.0, - ), - ( - 'precision_at_recall', - confusion_matrix_metrics.PrecisionAtRecall(0.5), - 0.58333, - ), - ( - 'recall_at_precision', - confusion_matrix_metrics.RecallAtPrecision(0.5), - 1.0, - ), - ( - 'recall_at_false_positive_rate', - confusion_matrix_metrics.RecallAtFalsePositiveRate(0.5 / (0.5 + 0.9)), - 1.0, - ), - ('true_positives', confusion_matrix_metrics.TruePositives(), 0.7), - ('false_positives', confusion_matrix_metrics.FalsePositives(), 0.5), - ('true_negatives', confusion_matrix_metrics.TrueNegatives(), 0.9), - ('false_negatives', confusion_matrix_metrics.FalseNegatives(), 0.0), - ( - 'binary_accuracy', - confusion_matrix_metrics.BinaryAccuracy(), - (0.7 + 0.9) / (0.7 + 0.9 + 0.5 + 0.0), - ), - ('precision', confusion_matrix_metrics.Precision(), 0.7 / (0.7 + 0.5)), - ('recall', confusion_matrix_metrics.Recall(), 0.7 / (0.7 + 0.0)), - ( - 'specificity', - confusion_matrix_metrics.Specificity(), - 0.9 / (0.9 + 0.5), - ), - ('fall_out', confusion_matrix_metrics.FallOut(), 0.5 / (0.5 + 0.9)), - ('miss_rate', confusion_matrix_metrics.MissRate(), 0.0 / (0.0 + 0.7)), - ( - 'negative_predictive_value', - confusion_matrix_metrics.NegativePredictiveValue(), - 0.9 / (0.9 + 0.0), - ), - ( - 'false_discovery_rate', - confusion_matrix_metrics.FalseDiscoveryRate(), - 0.5 / (0.5 + 0.7), - ), - ( - 'false_omission_rate', - confusion_matrix_metrics.FalseOmissionRate(), - 0.0 / (0.0 + 0.9), - ), - ( - 'prevalence', - confusion_matrix_metrics.Prevalence(), - (0.7 + 0.0) / (0.7 + 0.9 + 0.5 + 0.0), - ), - ( - 'prevalence_threshold', - confusion_matrix_metrics.PrevalenceThreshold(), - ( - math.sqrt((0.7 / (0.7 + 0.0)) * (1.0 - (0.9 / (0.9 + 0.5)))) - + (0.9 / (0.9 + 0.5) - 1.0) - ) - / ((0.7 / (0.7 + 0.0) + (0.9 / (0.9 + 0.5)) - 1.0)), - ), - ( - 'threat_score', - confusion_matrix_metrics.ThreatScore(), - 0.7 / (0.7 + 0.0 + 0.5), - ), - ( - 'balanced_accuracy', - confusion_matrix_metrics.BalancedAccuracy(), - ((0.7 / (0.7 + 0.0)) + (0.9 / (0.9 + 0.5))) / 2, - ), - ( - 'f1_score', - confusion_matrix_metrics.F1Score(), - 2 * 0.7 / (2 * 0.7 + 0.5 + 0.0), - ), - ( - 'matthews_correlation_coefficient', - confusion_matrix_metrics.MatthewsCorrelationCoefficient(), - (0.7 * 0.9 - 0.5 * 0.0) - / math.sqrt((0.7 + 0.5) * (0.7 + 0.0) * (0.9 + 0.5) * (0.9 + 0.0)), - ), - ( - 'fowlkes_mallows_index', - confusion_matrix_metrics.FowlkesMallowsIndex(), - math.sqrt(0.7 / (0.7 + 0.5) * 0.7 / (0.7 + 0.0)), - ), - ( - 'informedness', - confusion_matrix_metrics.Informedness(), - (0.7 / (0.7 + 0.0)) + (0.9 / (0.9 + 0.5)) - 1.0, - ), - ( - 'markedness', - confusion_matrix_metrics.Markedness(), - (0.7 / (0.7 + 0.5)) + (0.9 / (0.9 + 0.0)) - 1.0, - ), - ( - 'positive_likelihood_ratio', - confusion_matrix_metrics.PositiveLikelihoodRatio(), - (0.7 / (0.7 + 0.0)) / (0.5 / (0.5 + 0.9)), - ), - ( - 'negative_likelihood_ratio', - confusion_matrix_metrics.NegativeLikelihoodRatio(), - (0.0 / (0.0 + 0.7)) / (0.9 / (0.9 + 0.5)), - ), - ( - 'predicted_positive_rate', - confusion_matrix_metrics.PredictedPositiveRate(), - (0.7 + 0.5) / (0.7 + 0.9 + 0.5 + 0.0), - ), - ) - def testConfusionMatrixMetricsWithWeights(self, metric, expected_value): - if _TF_MAJOR_VERSION < 2 and metric.__class__.__name__ in ( - 'SpecificityAtSensitivity', - 'SensitivityAtSpecificity', - 'PrecisionAtRecall', - 'RecallAtPrecision', - 'RecallAtFalsePositiveRate', + @parameterized.named_parameters( + ( + "Precision", + confusion_matrix_metrics.Precision(), + _TRUE_NEGATIVE, + float("nan"), + ), + ( + "Recall", + confusion_matrix_metrics.Recall(), + _TRUE_NEGATIVE, + float("nan"), + ), + ( + "Specificity", + confusion_matrix_metrics.Specificity(), + _TRUE_POISITIVE, + float("nan"), + ), + ( + "FallOut", + confusion_matrix_metrics.FallOut(), + _TRUE_POISITIVE, + float("nan"), + ), + ( + "MissRate", + confusion_matrix_metrics.MissRate(), + _TRUE_NEGATIVE, + float("nan"), + ), + ( + "NegativePredictiveValue", + confusion_matrix_metrics.NegativePredictiveValue(), + _TRUE_POISITIVE, + float("nan"), + ), + ( + "FalseDiscoveryRate", + confusion_matrix_metrics.FalseDiscoveryRate(), + _TRUE_NEGATIVE, + float("nan"), + ), + ( + "FalseOmissionRate", + confusion_matrix_metrics.FalseOmissionRate(), + _TRUE_POISITIVE, + float("nan"), + ), + ( + "ThreatScore", + confusion_matrix_metrics.ThreatScore(), + _TRUE_NEGATIVE, + float("nan"), + ), + ( + "F1Score", + confusion_matrix_metrics.F1Score(), + _TRUE_NEGATIVE, + float("nan"), + ), + ( + "MatthewsCorrelationCoefficient", + confusion_matrix_metrics.MatthewsCorrelationCoefficient(), + _TRUE_NEGATIVE, + float("nan"), + ), + ) + def testConfusionMatrixMetrics_DivideByZero_( + self, metric, pred_label, expected_value ): - self.skipTest('Not supported in TFv1.') - - computations = metric.computations(example_weighted=True) - histogram = computations[0] - matrix = computations[1] - derived_metric = computations[2] - - # tp = 0.7 - # tn = 0.9 - # fp = 0.5 - # fn = 0.0 - example1 = { - 'labels': np.array([0.0]), - 'predictions': np.array([1.0]), - 'example_weights': np.array([0.5]), - } - example2 = { - 'labels': np.array([1.0]), - 'predictions': np.array([0.7]), - 'example_weights': np.array([0.7]), - } - example3 = { - 'labels': np.array([0.0]), - 'predictions': np.array([0.5]), - 'example_weights': np.array([0.9]), - } - - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'Create' >> beam.Create([example1, example2, example3]) - | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs) - | 'AddSlice' >> beam.Map(lambda x: ((), x)) - | 'ComputeHistogram' >> beam.CombinePerKey(histogram.combiner) - | 'ComputeConfusionMatrix' - >> beam.Map(lambda x: (x[0], matrix.result(x[1]))) - | 'ComputeMetric' - >> beam.Map(lambda x: (x[0], derived_metric.result(x[1]))) - ) - - # pylint: enable=no-value-for-parameter - - def check_result(got): - try: - self.assertLen(got, 1) - got_slice_key, got_metrics = got[0] - self.assertEqual(got_slice_key, ()) - key = metric_types.MetricKey(name=metric.name, example_weighted=True) - self.assertIn(key, got_metrics) - # np.testing utils automatically cast floats to arrays which fails - # to catch type mismatches. - self.assertEqual(type(expected_value), type(got_metrics[key])) - np.testing.assert_almost_equal( - np.array(got_metrics[key]), np.array(expected_value), decimal=5) - - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(result, check_result, label='result') - - # LINT.ThenChange(../google/sql:uda_auc_tests) - - @parameterized.named_parameters( - ('auc', confusion_matrix_metrics.AUC(), 0.8571428), - ('auc_precision_recall', confusion_matrix_metrics.AUCPrecisionRecall(), - 0.77369833), - ('true_positives', confusion_matrix_metrics.TruePositives(), 1.4), - ('false_positives', confusion_matrix_metrics.FalsePositives(), 0.6), - ('true_negatives', confusion_matrix_metrics.TrueNegatives(), 1.0), - ('false_negatives', confusion_matrix_metrics.FalseNegatives(), 0.0), - ) - def testConfusionMatrixMetricsWithFractionalLabels(self, metric, - expected_value): - computations = metric.computations(example_weighted=True) - histogram = computations[0] - matrix = computations[1] - derived_metric = computations[2] - - # The following examples will be expanded to: - # - # prediction | label | weight - # 0.0 | - | 1.0 - # 0.7 | - | 0.4 - # 0.7 | + | 0.6 - # 1.0 | - | 0.2 - # 1.0 | + | 0.8 - example1 = { - 'labels': np.array([0.0]), - 'predictions': np.array([0.0]), - 'example_weights': np.array([1.0]), - } - example2 = { - 'labels': np.array([0.6]), - 'predictions': np.array([0.7]), - 'example_weights': np.array([1.0]), - } - example3 = { - 'labels': np.array([0.8]), - 'predictions': np.array([1.0]), - 'example_weights': np.array([1.0]), - } - - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'Create' >> beam.Create([example1, example2, example3]) - | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs) - | 'AddSlice' >> beam.Map(lambda x: ((), x)) - | 'ComputeHistogram' >> beam.CombinePerKey(histogram.combiner) - | 'ComputeConfusionMatrix' - >> beam.Map(lambda x: (x[0], matrix.result(x[1]))) - | 'ComputeMetric' - >> beam.Map(lambda x: (x[0], derived_metric.result(x[1]))) - ) - - # pylint: enable=no-value-for-parameter - - def check_result(got): - try: - self.assertLen(got, 1) - got_slice_key, got_metrics = got[0] - self.assertEqual(got_slice_key, ()) - key = metric_types.MetricKey(name=metric.name, example_weighted=True) - self.assertDictElementsAlmostEqual( - got_metrics, {key: expected_value}, places=5) - - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(result, check_result, label='result') - - @parameterized.named_parameters( - ('precision@2 (using sub_key)', confusion_matrix_metrics.Precision(), 2, - 1.6 / (1.6 + 3.2)), - ('precision@2 (using param)', confusion_matrix_metrics.Precision(top_k=2), - None, 1.6 / (1.6 + 3.2)), - ('recall@2 (using sub_key)', confusion_matrix_metrics.Recall(), 2, 1.6 / - (1.6 + 0.8)), - ('recall@2 (using param)', confusion_matrix_metrics.Recall(top_k=2), None, - 1.6 / (1.6 + 0.8)), - ('precision@3 (using sub_key)', confusion_matrix_metrics.Precision(), 3, - 1.9 / (1.9 + 5.3)), - ('recall@3 (using sub_key)', confusion_matrix_metrics.Recall(), 3, 1.9 / - (1.9 + 0.5)), - ) - def testConfusionMatrixMetricsWithTopK(self, metric, top_k, expected_value): - computations = metric.computations( - sub_keys=[metric_types.SubKey(top_k=top_k)], example_weighted=True) - histogram = computations[0] - matrix = computations[1] - derived_metric = computations[2] - - # top_k = 2 - # TP = 0.5*0 + 0.7*1 + 0.9*1 + 0.3*0 = 1.6 - # FP = 0.5*2 + 0.7*1 + 0.9*1 + 0.3*2 = 3.2 - # FN = 0.5*1 + 0.7*0 + 0.9*0 + 0.3*1 = 0.8 - # - # top_k = 3 - # TP = 0.5*0 + 0.7*1 + 0.9*1 + 0.3*1 = 1.9 - # FP = 0.5*3 + 0.7*2 + 0.9*2 + 0.3*2 = 5.3 - # FN = 0.5*1 + 0.7*0 + 0.9*0 + 0.3*0 = 0.5 - example1 = { - 'labels': np.array([2]), - 'predictions': np.array([0.1, 0.2, 0.1, 0.25, 0.35]), - 'example_weights': np.array([0.5]), - } - example2 = { - 'labels': np.array([1]), - 'predictions': np.array([0.2, 0.3, 0.05, 0.15, 0.3]), - 'example_weights': np.array([0.7]), - } - example3 = { - 'labels': np.array([3]), - 'predictions': np.array([0.01, 0.2, 0.09, 0.5, 0.2]), - 'example_weights': np.array([0.9]), - } - example4 = { - 'labels': np.array([1]), - 'predictions': np.array([0.3, 0.2, 0.05, 0.4, 0.05]), - 'example_weights': np.array([0.3]), - } - - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'Create' >> beam.Create([example1, example2, example3, example4]) - | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs) - | 'AddSlice' >> beam.Map(lambda x: ((), x)) - | 'ComputeHistogram' >> beam.CombinePerKey(histogram.combiner) - | 'ComputeConfusionMatrix' - >> beam.Map(lambda x: (x[0], matrix.result(x[1]))) - | 'ComputeMetric' - >> beam.Map(lambda x: (x[0], derived_metric.result(x[1]))) - ) - - # pylint: enable=no-value-for-parameter - - def check_result(got): - try: - self.assertLen(got, 1) - got_slice_key, got_metrics = got[0] - self.assertEqual(got_slice_key, ()) - if top_k: - sub_key = metric_types.SubKey(top_k=top_k) - else: - sub_key = metric_types.SubKey(top_k=metric.get_config()['top_k']) - key = metric_types.MetricKey( - name=metric.name, sub_key=sub_key, example_weighted=True) - self.assertDictElementsAlmostEqual( - got_metrics, {key: expected_value}, places=5) - - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(result, check_result, label='result') - - @parameterized.named_parameters( - ('precision (class_id=1 using sub_key)', - confusion_matrix_metrics.Precision(thresholds=[0.1]), 1, 0.5 / - (0.5 + 1.6)), - ('precision (class_id=1 using param)', - confusion_matrix_metrics.Precision( - class_id=1, thresholds=[0.1]), None, 0.5 / (0.5 + 1.6)), - ('recall (class_id=3 using sub_key)', - confusion_matrix_metrics.Recall(thresholds=[0.1]), 3, 0.7 / (0.7 + 0.9)), - ('recall (class_id=3 using param)', - confusion_matrix_metrics.Recall( - class_id=3, thresholds=[0.1]), None, 0.7 / (0.7 + 0.9)), - ) - def testConfusionMatrixMetricsWithClassId(self, metric, class_id, - expected_value): - computations = metric.computations( - sub_keys=[metric_types.SubKey(class_id=class_id)], - example_weighted=True) - histogram = computations[0] - matrix = computations[1] - derived_metric = computations[2] - - # class_id = 1, threshold = 0.1 - # TP = 0.5*1 + 0.7*0 + 0.9*0 + 0.3*0 = 0.5 - # FP = 0.5*0 + 0.7*1 + 0.9*1 + 0.3*0 = 1.6 - # FN = 0.5*0 + 0.7*0 + 0.9*0 + 0.3*1 = 0.3 - # - # class_id = 3, threshold = 0.1 - # TP = 0.5*0 + 0.7*1 + 0.9*0 + 0.3*0 = 0.7 - # FP = 0.5*1 + 0.7*0 + 0.9*0 + 0.3*1 = 0.8 - # FN = 0.5*0 + 0.7*0 + 0.9*1 + 0.3*0 = 0.9 - example1 = { - 'labels': np.array([1]), - 'predictions': np.array([0.1, 0.2, 0.1, 0.25, 0.35]), - 'example_weights': np.array([0.5]), - } - example2 = { - 'labels': np.array([3]), - 'predictions': np.array([0.2, 0.3, 0.05, 0.15, 0.3]), - 'example_weights': np.array([0.7]), - } - example3 = { - 'labels': np.array([3]), - 'predictions': np.array([0.01, 0.2, 0.2, 0.09, 0.5]), - 'example_weights': np.array([0.9]), - } - example4 = { - 'labels': np.array([1]), - 'predictions': np.array([0.1, 0.05, 0.3, 0.4, 0.05]), - 'example_weights': np.array([0.3]), - } - - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'Create' >> beam.Create([example1, example2, example3, example4]) - | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs) - | 'AddSlice' >> beam.Map(lambda x: ((), x)) - | 'ComputeHistogram' >> beam.CombinePerKey(histogram.combiner) - | 'ComputeConfusionMatrix' - >> beam.Map(lambda x: (x[0], matrix.result(x[1]))) - | 'ComputeMetric' - >> beam.Map(lambda x: (x[0], derived_metric.result(x[1]))) - ) - - # pylint: enable=no-value-for-parameter - - def check_result(got): - try: - self.assertLen(got, 1) - got_slice_key, got_metrics = got[0] - self.assertEqual(got_slice_key, ()) - if class_id: - sub_key = metric_types.SubKey(class_id=class_id) - else: - sub_key = metric_types.SubKey( - class_id=metric.get_config()['class_id']) - key = metric_types.MetricKey( - name=metric.name, sub_key=sub_key, example_weighted=True) - self.assertDictElementsAlmostEqual( - got_metrics, {key: expected_value}, places=5) - - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(result, check_result, label='result') - - def testConfusionMatrixMetricsWithNan(self): - computations = confusion_matrix_metrics.Specificity().computations( - example_weighted=True) - histogram = computations[0] - matrices = computations[1] - metrics = computations[2] - - example1 = { - 'labels': np.array([1.0]), - 'predictions': np.array([1.0]), - 'example_weights': np.array([1.0]), - } - - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'Create' >> beam.Create([example1]) - | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs) - | 'AddSlice' >> beam.Map(lambda x: ((), x)) - | 'ComputeHistogram' >> beam.CombinePerKey(histogram.combiner) - | 'ComputeMatrices' - >> beam.Map(lambda x: (x[0], matrices.result(x[1]))) - | 'ComputeMetrics' >> beam.Map(lambda x: (x[0], metrics.result(x[1]))) - ) - - # pylint: enable=no-value-for-parameter - - def check_result(got): - try: - self.assertLen(got, 1) - got_slice_key, got_metrics = got[0] - self.assertEqual(got_slice_key, ()) - self.assertLen(got_metrics, 1) - key = metrics.keys[0] - self.assertIn(key, got_metrics) - self.assertTrue(math.isnan(got_metrics[key])) - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(result, check_result, label='result') - - @parameterized.named_parameters( - ('class_id as param and class_id as sub_key', - confusion_matrix_metrics.Precision(class_id=2), 2, None), - ('top_k as param and top_k as sub_key', - confusion_matrix_metrics.Precision(top_k=2), None, 2), - ) - def testRaisesErrorIfOverlappingSettings(self, metric, class_id, top_k): - with self.assertRaisesRegex(ValueError, - '.*is configured with overlapping settings.*'): - metric.computations( - sub_keys=[metric_types.SubKey(class_id=class_id, top_k=top_k)]) - - def testConfusionMatrixAtThresholds(self): - computations = confusion_matrix_metrics.ConfusionMatrixAtThresholds( - thresholds=[0.3, 0.5, 0.8]).computations(example_weighted=True) - histogram = computations[0] - matrices = computations[1] - metrics = computations[2] - - example1 = { - 'labels': np.array([0.0]), - 'predictions': np.array([0.0]), - 'example_weights': np.array([1.0]), - } - example2 = { - 'labels': np.array([0.0]), - 'predictions': np.array([0.5]), - 'example_weights': np.array([1.0]), - } - example3 = { - 'labels': np.array([1.0]), - 'predictions': np.array([0.3]), - 'example_weights': np.array([1.0]), - } - example4 = { - 'labels': np.array([1.0]), - 'predictions': np.array([0.9]), - 'example_weights': np.array([1.0]), - } - - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'Create' >> beam.Create([example1, example2, example3, example4]) - | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs) - | 'AddSlice' >> beam.Map(lambda x: ((), x)) - | 'ComputeHistogram' >> beam.CombinePerKey(histogram.combiner) - | 'ComputeMatrices' - >> beam.Map(lambda x: (x[0], matrices.result(x[1]))) - | 'ComputeMetrics' >> beam.Map(lambda x: (x[0], metrics.result(x[1]))) - ) - - # pylint: enable=no-value-for-parameter - - def check_result(got): - try: - self.assertLen(got, 1) - got_slice_key, got_metrics = got[0] - self.assertEqual(got_slice_key, ()) - self.assertLen(got_metrics, 1) - key = metric_types.MetricKey( - name='confusion_matrix_at_thresholds', example_weighted=True) - self.assertIn(key, got_metrics) - got_metric = got_metrics[key] - self.assertEqual( - binary_confusion_matrices.Matrices( - thresholds=[0.3, 0.5, 0.8], - tp=[1.0, 1.0, 1.0], - tn=[1.0, 2.0, 2.0], - fp=[1.0, 0.0, 0.0], - fn=[1.0, 1.0, 1.0]), got_metric) - - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(result, check_result, label='result') - - @parameterized.named_parameters( - ( - 'precision (class_id=1 top_k=2 using sub_key)', - confusion_matrix_metrics.Precision(thresholds=[0.1]), - 1, - 2, - 0.5 / (0.5 + 1.6), - ), - ( - 'precision (class_id=1 using param and top_k=2 using sub_key)', - confusion_matrix_metrics.Precision(class_id=1, thresholds=[0.1]), - None, - 2, - 0.5 / (0.5 + 1.6), - ), - ( - 'recall (class_id=3 using sub_key and top_k=2 using param)', - confusion_matrix_metrics.Recall(thresholds=[0.1], top_k=2), - 3, - None, - 0.7 / (0.7 + 0.9), - ), - ( - 'recall (class_id=3 top_k=2 using param)', - confusion_matrix_metrics.Recall( - class_id=3, top_k=2, thresholds=[0.1] - ), - None, - None, - 0.7 / (0.7 + 0.9), - ), - ) - def testConfusionMatrixMetricsWithClassIdAndTopK( - self, metric, class_id, top_k, expected_value - ): - computations = metric.computations( - sub_keys=[metric_types.SubKey(class_id=class_id, top_k=top_k)], - example_weighted=True, + if _TF_MAJOR_VERSION < 2 and metric.__class__.__name__ in ( + "SpecificityAtSensitivity", + "SensitivityAtSpecificity", + "PrecisionAtRecall", + "RecallAtPrecision", + "RecallAtFalsePositiveRate", + ): + self.skipTest("Not supported in TFv1.") + + computations = metric.computations(example_weighted=True) + histogram = computations[0] + matrices = computations[1] + metrics = computations[2] + + # Using one example to create a situation where the denominator in the + # corresponding calculation is 0. + pred, label = pred_label + example1 = { + "labels": np.array([label]), + "predictions": np.array([pred]), + } + + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "Create" >> beam.Create([example1]) + | "Process" >> beam.Map(metric_util.to_standard_metric_inputs) + | "AddSlice" >> beam.Map(lambda x: ((), x)) + | "ComputeHistogram" >> beam.CombinePerKey(histogram.combiner) + | "ComputeMatrices" + >> beam.Map(lambda x: (x[0], matrices.result(x[1]))) # pyformat: ignore + | "ComputeMetrics" >> beam.Map(lambda x: (x[0], metrics.result(x[1]))) + ) # pyformat: ignore + + # pylint: enable=no-value-for-parameter + def check_result(got): + try: + self.assertLen(got, 1) + got_slice_key, got_metrics = got[0] + self.assertEqual(got_slice_key, ()) + self.assertLen(got_metrics, 1) + key = metrics.keys[0] + self.assertIn(key, got_metrics) + # np.testing utils automatically cast floats to arrays which fails + # to catch type mismatches. + self.assertEqual(type(expected_value), type(got_metrics[key])) + np.testing.assert_almost_equal( + got_metrics[key], expected_value, decimal=5 + ) + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(result, check_result, label="result") + + # LINT.IfChange(tfma_confusion_matrix_metrics_tests) + @parameterized.named_parameters( + ("auc", confusion_matrix_metrics.AUC(), np.float64(0.26)), + ( + "auc_precision_recall", + confusion_matrix_metrics.AUCPrecisionRecall(), + np.float64(0.36205), + ), + ( + "specificity_at_sensitivity", + confusion_matrix_metrics.SpecificityAtSensitivity(0.5), + 0.2, + ), + ( + "sensitivity_at_specificity", + confusion_matrix_metrics.SensitivityAtSpecificity(0.5), + 0.0, + ), + ( + "precision_at_recall", + confusion_matrix_metrics.PrecisionAtRecall(0.5), + 0.5, + ), + ( + "recall_at_precision", + confusion_matrix_metrics.RecallAtPrecision(0.5), + 1.0, + ), + ( + "recall_at_false_positive_rate", + confusion_matrix_metrics.RecallAtFalsePositiveRate(0.6), + 0.4, + ), + ("true_positives", confusion_matrix_metrics.TruePositives(), 1.0), + ("tp", confusion_matrix_metrics.TP(), 1.0), + ("false_positives", confusion_matrix_metrics.FalsePositives(), 3.0), + ("fp", confusion_matrix_metrics.FP(), 3.0), + ("true_negatives", confusion_matrix_metrics.TrueNegatives(), 2.0), + ("tn", confusion_matrix_metrics.TN(), 2.0), + ("false_negatives", confusion_matrix_metrics.FalseNegatives(), 4.0), + ("fn", confusion_matrix_metrics.FN(), 4.0), + ( + "binary_accuracy", + confusion_matrix_metrics.BinaryAccuracy(), + (1.0 + 2.0) / (1.0 + 2.0 + 3.0 + 4.0), + ), + ("precision", confusion_matrix_metrics.Precision(), 1.0 / (1.0 + 3.0)), + ("ppv", confusion_matrix_metrics.PPV(), 1.0 / (1.0 + 3.0)), + ("recall", confusion_matrix_metrics.Recall(), 1.0 / (1.0 + 4.0)), + ("tpr", confusion_matrix_metrics.TPR(), 1.0 / (1.0 + 4.0)), + ( + "specificity", + confusion_matrix_metrics.Specificity(), + 2.0 / (2.0 + 3.0), + ), + ("tnr", confusion_matrix_metrics.TNR(), 2.0 / (2.0 + 3.0)), + ("fall_out", confusion_matrix_metrics.FallOut(), 3.0 / (3.0 + 2.0)), + ("fpr", confusion_matrix_metrics.FPR(), 3.0 / (3.0 + 2.0)), + ("miss_rate", confusion_matrix_metrics.MissRate(), 4.0 / (4.0 + 1.0)), + ("fnr", confusion_matrix_metrics.FNR(), 4.0 / (4.0 + 1.0)), + ( + "negative_predictive_value", + confusion_matrix_metrics.NegativePredictiveValue(), + 2.0 / (2.0 + 4.0), + ), + ("npv", confusion_matrix_metrics.NPV(), 2.0 / (2.0 + 4.0)), + ( + "false_discovery_rate", + confusion_matrix_metrics.FalseDiscoveryRate(), + 3.0 / (3.0 + 1.0), + ), + ( + "false_omission_rate", + confusion_matrix_metrics.FalseOmissionRate(), + 4.0 / (4.0 + 2.0), + ), + ( + "prevalence", + confusion_matrix_metrics.Prevalence(), + (1.0 + 4.0) / (1.0 + 2.0 + 3.0 + 4.0), + ), + ( + "prevalence_threshold", + confusion_matrix_metrics.PrevalenceThreshold(), + ( + math.sqrt((1.0 / (1.0 + 4.0)) * (1.0 - (2.0 / (2.0 + 3.0)))) + + (2.0 / (2.0 + 3.0) - 1.0) + ) + / (1.0 / (1.0 + 4.0) + (2.0 / (2.0 + 3.0)) - 1.0), + ), + ( + "threat_score", + confusion_matrix_metrics.ThreatScore(), + 1.0 / (1.0 + 4.0 + 3.0), + ), + ( + "balanced_accuracy", + confusion_matrix_metrics.BalancedAccuracy(), + ((1.0 / (1.0 + 4.0)) + (2.0 / (2.0 + 3.0))) / 2, + ), + ( + "f1_score", + confusion_matrix_metrics.F1Score(), + 2 * 1.0 / (2 * 1.0 + 3.0 + 4.0), + ), + ( + "matthews_correlation_coefficient", + confusion_matrix_metrics.MatthewsCorrelationCoefficient(), + (1.0 * 2.0 - 3.0 * 4.0) + / math.sqrt((1.0 + 3.0) * (1.0 + 4.0) * (2.0 + 3.0) * (2.0 + 4.0)), + ), + ( + "fowlkes_mallows_index", + confusion_matrix_metrics.FowlkesMallowsIndex(), + math.sqrt(1.0 / (1.0 + 3.0) * 1.0 / (1.0 + 4.0)), + ), + ( + "informedness", + confusion_matrix_metrics.Informedness(), + (1.0 / (1.0 + 4.0)) + (2.0 / (2.0 + 3.0)) - 1.0, + ), + ( + "markedness", + confusion_matrix_metrics.Markedness(), + (1.0 / (1.0 + 3.0)) + (2.0 / (2.0 + 4.0)) - 1.0, + ), + ( + "positive_likelihood_ratio", + confusion_matrix_metrics.PositiveLikelihoodRatio(), + (1.0 / (1.0 + 4.0)) / (3.0 / (3.0 + 2.0)), + ), + ( + "negative_likelihood_ratio", + confusion_matrix_metrics.NegativeLikelihoodRatio(), + (4.0 / (4.0 + 1.0)) / (2.0 / (2.0 + 3.0)), + ), + ( + "diagnostic_odds_ratio", + confusion_matrix_metrics.DiagnosticOddsRatio(), + (1.0 / 3.0) / (4.0 / 2.0), + ), + ( + "predicted_positive_rate", + confusion_matrix_metrics.PredictedPositiveRate(), + (1.0 + 3.0) / (1.0 + 2.0 + 3.0 + 4.0), + ), + ( + "threshold_at_recall", + confusion_matrix_metrics.ThresholdAtRecall(0.5), + 0.29993, + ), ) - histogram = computations[0] - matrix = computations[1] - derived_metric = computations[2] - - # class_id = 1, top_k=2 - # TP = 0.5*1 + 0.7*0 + 0.9*0 + 0.3*0 = 0.5 - # FP = 0.5*0 + 0.7*1 + 0.9*1 + 0.3*0 = 1.6 - # FN = 0.5*0 + 0.7*0 + 0.9*0 + 0.3*1 = 0.3 - # - # class_id = 3, top_k=2 - # TP = 0.5*0 + 0.7*1 + 0.9*0 + 0.3*0 = 0.7 - # FP = 0.5*1 + 0.7*0 + 0.9*0 + 0.3*1 = 0.8 - # FN = 0.5*0 + 0.7*0 + 0.9*1 + 0.3*0 = 0.9 - example1 = { - 'labels': np.array([1]), - 'predictions': np.array([0.1, 0.5, 0.1, 0.45, 0.35]), - 'example_weights': np.array([0.5]), - } - example2 = { - 'labels': np.array([3]), - 'predictions': np.array([0.2, 0.3, 0.05, 0.31, 0.3]), - 'example_weights': np.array([0.7]), - } - example3 = { - 'labels': np.array([3]), - 'predictions': np.array([0.01, 0.2, 0.2, 0.09, 0.5]), - 'example_weights': np.array([0.9]), - } - example4 = { - 'labels': np.array([1]), - 'predictions': np.array([0.1, 0.05, 0.3, 0.4, 0.05]), - 'example_weights': np.array([0.3]), - } - - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'Create' >> beam.Create([example1, example2, example3, example4]) - | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs) - | 'AddSlice' >> beam.Map(lambda x: ((), x)) - | 'ComputeHistogram' >> beam.CombinePerKey(histogram.combiner) - | 'ComputeConfusionMatrix' - >> beam.Map(lambda x: (x[0], matrix.result(x[1]))) - | 'ComputeMetric' - >> beam.Map(lambda x: (x[0], derived_metric.result(x[1]))) - ) - - # pylint: enable=no-value-for-parameter - - def check_result(got): - try: - self.assertLen(got, 1) - got_slice_key, got_metrics = got[0] - self.assertEqual(got_slice_key, ()) - subkey_class_id = class_id or metric.get_config()['class_id'] - subkey_top_k = top_k or metric.get_config()['top_k'] - sub_key = metric_types.SubKey( - class_id=subkey_class_id, top_k=subkey_top_k - ) - key = metric_types.MetricKey( - name=metric.name, sub_key=sub_key, example_weighted=True - ) - self.assertDictElementsAlmostEqual( - got_metrics, {key: expected_value}, places=5 - ) - - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(result, check_result, label='result') - - @parameterized.named_parameters( - ( - 'false_positives', - confusion_matrix_metrics.FalsePositiveFeatureSampler( - threshold=0.5, feature_key='example_id', sample_size=2 - ), - 'false_positive_feature_sampler', - np.array(['example1', 'example2'], dtype=str), - ), - ( - 'false_negatives', - confusion_matrix_metrics.FalseNegativeFeatureSampler( - threshold=0.5, feature_key='example_id', sample_size=2 - ), - 'false_negative_feature_sampler', - np.array(['example3', 'example4'], dtype=str), - ), - ) - def testConfusionMatrixFeatureSamplers( - self, metric, expected_metric_name, expected_value - ): - # false positive - example1 = { - 'labels': np.array([0.0]), - 'predictions': np.array([1.0]), - 'example_weights': np.array([1.0]), - 'features': {'example_id': np.array(['example1'])}, - } - # false positive - example2 = { - 'labels': np.array([0.0]), - 'predictions': np.array([1.0]), - 'example_weights': np.array([1.0]), - 'features': {'example_id': np.array(['example2'])}, - } - # false negative - example3 = { - 'labels': np.array([1.0]), - 'predictions': np.array([0.0]), - 'example_weights': np.array([1.0]), - 'features': {'example_id': np.array(['example3'])}, - } - # false negative - example4 = { - 'labels': np.array([1.0]), - 'predictions': np.array([0.0]), - 'example_weights': np.array([1.0]), - 'features': {'example_id': np.array(['example4'])}, - } - - expected_metrics = { - metric_types.MetricKey( - name=expected_metric_name, example_weighted=True - ): expected_value, - } - self.assertDerivedMetricsEqual( - expected_metrics=expected_metrics, - extracts=[example1, example2, example3, example4], - metric=metric, - enable_debug_print=True, + def testConfusionMatrixMetrics(self, metric, expected_value): + if _TF_MAJOR_VERSION < 2 and metric.__class__.__name__ in ( + "SpecificityAtSensitivity", + "SensitivityAtSpecificity", + "PrecisionAtRecall", + "RecallAtPrecision", + "RecallAtFalsePositiveRate", + ): + self.skipTest("Not supported in TFv1.") + + computations = metric.computations(example_weighted=True) + histogram = computations[0] + matrices = computations[1] + metrics = computations[2] + + # tp = 1 + # tn = 2 + # fp = 3 + # fn = 4 + example1 = { + "labels": np.array([1.0]), + "predictions": np.array([0.6]), + "example_weights": np.array([1.0]), + } + example2 = { + "labels": np.array([0.0]), + "predictions": np.array([0.3]), + "example_weights": np.array([1.0]), + } + example3 = { + "labels": np.array([0.0]), + "predictions": np.array([0.2]), + "example_weights": np.array([1.0]), + } + example4 = { + "labels": np.array([0.0]), + "predictions": np.array([0.6]), + "example_weights": np.array([1.0]), + } + example5 = { + "labels": np.array([0.0]), + "predictions": np.array([0.7]), + "example_weights": np.array([1.0]), + } + example6 = { + "labels": np.array([0.0]), + "predictions": np.array([0.8]), + "example_weights": np.array([1.0]), + } + example7 = { + "labels": np.array([1.0]), + "predictions": np.array([0.1]), + "example_weights": np.array([1.0]), + } + example8 = { + "labels": np.array([1.0]), + "predictions": np.array([0.2]), + "example_weights": np.array([1.0]), + } + example9 = { + "labels": np.array([1.0]), + "predictions": np.array([0.3]), + "example_weights": np.array([1.0]), + } + example10 = { + "labels": np.array([1.0]), + "predictions": np.array([0.4]), + "example_weights": np.array([1.0]), + } + + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "Create" + >> beam.Create( + [ + example1, + example2, + example3, + example4, + example5, + example6, + example7, + example8, + example9, + example10, + ] + ) + | "Process" >> beam.Map(metric_util.to_standard_metric_inputs) + | "AddSlice" >> beam.Map(lambda x: ((), x)) + | "ComputeHistogram" >> beam.CombinePerKey(histogram.combiner) + | "ComputeMatrices" >> beam.Map(lambda x: (x[0], matrices.result(x[1]))) + | "ComputeMetrics" >> beam.Map(lambda x: (x[0], metrics.result(x[1]))) + ) + + # pylint: enable=no-value-for-parameter + + def check_result(got): + try: + self.assertLen(got, 1) + got_slice_key, got_metrics = got[0] + self.assertEqual(got_slice_key, ()) + self.assertLen(got_metrics, 1) + key = metrics.keys[0] + self.assertIn(key, got_metrics) + # np.testing utils automatically cast floats to arrays which fails + # to catch type mismatches. + self.assertEqual(type(expected_value), type(got_metrics[key])) + np.testing.assert_almost_equal( + got_metrics[key], expected_value, decimal=5 + ) + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(result, check_result, label="result") + + @parameterized.named_parameters( + ("auc", confusion_matrix_metrics.AUC(), np.float64(0.64286)), + ( + "auc_precision_recall", + confusion_matrix_metrics.AUCPrecisionRecall(), + np.float64(0.37467), + ), + ( + "specificity_at_sensitivity", + confusion_matrix_metrics.SpecificityAtSensitivity(0.5), + 0.642857, + ), + ( + "sensitivity_at_specificity", + confusion_matrix_metrics.SensitivityAtSpecificity(0.5), + 1.0, + ), + ( + "precision_at_recall", + confusion_matrix_metrics.PrecisionAtRecall(0.5), + 0.58333, + ), + ( + "recall_at_precision", + confusion_matrix_metrics.RecallAtPrecision(0.5), + 1.0, + ), + ( + "recall_at_false_positive_rate", + confusion_matrix_metrics.RecallAtFalsePositiveRate(0.5 / (0.5 + 0.9)), + 1.0, + ), + ("true_positives", confusion_matrix_metrics.TruePositives(), 0.7), + ("false_positives", confusion_matrix_metrics.FalsePositives(), 0.5), + ("true_negatives", confusion_matrix_metrics.TrueNegatives(), 0.9), + ("false_negatives", confusion_matrix_metrics.FalseNegatives(), 0.0), + ( + "binary_accuracy", + confusion_matrix_metrics.BinaryAccuracy(), + (0.7 + 0.9) / (0.7 + 0.9 + 0.5 + 0.0), + ), + ("precision", confusion_matrix_metrics.Precision(), 0.7 / (0.7 + 0.5)), + ("recall", confusion_matrix_metrics.Recall(), 0.7 / (0.7 + 0.0)), + ( + "specificity", + confusion_matrix_metrics.Specificity(), + 0.9 / (0.9 + 0.5), + ), + ("fall_out", confusion_matrix_metrics.FallOut(), 0.5 / (0.5 + 0.9)), + ("miss_rate", confusion_matrix_metrics.MissRate(), 0.0 / (0.0 + 0.7)), + ( + "negative_predictive_value", + confusion_matrix_metrics.NegativePredictiveValue(), + 0.9 / (0.9 + 0.0), + ), + ( + "false_discovery_rate", + confusion_matrix_metrics.FalseDiscoveryRate(), + 0.5 / (0.5 + 0.7), + ), + ( + "false_omission_rate", + confusion_matrix_metrics.FalseOmissionRate(), + 0.0 / (0.0 + 0.9), + ), + ( + "prevalence", + confusion_matrix_metrics.Prevalence(), + (0.7 + 0.0) / (0.7 + 0.9 + 0.5 + 0.0), + ), + ( + "prevalence_threshold", + confusion_matrix_metrics.PrevalenceThreshold(), + ( + math.sqrt((0.7 / (0.7 + 0.0)) * (1.0 - (0.9 / (0.9 + 0.5)))) + + (0.9 / (0.9 + 0.5) - 1.0) + ) + / (0.7 / (0.7 + 0.0) + (0.9 / (0.9 + 0.5)) - 1.0), + ), + ( + "threat_score", + confusion_matrix_metrics.ThreatScore(), + 0.7 / (0.7 + 0.0 + 0.5), + ), + ( + "balanced_accuracy", + confusion_matrix_metrics.BalancedAccuracy(), + ((0.7 / (0.7 + 0.0)) + (0.9 / (0.9 + 0.5))) / 2, + ), + ( + "f1_score", + confusion_matrix_metrics.F1Score(), + 2 * 0.7 / (2 * 0.7 + 0.5 + 0.0), + ), + ( + "matthews_correlation_coefficient", + confusion_matrix_metrics.MatthewsCorrelationCoefficient(), + (0.7 * 0.9 - 0.5 * 0.0) + / math.sqrt((0.7 + 0.5) * (0.7 + 0.0) * (0.9 + 0.5) * (0.9 + 0.0)), + ), + ( + "fowlkes_mallows_index", + confusion_matrix_metrics.FowlkesMallowsIndex(), + math.sqrt(0.7 / (0.7 + 0.5) * 0.7 / (0.7 + 0.0)), + ), + ( + "informedness", + confusion_matrix_metrics.Informedness(), + (0.7 / (0.7 + 0.0)) + (0.9 / (0.9 + 0.5)) - 1.0, + ), + ( + "markedness", + confusion_matrix_metrics.Markedness(), + (0.7 / (0.7 + 0.5)) + (0.9 / (0.9 + 0.0)) - 1.0, + ), + ( + "positive_likelihood_ratio", + confusion_matrix_metrics.PositiveLikelihoodRatio(), + (0.7 / (0.7 + 0.0)) / (0.5 / (0.5 + 0.9)), + ), + ( + "negative_likelihood_ratio", + confusion_matrix_metrics.NegativeLikelihoodRatio(), + (0.0 / (0.0 + 0.7)) / (0.9 / (0.9 + 0.5)), + ), + ( + "predicted_positive_rate", + confusion_matrix_metrics.PredictedPositiveRate(), + (0.7 + 0.5) / (0.7 + 0.9 + 0.5 + 0.0), + ), ) - - -if __name__ == '__main__': - tf.compat.v1.enable_v2_behavior() - tf.test.main() + def testConfusionMatrixMetricsWithWeights(self, metric, expected_value): + if _TF_MAJOR_VERSION < 2 and metric.__class__.__name__ in ( + "SpecificityAtSensitivity", + "SensitivityAtSpecificity", + "PrecisionAtRecall", + "RecallAtPrecision", + "RecallAtFalsePositiveRate", + ): + self.skipTest("Not supported in TFv1.") + + computations = metric.computations(example_weighted=True) + histogram = computations[0] + matrix = computations[1] + derived_metric = computations[2] + + # tp = 0.7 + # tn = 0.9 + # fp = 0.5 + # fn = 0.0 + example1 = { + "labels": np.array([0.0]), + "predictions": np.array([1.0]), + "example_weights": np.array([0.5]), + } + example2 = { + "labels": np.array([1.0]), + "predictions": np.array([0.7]), + "example_weights": np.array([0.7]), + } + example3 = { + "labels": np.array([0.0]), + "predictions": np.array([0.5]), + "example_weights": np.array([0.9]), + } + + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "Create" >> beam.Create([example1, example2, example3]) + | "Process" >> beam.Map(metric_util.to_standard_metric_inputs) + | "AddSlice" >> beam.Map(lambda x: ((), x)) + | "ComputeHistogram" >> beam.CombinePerKey(histogram.combiner) + | "ComputeConfusionMatrix" + >> beam.Map(lambda x: (x[0], matrix.result(x[1]))) + | "ComputeMetric" + >> beam.Map(lambda x: (x[0], derived_metric.result(x[1]))) + ) + + # pylint: enable=no-value-for-parameter + + def check_result(got): + try: + self.assertLen(got, 1) + got_slice_key, got_metrics = got[0] + self.assertEqual(got_slice_key, ()) + key = metric_types.MetricKey( + name=metric.name, example_weighted=True + ) + self.assertIn(key, got_metrics) + # np.testing utils automatically cast floats to arrays which fails + # to catch type mismatches. + self.assertEqual(type(expected_value), type(got_metrics[key])) + np.testing.assert_almost_equal( + np.array(got_metrics[key]), np.array(expected_value), decimal=5 + ) + + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(result, check_result, label="result") + + # LINT.ThenChange(../google/sql:uda_auc_tests) + + @parameterized.named_parameters( + ("auc", confusion_matrix_metrics.AUC(), 0.8571428), + ( + "auc_precision_recall", + confusion_matrix_metrics.AUCPrecisionRecall(), + 0.77369833, + ), + ("true_positives", confusion_matrix_metrics.TruePositives(), 1.4), + ("false_positives", confusion_matrix_metrics.FalsePositives(), 0.6), + ("true_negatives", confusion_matrix_metrics.TrueNegatives(), 1.0), + ("false_negatives", confusion_matrix_metrics.FalseNegatives(), 0.0), + ) + def testConfusionMatrixMetricsWithFractionalLabels(self, metric, expected_value): + computations = metric.computations(example_weighted=True) + histogram = computations[0] + matrix = computations[1] + derived_metric = computations[2] + + # The following examples will be expanded to: + # + # prediction | label | weight + # 0.0 | - | 1.0 + # 0.7 | - | 0.4 + # 0.7 | + | 0.6 + # 1.0 | - | 0.2 + # 1.0 | + | 0.8 + example1 = { + "labels": np.array([0.0]), + "predictions": np.array([0.0]), + "example_weights": np.array([1.0]), + } + example2 = { + "labels": np.array([0.6]), + "predictions": np.array([0.7]), + "example_weights": np.array([1.0]), + } + example3 = { + "labels": np.array([0.8]), + "predictions": np.array([1.0]), + "example_weights": np.array([1.0]), + } + + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "Create" >> beam.Create([example1, example2, example3]) + | "Process" >> beam.Map(metric_util.to_standard_metric_inputs) + | "AddSlice" >> beam.Map(lambda x: ((), x)) + | "ComputeHistogram" >> beam.CombinePerKey(histogram.combiner) + | "ComputeConfusionMatrix" + >> beam.Map(lambda x: (x[0], matrix.result(x[1]))) + | "ComputeMetric" + >> beam.Map(lambda x: (x[0], derived_metric.result(x[1]))) + ) + + # pylint: enable=no-value-for-parameter + + def check_result(got): + try: + self.assertLen(got, 1) + got_slice_key, got_metrics = got[0] + self.assertEqual(got_slice_key, ()) + key = metric_types.MetricKey( + name=metric.name, example_weighted=True + ) + self.assertDictElementsAlmostEqual( + got_metrics, {key: expected_value}, places=5 + ) + + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(result, check_result, label="result") + + @parameterized.named_parameters( + ( + "precision@2 (using sub_key)", + confusion_matrix_metrics.Precision(), + 2, + 1.6 / (1.6 + 3.2), + ), + ( + "precision@2 (using param)", + confusion_matrix_metrics.Precision(top_k=2), + None, + 1.6 / (1.6 + 3.2), + ), + ( + "recall@2 (using sub_key)", + confusion_matrix_metrics.Recall(), + 2, + 1.6 / (1.6 + 0.8), + ), + ( + "recall@2 (using param)", + confusion_matrix_metrics.Recall(top_k=2), + None, + 1.6 / (1.6 + 0.8), + ), + ( + "precision@3 (using sub_key)", + confusion_matrix_metrics.Precision(), + 3, + 1.9 / (1.9 + 5.3), + ), + ( + "recall@3 (using sub_key)", + confusion_matrix_metrics.Recall(), + 3, + 1.9 / (1.9 + 0.5), + ), + ) + def testConfusionMatrixMetricsWithTopK(self, metric, top_k, expected_value): + computations = metric.computations( + sub_keys=[metric_types.SubKey(top_k=top_k)], example_weighted=True + ) + histogram = computations[0] + matrix = computations[1] + derived_metric = computations[2] + + # top_k = 2 + # TP = 0.5*0 + 0.7*1 + 0.9*1 + 0.3*0 = 1.6 + # FP = 0.5*2 + 0.7*1 + 0.9*1 + 0.3*2 = 3.2 + # FN = 0.5*1 + 0.7*0 + 0.9*0 + 0.3*1 = 0.8 + # + # top_k = 3 + # TP = 0.5*0 + 0.7*1 + 0.9*1 + 0.3*1 = 1.9 + # FP = 0.5*3 + 0.7*2 + 0.9*2 + 0.3*2 = 5.3 + # FN = 0.5*1 + 0.7*0 + 0.9*0 + 0.3*0 = 0.5 + example1 = { + "labels": np.array([2]), + "predictions": np.array([0.1, 0.2, 0.1, 0.25, 0.35]), + "example_weights": np.array([0.5]), + } + example2 = { + "labels": np.array([1]), + "predictions": np.array([0.2, 0.3, 0.05, 0.15, 0.3]), + "example_weights": np.array([0.7]), + } + example3 = { + "labels": np.array([3]), + "predictions": np.array([0.01, 0.2, 0.09, 0.5, 0.2]), + "example_weights": np.array([0.9]), + } + example4 = { + "labels": np.array([1]), + "predictions": np.array([0.3, 0.2, 0.05, 0.4, 0.05]), + "example_weights": np.array([0.3]), + } + + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "Create" >> beam.Create([example1, example2, example3, example4]) + | "Process" >> beam.Map(metric_util.to_standard_metric_inputs) + | "AddSlice" >> beam.Map(lambda x: ((), x)) + | "ComputeHistogram" >> beam.CombinePerKey(histogram.combiner) + | "ComputeConfusionMatrix" + >> beam.Map(lambda x: (x[0], matrix.result(x[1]))) + | "ComputeMetric" + >> beam.Map(lambda x: (x[0], derived_metric.result(x[1]))) + ) + + # pylint: enable=no-value-for-parameter + + def check_result(got): + try: + self.assertLen(got, 1) + got_slice_key, got_metrics = got[0] + self.assertEqual(got_slice_key, ()) + if top_k: + sub_key = metric_types.SubKey(top_k=top_k) + else: + sub_key = metric_types.SubKey( + top_k=metric.get_config()["top_k"] + ) + key = metric_types.MetricKey( + name=metric.name, sub_key=sub_key, example_weighted=True + ) + self.assertDictElementsAlmostEqual( + got_metrics, {key: expected_value}, places=5 + ) + + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(result, check_result, label="result") + + @parameterized.named_parameters( + ( + "precision (class_id=1 using sub_key)", + confusion_matrix_metrics.Precision(thresholds=[0.1]), + 1, + 0.5 / (0.5 + 1.6), + ), + ( + "precision (class_id=1 using param)", + confusion_matrix_metrics.Precision(class_id=1, thresholds=[0.1]), + None, + 0.5 / (0.5 + 1.6), + ), + ( + "recall (class_id=3 using sub_key)", + confusion_matrix_metrics.Recall(thresholds=[0.1]), + 3, + 0.7 / (0.7 + 0.9), + ), + ( + "recall (class_id=3 using param)", + confusion_matrix_metrics.Recall(class_id=3, thresholds=[0.1]), + None, + 0.7 / (0.7 + 0.9), + ), + ) + def testConfusionMatrixMetricsWithClassId(self, metric, class_id, expected_value): + computations = metric.computations( + sub_keys=[metric_types.SubKey(class_id=class_id)], example_weighted=True + ) + histogram = computations[0] + matrix = computations[1] + derived_metric = computations[2] + + # class_id = 1, threshold = 0.1 + # TP = 0.5*1 + 0.7*0 + 0.9*0 + 0.3*0 = 0.5 + # FP = 0.5*0 + 0.7*1 + 0.9*1 + 0.3*0 = 1.6 + # FN = 0.5*0 + 0.7*0 + 0.9*0 + 0.3*1 = 0.3 + # + # class_id = 3, threshold = 0.1 + # TP = 0.5*0 + 0.7*1 + 0.9*0 + 0.3*0 = 0.7 + # FP = 0.5*1 + 0.7*0 + 0.9*0 + 0.3*1 = 0.8 + # FN = 0.5*0 + 0.7*0 + 0.9*1 + 0.3*0 = 0.9 + example1 = { + "labels": np.array([1]), + "predictions": np.array([0.1, 0.2, 0.1, 0.25, 0.35]), + "example_weights": np.array([0.5]), + } + example2 = { + "labels": np.array([3]), + "predictions": np.array([0.2, 0.3, 0.05, 0.15, 0.3]), + "example_weights": np.array([0.7]), + } + example3 = { + "labels": np.array([3]), + "predictions": np.array([0.01, 0.2, 0.2, 0.09, 0.5]), + "example_weights": np.array([0.9]), + } + example4 = { + "labels": np.array([1]), + "predictions": np.array([0.1, 0.05, 0.3, 0.4, 0.05]), + "example_weights": np.array([0.3]), + } + + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "Create" >> beam.Create([example1, example2, example3, example4]) + | "Process" >> beam.Map(metric_util.to_standard_metric_inputs) + | "AddSlice" >> beam.Map(lambda x: ((), x)) + | "ComputeHistogram" >> beam.CombinePerKey(histogram.combiner) + | "ComputeConfusionMatrix" + >> beam.Map(lambda x: (x[0], matrix.result(x[1]))) + | "ComputeMetric" + >> beam.Map(lambda x: (x[0], derived_metric.result(x[1]))) + ) + + # pylint: enable=no-value-for-parameter + + def check_result(got): + try: + self.assertLen(got, 1) + got_slice_key, got_metrics = got[0] + self.assertEqual(got_slice_key, ()) + if class_id: + sub_key = metric_types.SubKey(class_id=class_id) + else: + sub_key = metric_types.SubKey( + class_id=metric.get_config()["class_id"] + ) + key = metric_types.MetricKey( + name=metric.name, sub_key=sub_key, example_weighted=True + ) + self.assertDictElementsAlmostEqual( + got_metrics, {key: expected_value}, places=5 + ) + + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(result, check_result, label="result") + + def testConfusionMatrixMetricsWithNan(self): + computations = confusion_matrix_metrics.Specificity().computations( + example_weighted=True + ) + histogram = computations[0] + matrices = computations[1] + metrics = computations[2] + + example1 = { + "labels": np.array([1.0]), + "predictions": np.array([1.0]), + "example_weights": np.array([1.0]), + } + + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "Create" >> beam.Create([example1]) + | "Process" >> beam.Map(metric_util.to_standard_metric_inputs) + | "AddSlice" >> beam.Map(lambda x: ((), x)) + | "ComputeHistogram" >> beam.CombinePerKey(histogram.combiner) + | "ComputeMatrices" >> beam.Map(lambda x: (x[0], matrices.result(x[1]))) + | "ComputeMetrics" >> beam.Map(lambda x: (x[0], metrics.result(x[1]))) + ) + + # pylint: enable=no-value-for-parameter + + def check_result(got): + try: + self.assertLen(got, 1) + got_slice_key, got_metrics = got[0] + self.assertEqual(got_slice_key, ()) + self.assertLen(got_metrics, 1) + key = metrics.keys[0] + self.assertIn(key, got_metrics) + self.assertTrue(math.isnan(got_metrics[key])) + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(result, check_result, label="result") + + @parameterized.named_parameters( + ( + "class_id as param and class_id as sub_key", + confusion_matrix_metrics.Precision(class_id=2), + 2, + None, + ), + ( + "top_k as param and top_k as sub_key", + confusion_matrix_metrics.Precision(top_k=2), + None, + 2, + ), + ) + def testRaisesErrorIfOverlappingSettings(self, metric, class_id, top_k): + with self.assertRaisesRegex( + ValueError, ".*is configured with overlapping settings.*" + ): + metric.computations( + sub_keys=[metric_types.SubKey(class_id=class_id, top_k=top_k)] + ) + + def testConfusionMatrixAtThresholds(self): + computations = confusion_matrix_metrics.ConfusionMatrixAtThresholds( + thresholds=[0.3, 0.5, 0.8] + ).computations(example_weighted=True) + histogram = computations[0] + matrices = computations[1] + metrics = computations[2] + + example1 = { + "labels": np.array([0.0]), + "predictions": np.array([0.0]), + "example_weights": np.array([1.0]), + } + example2 = { + "labels": np.array([0.0]), + "predictions": np.array([0.5]), + "example_weights": np.array([1.0]), + } + example3 = { + "labels": np.array([1.0]), + "predictions": np.array([0.3]), + "example_weights": np.array([1.0]), + } + example4 = { + "labels": np.array([1.0]), + "predictions": np.array([0.9]), + "example_weights": np.array([1.0]), + } + + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "Create" >> beam.Create([example1, example2, example3, example4]) + | "Process" >> beam.Map(metric_util.to_standard_metric_inputs) + | "AddSlice" >> beam.Map(lambda x: ((), x)) + | "ComputeHistogram" >> beam.CombinePerKey(histogram.combiner) + | "ComputeMatrices" >> beam.Map(lambda x: (x[0], matrices.result(x[1]))) + | "ComputeMetrics" >> beam.Map(lambda x: (x[0], metrics.result(x[1]))) + ) + + # pylint: enable=no-value-for-parameter + + def check_result(got): + try: + self.assertLen(got, 1) + got_slice_key, got_metrics = got[0] + self.assertEqual(got_slice_key, ()) + self.assertLen(got_metrics, 1) + key = metric_types.MetricKey( + name="confusion_matrix_at_thresholds", example_weighted=True + ) + self.assertIn(key, got_metrics) + got_metric = got_metrics[key] + self.assertEqual( + binary_confusion_matrices.Matrices( + thresholds=[0.3, 0.5, 0.8], + tp=[1.0, 1.0, 1.0], + tn=[1.0, 2.0, 2.0], + fp=[1.0, 0.0, 0.0], + fn=[1.0, 1.0, 1.0], + ), + got_metric, + ) + + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(result, check_result, label="result") + + @parameterized.named_parameters( + ( + "precision (class_id=1 top_k=2 using sub_key)", + confusion_matrix_metrics.Precision(thresholds=[0.1]), + 1, + 2, + 0.5 / (0.5 + 1.6), + ), + ( + "precision (class_id=1 using param and top_k=2 using sub_key)", + confusion_matrix_metrics.Precision(class_id=1, thresholds=[0.1]), + None, + 2, + 0.5 / (0.5 + 1.6), + ), + ( + "recall (class_id=3 using sub_key and top_k=2 using param)", + confusion_matrix_metrics.Recall(thresholds=[0.1], top_k=2), + 3, + None, + 0.7 / (0.7 + 0.9), + ), + ( + "recall (class_id=3 top_k=2 using param)", + confusion_matrix_metrics.Recall(class_id=3, top_k=2, thresholds=[0.1]), + None, + None, + 0.7 / (0.7 + 0.9), + ), + ) + def testConfusionMatrixMetricsWithClassIdAndTopK( + self, metric, class_id, top_k, expected_value + ): + computations = metric.computations( + sub_keys=[metric_types.SubKey(class_id=class_id, top_k=top_k)], + example_weighted=True, + ) + histogram = computations[0] + matrix = computations[1] + derived_metric = computations[2] + + # class_id = 1, top_k=2 + # TP = 0.5*1 + 0.7*0 + 0.9*0 + 0.3*0 = 0.5 + # FP = 0.5*0 + 0.7*1 + 0.9*1 + 0.3*0 = 1.6 + # FN = 0.5*0 + 0.7*0 + 0.9*0 + 0.3*1 = 0.3 + # + # class_id = 3, top_k=2 + # TP = 0.5*0 + 0.7*1 + 0.9*0 + 0.3*0 = 0.7 + # FP = 0.5*1 + 0.7*0 + 0.9*0 + 0.3*1 = 0.8 + # FN = 0.5*0 + 0.7*0 + 0.9*1 + 0.3*0 = 0.9 + example1 = { + "labels": np.array([1]), + "predictions": np.array([0.1, 0.5, 0.1, 0.45, 0.35]), + "example_weights": np.array([0.5]), + } + example2 = { + "labels": np.array([3]), + "predictions": np.array([0.2, 0.3, 0.05, 0.31, 0.3]), + "example_weights": np.array([0.7]), + } + example3 = { + "labels": np.array([3]), + "predictions": np.array([0.01, 0.2, 0.2, 0.09, 0.5]), + "example_weights": np.array([0.9]), + } + example4 = { + "labels": np.array([1]), + "predictions": np.array([0.1, 0.05, 0.3, 0.4, 0.05]), + "example_weights": np.array([0.3]), + } + + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "Create" >> beam.Create([example1, example2, example3, example4]) + | "Process" >> beam.Map(metric_util.to_standard_metric_inputs) + | "AddSlice" >> beam.Map(lambda x: ((), x)) + | "ComputeHistogram" >> beam.CombinePerKey(histogram.combiner) + | "ComputeConfusionMatrix" + >> beam.Map(lambda x: (x[0], matrix.result(x[1]))) + | "ComputeMetric" + >> beam.Map(lambda x: (x[0], derived_metric.result(x[1]))) + ) + + # pylint: enable=no-value-for-parameter + + def check_result(got): + try: + self.assertLen(got, 1) + got_slice_key, got_metrics = got[0] + self.assertEqual(got_slice_key, ()) + subkey_class_id = class_id or metric.get_config()["class_id"] + subkey_top_k = top_k or metric.get_config()["top_k"] + sub_key = metric_types.SubKey( + class_id=subkey_class_id, top_k=subkey_top_k + ) + key = metric_types.MetricKey( + name=metric.name, sub_key=sub_key, example_weighted=True + ) + self.assertDictElementsAlmostEqual( + got_metrics, {key: expected_value}, places=5 + ) + + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(result, check_result, label="result") + + @parameterized.named_parameters( + ( + "false_positives", + confusion_matrix_metrics.FalsePositiveFeatureSampler( + threshold=0.5, feature_key="example_id", sample_size=2 + ), + "false_positive_feature_sampler", + np.array(["example1", "example2"], dtype=str), + ), + ( + "false_negatives", + confusion_matrix_metrics.FalseNegativeFeatureSampler( + threshold=0.5, feature_key="example_id", sample_size=2 + ), + "false_negative_feature_sampler", + np.array(["example3", "example4"], dtype=str), + ), + ) + def testConfusionMatrixFeatureSamplers( + self, metric, expected_metric_name, expected_value + ): + # false positive + example1 = { + "labels": np.array([0.0]), + "predictions": np.array([1.0]), + "example_weights": np.array([1.0]), + "features": {"example_id": np.array(["example1"])}, + } + # false positive + example2 = { + "labels": np.array([0.0]), + "predictions": np.array([1.0]), + "example_weights": np.array([1.0]), + "features": {"example_id": np.array(["example2"])}, + } + # false negative + example3 = { + "labels": np.array([1.0]), + "predictions": np.array([0.0]), + "example_weights": np.array([1.0]), + "features": {"example_id": np.array(["example3"])}, + } + # false negative + example4 = { + "labels": np.array([1.0]), + "predictions": np.array([0.0]), + "example_weights": np.array([1.0]), + "features": {"example_id": np.array(["example4"])}, + } + + expected_metrics = { + metric_types.MetricKey( + name=expected_metric_name, example_weighted=True + ): expected_value, + } + self.assertDerivedMetricsEqual( + expected_metrics=expected_metrics, + extracts=[example1, example2, example3, example4], + metric=metric, + enable_debug_print=True, + ) + + +if __name__ == "__main__": + tf.compat.v1.enable_v2_behavior() + tf.test.main() diff --git a/tensorflow_model_analysis/metrics/confusion_matrix_plot.py b/tensorflow_model_analysis/metrics/confusion_matrix_plot.py index e563db9156..4badfabf90 100644 --- a/tensorflow_model_analysis/metrics/confusion_matrix_plot.py +++ b/tensorflow_model_analysis/metrics/confusion_matrix_plot.py @@ -13,105 +13,114 @@ # limitations under the License. """Confusion matrix Plot.""" -from typing import Any, Dict, Optional, List +from typing import Any, Dict, List, Optional -from tensorflow_model_analysis.metrics import binary_confusion_matrices -from tensorflow_model_analysis.metrics import metric_types -from tensorflow_model_analysis.metrics import metric_util +from tensorflow_model_analysis.metrics import ( + binary_confusion_matrices, + metric_types, + metric_util, +) from tensorflow_model_analysis.proto import config_pb2 DEFAULT_NUM_THRESHOLDS = 1000 -CONFUSION_MATRIX_PLOT_NAME = 'confusion_matrix_plot' +CONFUSION_MATRIX_PLOT_NAME = "confusion_matrix_plot" class ConfusionMatrixPlot(metric_types.Metric): - """Confusion matrix plot.""" + """Confusion matrix plot.""" - def __init__(self, - num_thresholds: int = DEFAULT_NUM_THRESHOLDS, - name: str = CONFUSION_MATRIX_PLOT_NAME, - **kwargs): - """Initializes confusion matrix plot. + def __init__( + self, + num_thresholds: int = DEFAULT_NUM_THRESHOLDS, + name: str = CONFUSION_MATRIX_PLOT_NAME, + **kwargs, + ): + """Initializes confusion matrix plot. - Args: - num_thresholds: Number of thresholds to use when discretizing the curve. - Values must be > 1. Defaults to 1000. - name: Metric name. - **kwargs: (Optional) Additional args to pass along to init (and eventually - on to _confusion_matrix_plot). These kwargs are useful for subclasses to - pass information from their init to the create_computation_fn. - """ - super().__init__( - metric_util.merge_per_key_computations(self._confusion_matrix_plot), - num_thresholds=num_thresholds, - name=name, - **kwargs) + Args: + ---- + num_thresholds: Number of thresholds to use when discretizing the curve. + Values must be > 1. Defaults to 1000. + name: Metric name. + **kwargs: (Optional) Additional args to pass along to init (and eventually + on to _confusion_matrix_plot). These kwargs are useful for subclasses to + pass information from their init to the create_computation_fn. + """ + super().__init__( + metric_util.merge_per_key_computations(self._confusion_matrix_plot), + num_thresholds=num_thresholds, + name=name, + **kwargs, + ) - def _confusion_matrix_plot( - self, - num_thresholds: int = DEFAULT_NUM_THRESHOLDS, - name: str = CONFUSION_MATRIX_PLOT_NAME, - eval_config: Optional[config_pb2.EvalConfig] = None, - model_name: str = '', - output_name: str = '', - sub_key: Optional[metric_types.SubKey] = None, - aggregation_type: Optional[metric_types.AggregationType] = None, - class_weights: Optional[Dict[int, float]] = None, - example_weighted: bool = False, - preprocessors: Optional[List[metric_types.Preprocessor]] = None, - plot_key: Optional[metric_types.PlotKey] = None, - ) -> metric_types.MetricComputations: - """Returns metric computations for confusion matrix plots.""" - if plot_key: - key = plot_key - else: - key = metric_types.PlotKey( - name=name, - model_name=model_name, - output_name=output_name, - sub_key=sub_key, - example_weighted=example_weighted, - ) + def _confusion_matrix_plot( + self, + num_thresholds: int = DEFAULT_NUM_THRESHOLDS, + name: str = CONFUSION_MATRIX_PLOT_NAME, + eval_config: Optional[config_pb2.EvalConfig] = None, + model_name: str = "", + output_name: str = "", + sub_key: Optional[metric_types.SubKey] = None, + aggregation_type: Optional[metric_types.AggregationType] = None, + class_weights: Optional[Dict[int, float]] = None, + example_weighted: bool = False, + preprocessors: Optional[List[metric_types.Preprocessor]] = None, + plot_key: Optional[metric_types.PlotKey] = None, + ) -> metric_types.MetricComputations: + """Returns metric computations for confusion matrix plots.""" + if plot_key: + key = plot_key + else: + key = metric_types.PlotKey( + name=name, + model_name=model_name, + output_name=output_name, + sub_key=sub_key, + example_weighted=example_weighted, + ) - # The interoploation strategy used here matches how the legacy post export - # metrics calculated its plots. - thresholds = [ - i * 1.0 / num_thresholds for i in range(0, num_thresholds + 1) - ] - thresholds = [-1e-6] + thresholds + # The interoploation strategy used here matches how the legacy post export + # metrics calculated its plots. + thresholds = [i * 1.0 / num_thresholds for i in range(0, num_thresholds + 1)] + thresholds = [-1e-6] + thresholds - # Make sure matrices are calculated. - matrices_computations = binary_confusion_matrices.binary_confusion_matrices( - # Use a custom name since we have a custom interpolation strategy which - # will cause the default naming used by the binary confusion matrix to - # be very long. - name=(binary_confusion_matrices.BINARY_CONFUSION_MATRICES_NAME + '_' + - name), - eval_config=eval_config, - model_name=model_name, - output_name=output_name, - sub_key=sub_key, - aggregation_type=aggregation_type, - class_weights=class_weights, - example_weighted=example_weighted, - thresholds=thresholds, - use_histogram=True, - preprocessors=preprocessors) - matrices_key = matrices_computations[-1].keys[-1] + # Make sure matrices are calculated. + matrices_computations = binary_confusion_matrices.binary_confusion_matrices( + # Use a custom name since we have a custom interpolation strategy which + # will cause the default naming used by the binary confusion matrix to + # be very long. + name=( + binary_confusion_matrices.BINARY_CONFUSION_MATRICES_NAME + "_" + name + ), + eval_config=eval_config, + model_name=model_name, + output_name=output_name, + sub_key=sub_key, + aggregation_type=aggregation_type, + class_weights=class_weights, + example_weighted=example_weighted, + thresholds=thresholds, + use_histogram=True, + preprocessors=preprocessors, + ) + matrices_key = matrices_computations[-1].keys[-1] - def result( - metrics: Dict[metric_types.MetricKey, Any] - ) -> Dict[metric_types.MetricKey, binary_confusion_matrices.Matrices]: # pytype: disable=signature-mismatch # always-use-return-annotations - return { - key: metrics[matrices_key].to_proto().confusion_matrix_at_thresholds - } + def result( + metrics: Dict[metric_types.MetricKey, Any], + ) -> Dict[ + metric_types.MetricKey, binary_confusion_matrices.Matrices + ]: # pytype: disable=signature-mismatch # always-use-return-annotations + return { + key: metrics[matrices_key].to_proto().confusion_matrix_at_thresholds + } - derived_computation = metric_types.DerivedMetricComputation( - keys=[key], result=result) - computations = matrices_computations - computations.append(derived_computation) - return computations + derived_computation = metric_types.DerivedMetricComputation( + keys=[key], result=result + ) + computations = matrices_computations + computations.append(derived_computation) + return computations metric_types.register_metric(ConfusionMatrixPlot) diff --git a/tensorflow_model_analysis/metrics/confusion_matrix_plot_test.py b/tensorflow_model_analysis/metrics/confusion_matrix_plot_test.py index bcb70ad824..70c78d2819 100644 --- a/tensorflow_model_analysis/metrics/confusion_matrix_plot_test.py +++ b/tensorflow_model_analysis/metrics/confusion_matrix_plot_test.py @@ -14,70 +14,74 @@ """Tests for confusion matrix plot.""" import apache_beam as beam -from apache_beam.testing import util import numpy as np import tensorflow as tf -from tensorflow_model_analysis.metrics import confusion_matrix_plot -from tensorflow_model_analysis.metrics import metric_types -from tensorflow_model_analysis.metrics import metric_util +from apache_beam.testing import util + +from tensorflow_model_analysis.metrics import ( + confusion_matrix_plot, + metric_types, + metric_util, +) from tensorflow_model_analysis.utils import test_util class ConfusionMatrixPlotTest(test_util.TensorflowModelAnalysisTest): + def testConfusionMatrixPlot(self): + computations = confusion_matrix_plot.ConfusionMatrixPlot( + num_thresholds=4 + ).computations() + histogram = computations[0] + matrices = computations[1] + plot = computations[2] - def testConfusionMatrixPlot(self): - computations = confusion_matrix_plot.ConfusionMatrixPlot( - num_thresholds=4).computations() - histogram = computations[0] - matrices = computations[1] - plot = computations[2] - - example1 = { - 'labels': np.array([0.0]), - 'predictions': np.array([0.0]), - 'example_weights': np.array([1.0]), - } - example2 = { - 'labels': np.array([0.0]), - 'predictions': np.array([0.5]), - 'example_weights': np.array([1.0]), - } - example3 = { - 'labels': np.array([1.0]), - 'predictions': np.array([0.3]), - 'example_weights': np.array([1.0]), - } - example4 = { - 'labels': np.array([1.0]), - 'predictions': np.array([0.9]), - 'example_weights': np.array([1.0]), - } + example1 = { + "labels": np.array([0.0]), + "predictions": np.array([0.0]), + "example_weights": np.array([1.0]), + } + example2 = { + "labels": np.array([0.0]), + "predictions": np.array([0.5]), + "example_weights": np.array([1.0]), + } + example3 = { + "labels": np.array([1.0]), + "predictions": np.array([0.3]), + "example_weights": np.array([1.0]), + } + example4 = { + "labels": np.array([1.0]), + "predictions": np.array([0.9]), + "example_weights": np.array([1.0]), + } - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'Create' >> beam.Create([example1, example2, example3, example4]) - | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs) - | 'AddSlice' >> beam.Map(lambda x: ((), x)) - | 'ComputeHistogram' >> beam.CombinePerKey(histogram.combiner) - | 'ComputeMatrices' >> beam.Map( - lambda x: (x[0], matrices.result(x[1]))) # pyformat: ignore - | 'ComputePlot' >> beam.Map(lambda x: (x[0], plot.result(x[1])))) + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "Create" >> beam.Create([example1, example2, example3, example4]) + | "Process" >> beam.Map(metric_util.to_standard_metric_inputs) + | "AddSlice" >> beam.Map(lambda x: ((), x)) + | "ComputeHistogram" >> beam.CombinePerKey(histogram.combiner) + | "ComputeMatrices" + >> beam.Map(lambda x: (x[0], matrices.result(x[1]))) # pyformat: ignore + | "ComputePlot" >> beam.Map(lambda x: (x[0], plot.result(x[1]))) + ) - # pylint: enable=no-value-for-parameter + # pylint: enable=no-value-for-parameter - def check_result(got): - try: - self.assertLen(got, 1) - got_slice_key, got_plots = got[0] - self.assertEqual(got_slice_key, ()) - self.assertLen(got_plots, 1) - key = metric_types.PlotKey(name='confusion_matrix_plot') - self.assertIn(key, got_plots) - got_plot = got_plots[key] - self.assertProtoEquals( - """ + def check_result(got): + try: + self.assertLen(got, 1) + got_slice_key, got_plots = got[0] + self.assertEqual(got_slice_key, ()) + self.assertLen(got_plots, 1) + key = metric_types.PlotKey(name="confusion_matrix_plot") + self.assertIn(key, got_plots) + got_plot = got_plots[key] + self.assertProtoEquals( + """ matrices { threshold: -1e-06 false_positives: 2.0 @@ -147,14 +151,14 @@ def check_result(got): false_omission_rate: 0.5 } """, - got_plot, - ) + got_plot, + ) - except AssertionError as err: - raise util.BeamAssertException(err) + except AssertionError as err: + raise util.BeamAssertException(err) - util.assert_that(result, check_result, label='result') + util.assert_that(result, check_result, label="result") -if __name__ == '__main__': - tf.test.main() +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow_model_analysis/metrics/cross_entropy_metrics.py b/tensorflow_model_analysis/metrics/cross_entropy_metrics.py index 7f5b6d4155..c5b4b21197 100644 --- a/tensorflow_model_analysis/metrics/cross_entropy_metrics.py +++ b/tensorflow_model_analysis/metrics/cross_entropy_metrics.py @@ -15,51 +15,49 @@ import abc import dataclasses -from typing import Iterable, Optional, Dict +from typing import Dict, Iterable, Optional import apache_beam as beam import numpy as np -from tensorflow_model_analysis.metrics import metric_types -from tensorflow_model_analysis.metrics import metric_util -from tensorflow_model_analysis.proto import config_pb2 +from tensorflow_model_analysis.metrics import metric_types, metric_util +from tensorflow_model_analysis.proto import config_pb2 -BINARY_CROSSENTROPY_NAME = 'binary_crossentropy' -CATEGORICAL_CROSSENTROPY_NAME = 'categorical_crossentropy' +BINARY_CROSSENTROPY_NAME = "binary_crossentropy" +CATEGORICAL_CROSSENTROPY_NAME = "categorical_crossentropy" class BinaryCrossEntropy(metric_types.Metric): - """Calculates the binary cross entropy. - - The metric computes the cross entropy when there are only two label classes - (0 and 1). See definition at: https://en.wikipedia.org/wiki/Cross_entropy - """ - - def __init__( - self, - name: str = BINARY_CROSSENTROPY_NAME, - from_logits: bool = False, - label_smoothing: float = 0.0, - ): - """Initializes binary cross entropy metric. + """Calculates the binary cross entropy. - Args: - name: The name of the metric. - from_logits: (Optional) Whether output is expected to be a logits tensor. - By default, we consider that output encodes a probability distribution. - label_smoothing: Float in [0, 1]. If > `0` then smooth the labels by - squeezing them towards 0.5 That is, using `1. - 0.5 * label_smoothing` - for the target class and `0.5 * label_smoothing` for the non-target - class. + The metric computes the cross entropy when there are only two label classes + (0 and 1). See definition at: https://en.wikipedia.org/wiki/Cross_entropy """ - super().__init__( - metric_util.merge_per_key_computations( - _binary_cross_entropy_computations - ), - name=name, - from_logits=from_logits, - label_smoothing=label_smoothing, - ) + + def __init__( + self, + name: str = BINARY_CROSSENTROPY_NAME, + from_logits: bool = False, + label_smoothing: float = 0.0, + ): + """Initializes binary cross entropy metric. + + Args: + ---- + name: The name of the metric. + from_logits: (Optional) Whether output is expected to be a logits tensor. + By default, we consider that output encodes a probability distribution. + label_smoothing: Float in [0, 1]. If > `0` then smooth the labels by + squeezing them towards 0.5 That is, using `1. - 0.5 * label_smoothing` + for the target class and `0.5 * label_smoothing` for the non-target + class. + """ + super().__init__( + metric_util.merge_per_key_computations(_binary_cross_entropy_computations), + name=name, + from_logits=from_logits, + label_smoothing=label_smoothing, + ) def _binary_cross_entropy_computations( @@ -74,86 +72,88 @@ def _binary_cross_entropy_computations( aggregation_type: Optional[metric_types.AggregationType] = None, class_weights: Optional[Dict[int, float]] = None, ) -> metric_types.MetricComputations: - """Returns metric computations for binary cross entropy. - - Args: - name: The name of the metric. - from_logits: (Optional) Whether output is expected to be a logits tensor. By - default, we consider that output encodes a probability distribution. - label_smoothing: Float in [0, 1]. If > `0` then smooth the labels by - squeezing them towards 0.5 That is, using `1. - 0.5 * label_smoothing` for - the target class and `0.5 * label_smoothing` for the non-target class. - eval_config: The configurations for TFMA pipeline. - model_name: The name of the model to get predictions from. - output_name: The name of the output under the model to get predictions from. - example_weighted: Whether the examples have specified weights. - sub_key: The key includes class, top-k, k information. It should only be in - classfication problems. - aggregation_type: The method to aggregate over classes. It should only be in - classfication problems. - class_weights: The weight of classes. It should only be in classfication - problems. - """ - key = metric_types.MetricKey( - name=name, - model_name=model_name, - output_name=output_name, - sub_key=sub_key, - example_weighted=example_weighted, - ) - return [ - metric_types.MetricComputation( - keys=[key], - preprocessors=None, - combiner=_BinaryCrossEntropyCombiner( - from_logits=from_logits, - label_smoothing=label_smoothing, - eval_config=eval_config, - model_name=model_name, - output_name=output_name, - metric_key=key, - example_weighted=example_weighted, - aggregation_type=aggregation_type, - class_weights=class_weights, - ), - ) - ] + """Returns metric computations for binary cross entropy. + + Args: + ---- + name: The name of the metric. + from_logits: (Optional) Whether output is expected to be a logits tensor. By + default, we consider that output encodes a probability distribution. + label_smoothing: Float in [0, 1]. If > `0` then smooth the labels by + squeezing them towards 0.5 That is, using `1. - 0.5 * label_smoothing` for + the target class and `0.5 * label_smoothing` for the non-target class. + eval_config: The configurations for TFMA pipeline. + model_name: The name of the model to get predictions from. + output_name: The name of the output under the model to get predictions from. + example_weighted: Whether the examples have specified weights. + sub_key: The key includes class, top-k, k information. It should only be in + classfication problems. + aggregation_type: The method to aggregate over classes. It should only be in + classfication problems. + class_weights: The weight of classes. It should only be in classfication + problems. + """ + key = metric_types.MetricKey( + name=name, + model_name=model_name, + output_name=output_name, + sub_key=sub_key, + example_weighted=example_weighted, + ) + return [ + metric_types.MetricComputation( + keys=[key], + preprocessors=None, + combiner=_BinaryCrossEntropyCombiner( + from_logits=from_logits, + label_smoothing=label_smoothing, + eval_config=eval_config, + model_name=model_name, + output_name=output_name, + metric_key=key, + example_weighted=example_weighted, + aggregation_type=aggregation_type, + class_weights=class_weights, + ), + ) + ] metric_types.register_metric(BinaryCrossEntropy) class CategoricalCrossEntropy(metric_types.Metric): - """Calculates the categorical cross entropy. - - The metric computes the cross entropy when there are multiple classes. - It outputs a numpy array. - """ + """Calculates the categorical cross entropy. - def __init__( - self, - name: str = CATEGORICAL_CROSSENTROPY_NAME, - from_logits: bool = False, - label_smoothing: float = 0.0, - ): - """Initializes categorical cross entropy metric. - - Args: - name: The name of the metric. - from_logits: (Optional) Whether output is expected to be a logits tensor. - By default, we consider that output encodes a probability distribution. - label_smoothing: Float in [0, 1]. If > `0` then smooth the labels. For - example, if `0.1`, use `0.1 / num_classes` for non-target labels and - `0.9 + 0.1 / num_classes` for target labels. + The metric computes the cross entropy when there are multiple classes. + It outputs a numpy array. """ - super().__init__( - metric_util.merge_per_key_computations( - _categorical_cross_entropy_computations - ), - name=name, - from_logits=from_logits, - label_smoothing=label_smoothing, - ) + + def __init__( + self, + name: str = CATEGORICAL_CROSSENTROPY_NAME, + from_logits: bool = False, + label_smoothing: float = 0.0, + ): + """Initializes categorical cross entropy metric. + + Args: + ---- + name: The name of the metric. + from_logits: (Optional) Whether output is expected to be a logits tensor. + By default, we consider that output encodes a probability distribution. + label_smoothing: Float in [0, 1]. If > `0` then smooth the labels. For + example, if `0.1`, use `0.1 / num_classes` for non-target labels and + `0.9 + 0.1 / num_classes` for target labels. + """ + super().__init__( + metric_util.merge_per_key_computations( + _categorical_cross_entropy_computations + ), + name=name, + from_logits=from_logits, + label_smoothing=label_smoothing, + ) def _categorical_cross_entropy_computations( @@ -168,50 +168,51 @@ def _categorical_cross_entropy_computations( aggregation_type: Optional[metric_types.AggregationType] = None, class_weights: Optional[Dict[int, float]] = None, ) -> metric_types.MetricComputations: - """Returns metric computations for categorical cross entropy. - - Args: - name: The name of the metric. - from_logits: (Optional) Whether output is expected to be a logits tensor. By - default, we consider that output encodes a probability distribution. - label_smoothing: Float in [0, 1]. If > `0` then smooth the labels. For - example, if `0.1`, use `0.1 / num_classes` for non-target labels and `0.9 - + 0.1 / num_classes` for target labels. - eval_config: The configurations for TFMA pipeline. - model_name: The name of the model to get predictions from. - output_name: The name of the output under the model to get predictions from. - example_weighted: Whether the examples have specified weights. - sub_key: The key includes class, top-k, k information. It should only be in - classfication problems. - aggregation_type: The method to aggregate over classes. It should only be in - classfication problems. - class_weights: The weight of classes. It should only be in classfication - problems. - """ - key = metric_types.MetricKey( - name=name, - model_name=model_name, - output_name=output_name, - sub_key=sub_key, - example_weighted=example_weighted, - ) - return [ - metric_types.MetricComputation( - keys=[key], - preprocessors=None, - combiner=_CategoricalCrossEntropyCombiner( - from_logits=from_logits, - label_smoothing=label_smoothing, - eval_config=eval_config, - model_name=model_name, - output_name=output_name, - metric_key=key, - example_weighted=example_weighted, - aggregation_type=aggregation_type, - class_weights=class_weights, - ), - ) - ] + """Returns metric computations for categorical cross entropy. + + Args: + ---- + name: The name of the metric. + from_logits: (Optional) Whether output is expected to be a logits tensor. By + default, we consider that output encodes a probability distribution. + label_smoothing: Float in [0, 1]. If > `0` then smooth the labels. For + example, if `0.1`, use `0.1 / num_classes` for non-target labels and `0.9 + + 0.1 / num_classes` for target labels. + eval_config: The configurations for TFMA pipeline. + model_name: The name of the model to get predictions from. + output_name: The name of the output under the model to get predictions from. + example_weighted: Whether the examples have specified weights. + sub_key: The key includes class, top-k, k information. It should only be in + classfication problems. + aggregation_type: The method to aggregate over classes. It should only be in + classfication problems. + class_weights: The weight of classes. It should only be in classfication + problems. + """ + key = metric_types.MetricKey( + name=name, + model_name=model_name, + output_name=output_name, + sub_key=sub_key, + example_weighted=example_weighted, + ) + return [ + metric_types.MetricComputation( + keys=[key], + preprocessors=None, + combiner=_CategoricalCrossEntropyCombiner( + from_logits=from_logits, + label_smoothing=label_smoothing, + eval_config=eval_config, + model_name=model_name, + output_name=output_name, + metric_key=key, + example_weighted=example_weighted, + aggregation_type=aggregation_type, + class_weights=class_weights, + ), + ) + ] metric_types.register_metric(CategoricalCrossEntropy) @@ -219,178 +220,176 @@ def _categorical_cross_entropy_computations( @dataclasses.dataclass class _CrossEntropyAccumulator: - """Accumulator for computing cross entropy metrics.""" + """Accumulator for computing cross entropy metrics.""" - total_cross_entropy: float = 0.0 - total_example_weights: float = 0.0 + total_cross_entropy: float = 0.0 + total_example_weights: float = 0.0 - def merge(self, other: '_CrossEntropyAccumulator'): - self.total_cross_entropy += other.total_cross_entropy - self.total_example_weights += other.total_example_weights + def merge(self, other: "_CrossEntropyAccumulator"): + self.total_cross_entropy += other.total_cross_entropy + self.total_example_weights += other.total_example_weights class _CrossEntropyCombiner(beam.CombineFn, metaclass=abc.ABCMeta): - """A combiner which computes cross entropy metrics. - - Two importnat parameters for cross entropy calcualtion. - from_logits: (Optional) Whether output is expected to be a logits tensor. - By default, we consider that output encodes a probability distribution. - label_smoothing: Float in [0, 1]. If > `0` then smooth the labels. For - example, if `0.1`, use `0.1 / num_classes` for non-target labels and - `0.9 + 0.1 / num_classes` for target labels. - """ - - def __init__( - self, - eval_config: config_pb2.EvalConfig, - model_name: str, - output_name: str, - metric_key: metric_types.MetricKey, - aggregation_type: Optional[metric_types.AggregationType], - class_weights: Optional[Dict[int, float]], - example_weighted: bool, - from_logits: bool = False, - label_smoothing: float = 0.0, - ): - self._eval_config = eval_config - self._model_name = model_name - self._output_name = output_name - self._metric_key = metric_key - self._example_weighted = example_weighted - self._aggregation_type = aggregation_type - self._class_weights = class_weights - self._from_logits = from_logits - self._label_smoothing = label_smoothing - - @abc.abstractmethod - def _cross_entropy(self, label, prediction) -> float: - """Returns the cross entropy between the label and prediction. - - Subclasses must override this method. Preditctions should encode the - probability distribution. The output of cross entropy is a numpy array. + """A combiner which computes cross entropy metrics. - Args: - label: The numpy array of floats. It should be the class probabilities - prediction: The numpy array of floats. It should be probabilities or - logits. + Two importnat parameters for cross entropy calcualtion. + from_logits: (Optional) Whether output is expected to be a logits tensor. + By default, we consider that output encodes a probability distribution. + label_smoothing: Float in [0, 1]. If > `0` then smooth the labels. For + example, if `0.1`, use `0.1 / num_classes` for non-target labels and + `0.9 + 0.1 / num_classes` for target labels. """ - raise NotImplementedError('Must be implemented in subclasses.') - - def create_accumulator(self) -> _CrossEntropyAccumulator: - return _CrossEntropyAccumulator() - - def add_input( - self, - accumulator: _CrossEntropyAccumulator, - element: metric_types.StandardMetricInputs, - ) -> _CrossEntropyAccumulator: - lpe_iterator = metric_util.to_label_prediction_example_weight( - element, - eval_config=self._eval_config, - model_name=self._metric_key.model_name, - output_name=self._metric_key.output_name, - aggregation_type=self._aggregation_type, - class_weights=self._class_weights, - example_weighted=self._example_weighted, - sub_key=self._metric_key.sub_key, - flatten=False, - ) - for label, prediction, example_weight in lpe_iterator: - # The np.item method makes sure the result is a one element numpy array - # and returns the single element as a float. - accumulator.total_cross_entropy += self._cross_entropy( - label, prediction - ) * metric_util.safe_to_scalar(example_weight) - accumulator.total_example_weights += metric_util.safe_to_scalar( - example_weight - ) - - return accumulator - - def merge_accumulators( - self, accumulators: Iterable[_CrossEntropyAccumulator] - ) -> _CrossEntropyAccumulator: - result = next(iter(accumulators)) - for accumulator in accumulators: - result.merge(accumulator) - return result - - def extract_output( - self, accumulator: _CrossEntropyAccumulator - ) -> metric_types.MetricsDict: - result = np.divide( - accumulator.total_cross_entropy, accumulator.total_example_weights - ) - return {self._metric_key: result} + + def __init__( + self, + eval_config: config_pb2.EvalConfig, + model_name: str, + output_name: str, + metric_key: metric_types.MetricKey, + aggregation_type: Optional[metric_types.AggregationType], + class_weights: Optional[Dict[int, float]], + example_weighted: bool, + from_logits: bool = False, + label_smoothing: float = 0.0, + ): + self._eval_config = eval_config + self._model_name = model_name + self._output_name = output_name + self._metric_key = metric_key + self._example_weighted = example_weighted + self._aggregation_type = aggregation_type + self._class_weights = class_weights + self._from_logits = from_logits + self._label_smoothing = label_smoothing + + @abc.abstractmethod + def _cross_entropy(self, label, prediction) -> float: + """Returns the cross entropy between the label and prediction. + + Subclasses must override this method. Preditctions should encode the + probability distribution. The output of cross entropy is a numpy array. + + Args: + ---- + label: The numpy array of floats. It should be the class probabilities + prediction: The numpy array of floats. It should be probabilities or + logits. + """ + raise NotImplementedError("Must be implemented in subclasses.") + + def create_accumulator(self) -> _CrossEntropyAccumulator: + return _CrossEntropyAccumulator() + + def add_input( + self, + accumulator: _CrossEntropyAccumulator, + element: metric_types.StandardMetricInputs, + ) -> _CrossEntropyAccumulator: + lpe_iterator = metric_util.to_label_prediction_example_weight( + element, + eval_config=self._eval_config, + model_name=self._metric_key.model_name, + output_name=self._metric_key.output_name, + aggregation_type=self._aggregation_type, + class_weights=self._class_weights, + example_weighted=self._example_weighted, + sub_key=self._metric_key.sub_key, + flatten=False, + ) + for label, prediction, example_weight in lpe_iterator: + # The np.item method makes sure the result is a one element numpy array + # and returns the single element as a float. + accumulator.total_cross_entropy += self._cross_entropy( + label, prediction + ) * metric_util.safe_to_scalar(example_weight) + accumulator.total_example_weights += metric_util.safe_to_scalar( + example_weight + ) + + return accumulator + + def merge_accumulators( + self, accumulators: Iterable[_CrossEntropyAccumulator] + ) -> _CrossEntropyAccumulator: + result = next(iter(accumulators)) + for accumulator in accumulators: + result.merge(accumulator) + return result + + def extract_output( + self, accumulator: _CrossEntropyAccumulator + ) -> metric_types.MetricsDict: + result = np.divide( + accumulator.total_cross_entropy, accumulator.total_example_weights + ) + return {self._metric_key: result} class _BinaryCrossEntropyCombiner(_CrossEntropyCombiner): - """A combiner which computes binary cross entropy.""" - - def _cross_entropy( - self, - label: np.ndarray, - prediction: np.ndarray, - ) -> float: - # smooth labels - label = label * (1.0 - self._label_smoothing) + 0.5 * self._label_smoothing - - # If predictions are logits rather than probability, then the probability - # should be sigmoid(logits). In this case, starting from logits, we can - # derive the formula for cross entropy which is expressed in logits. - # Let y = label, x = prediction logits - # If x > 0, - # Cross entropy loss = y * - log(sigmoid(x)) + (1-y) * -log(1-sigmoid(x)) - # = x - x * y + log(1 + exp(-x)) - # If x < 0, - # Cross entropy loss = -x * y + log(1 + exp(x)) - # In summary, to merge the x > 0 and x < 0 cases, we obtain, - # Cross entryopy loss = max(x, 0) - x * y + log(1 + exp(-abs(x))) - if self._from_logits: - elementwise_binary_cross_entropy = ( - np.maximum(prediction, 0) - - np.multiply(prediction, label) - + np.log(1 + np.exp(-np.abs(prediction))) - ) - else: - elementwise_binary_cross_entropy = -np.multiply( - label, np.log(prediction) - ) - np.multiply((1 - label), np.log(1 - prediction)) - binary_cross_entropy = np.mean(elementwise_binary_cross_entropy) - # The np.item method makes sure the result is a one element numpy array and - # returns the single element as a float. - return metric_util.safe_to_scalar(binary_cross_entropy) + """A combiner which computes binary cross entropy.""" + + def _cross_entropy( + self, + label: np.ndarray, + prediction: np.ndarray, + ) -> float: + # smooth labels + label = label * (1.0 - self._label_smoothing) + 0.5 * self._label_smoothing + + # If predictions are logits rather than probability, then the probability + # should be sigmoid(logits). In this case, starting from logits, we can + # derive the formula for cross entropy which is expressed in logits. + # Let y = label, x = prediction logits + # If x > 0, + # Cross entropy loss = y * - log(sigmoid(x)) + (1-y) * -log(1-sigmoid(x)) + # = x - x * y + log(1 + exp(-x)) + # If x < 0, + # Cross entropy loss = -x * y + log(1 + exp(x)) + # In summary, to merge the x > 0 and x < 0 cases, we obtain, + # Cross entryopy loss = max(x, 0) - x * y + log(1 + exp(-abs(x))) + if self._from_logits: + elementwise_binary_cross_entropy = ( + np.maximum(prediction, 0) + - np.multiply(prediction, label) + + np.log(1 + np.exp(-np.abs(prediction))) + ) + else: + elementwise_binary_cross_entropy = -np.multiply( + label, np.log(prediction) + ) - np.multiply((1 - label), np.log(1 - prediction)) + binary_cross_entropy = np.mean(elementwise_binary_cross_entropy) + # The np.item method makes sure the result is a one element numpy array and + # returns the single element as a float. + return metric_util.safe_to_scalar(binary_cross_entropy) class _CategoricalCrossEntropyCombiner(_CrossEntropyCombiner): - """A combiner which computes categorical cross entropy.""" - - def _cross_entropy( - self, - label: np.ndarray, - prediction: np.ndarray, - ) -> float: - # smooth labels - num_classes = prediction.shape[0] - label = ( - label * (1.0 - self._label_smoothing) - + self._label_smoothing / num_classes - ) - - if self._from_logits: - # Let z_i be the logits of probability p_i - # z_i = log( p_i / sum_(j!=i) p_j ) - # p_i = exp(z_i) / sum(exp(z)) - prediction = np.exp(prediction) - # Normalize prediction probability to 1 - prediction /= np.sum(prediction) - - # It assumes each row is a single prediction and each column is a class. - # The reduction on axis -1 is a classwise reduction. - categorical_cross_entropy = -np.sum( - np.multiply(label, np.ma.log(prediction)) - ) - - # The np.item method makes sure the result is a one element numpy array and - # returns the single element as a float. - return metric_util.safe_to_scalar(categorical_cross_entropy) + """A combiner which computes categorical cross entropy.""" + + def _cross_entropy( + self, + label: np.ndarray, + prediction: np.ndarray, + ) -> float: + # smooth labels + num_classes = prediction.shape[0] + label = ( + label * (1.0 - self._label_smoothing) + self._label_smoothing / num_classes + ) + + if self._from_logits: + # Let z_i be the logits of probability p_i + # z_i = log( p_i / sum_(j!=i) p_j ) + # p_i = exp(z_i) / sum(exp(z)) + prediction = np.exp(prediction) + # Normalize prediction probability to 1 + prediction /= np.sum(prediction) + + # It assumes each row is a single prediction and each column is a class. + # The reduction on axis -1 is a classwise reduction. + categorical_cross_entropy = -np.sum(np.multiply(label, np.ma.log(prediction))) + + # The np.item method makes sure the result is a one element numpy array and + # returns the single element as a float. + return metric_util.safe_to_scalar(categorical_cross_entropy) diff --git a/tensorflow_model_analysis/metrics/cross_entropy_metrics_test.py b/tensorflow_model_analysis/metrics/cross_entropy_metrics_test.py index 133abf1b29..b183db5dc1 100644 --- a/tensorflow_model_analysis/metrics/cross_entropy_metrics_test.py +++ b/tensorflow_model_analysis/metrics/cross_entropy_metrics_test.py @@ -12,171 +12,170 @@ # See the License for the specific language governing permissions and # limitations under the License. """Tests for cross entropy related metrics.""" -from absl.testing import absltest -from absl.testing import parameterized + import apache_beam as beam -from apache_beam.testing import util import numpy as np -from tensorflow_model_analysis.metrics import cross_entropy_metrics -from tensorflow_model_analysis.metrics import metric_util +from absl.testing import absltest, parameterized +from apache_beam.testing import util +from tensorflow_model_analysis.metrics import cross_entropy_metrics, metric_util -class CrossEntropyTest(parameterized.TestCase): - @parameterized.named_parameters( - # To be consistent with Keras, a single example can have multiple - # predictions and labels. - dict( - testcase_name='_binary_two_examples', - extracts=[ - { - 'labels': np.array([0]), - 'predictions': np.array([0.6]), - }, - { - 'labels': np.array([1]), - 'predictions': np.array([0.6]), - 'example_weights': np.array([0.8]), - }, - ], - metric=cross_entropy_metrics.BinaryCrossEntropy( - from_logits=False, label_smoothing=0.0 - ), - expected_value=0.736083, - ), - dict( - testcase_name='_binary_two_examples_per_batch_from_logits', - extracts=[ - { - 'labels': np.array([0, 1]), - 'predictions': np.array([-18.6, 0.51]), - }, - { - 'labels': np.array([0, 0]), - 'predictions': np.array([2.94, -12.8]), - }, - ], - metric=cross_entropy_metrics.BinaryCrossEntropy( - from_logits=True, label_smoothing=0.0 - ), - expected_value=0.865457, - ), - dict( - testcase_name='_binary_two_examples_with_label_smoothing', - extracts=[ - { - 'labels': np.array([0]), - 'predictions': np.array([0.6]), - }, - { - 'labels': np.array([1]), - 'predictions': np.array([0.6]), - 'example_weights': np.array([0.8]), - }, - ], - metric=cross_entropy_metrics.BinaryCrossEntropy( - from_logits=False, label_smoothing=0.1 - ), - expected_value=0.733831, - ), - dict( - testcase_name='_categorical_two_examples', - extracts=[ - { - 'labels': np.array([0, 1, 0]), - 'predictions': np.array([0.05, 0.95, 0]), - }, - { - 'labels': np.array([0, 0, 1]), - 'predictions': np.array([0.1, 0.8, 0.1]), - }, - ], - metric=cross_entropy_metrics.CategoricalCrossEntropy( - from_logits=False, label_smoothing=0.0 - ), - expected_value=1.176939, - ), - dict( - testcase_name='_categorical_two_examples_with_weights', - extracts=[ - { - 'labels': np.array([0, 1, 0]), - 'predictions': np.array([0.05, 0.95, 0]), - 'example_weights': np.array([0.3]), - }, - { - 'labels': np.array([0, 0, 1]), - 'predictions': np.array([0.1, 0.8, 0.1]), - 'example_weights': np.array([0.7]), - }, - ], - metric=cross_entropy_metrics.CategoricalCrossEntropy( - from_logits=False, label_smoothing=0.0 - ), - expected_value=1.627198, - ), - dict( - testcase_name='_categorical_two_examples_from_logits', - extracts=[ - { - 'labels': np.array([0, 1, 0]), - 'predictions': np.array([-5, 0.95, -10]), - }, - { - 'labels': np.array([0, 0, 1]), - 'predictions': np.array([5, -1, 0.5]), - }, - ], - metric=cross_entropy_metrics.CategoricalCrossEntropy( - from_logits=True, label_smoothing=0.0 - ), - expected_value=2.258058, - ), - dict( - testcase_name='_categorical_two_examples_from_logits_with_smoothing', - extracts=[ - { - 'labels': np.array([0, 1, 0]), - 'predictions': np.array([-5, 0.95, -10]), - }, - { - 'labels': np.array([0, 0, 1]), - 'predictions': np.array([5, -1, 0.5]), - }, - ], - metric=cross_entropy_metrics.CategoricalCrossEntropy( - from_logits=True, label_smoothing=0.1 - ), - expected_value=2.489725, - ), - ) - def testBinaryCrossEntropy(self, extracts, metric, expected_value): - computations = metric.computations(example_weighted=True) - computation = computations[0] - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'Create' >> beam.Create(extracts) - | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs) - | 'AddSlice' >> beam.Map(lambda x: ((), x)) - | 'ComputeMetric' >> beam.CombinePerKey(computation.combiner) - ) +class CrossEntropyTest(parameterized.TestCase): + @parameterized.named_parameters( + # To be consistent with Keras, a single example can have multiple + # predictions and labels. + dict( + testcase_name="_binary_two_examples", + extracts=[ + { + "labels": np.array([0]), + "predictions": np.array([0.6]), + }, + { + "labels": np.array([1]), + "predictions": np.array([0.6]), + "example_weights": np.array([0.8]), + }, + ], + metric=cross_entropy_metrics.BinaryCrossEntropy( + from_logits=False, label_smoothing=0.0 + ), + expected_value=0.736083, + ), + dict( + testcase_name="_binary_two_examples_per_batch_from_logits", + extracts=[ + { + "labels": np.array([0, 1]), + "predictions": np.array([-18.6, 0.51]), + }, + { + "labels": np.array([0, 0]), + "predictions": np.array([2.94, -12.8]), + }, + ], + metric=cross_entropy_metrics.BinaryCrossEntropy( + from_logits=True, label_smoothing=0.0 + ), + expected_value=0.865457, + ), + dict( + testcase_name="_binary_two_examples_with_label_smoothing", + extracts=[ + { + "labels": np.array([0]), + "predictions": np.array([0.6]), + }, + { + "labels": np.array([1]), + "predictions": np.array([0.6]), + "example_weights": np.array([0.8]), + }, + ], + metric=cross_entropy_metrics.BinaryCrossEntropy( + from_logits=False, label_smoothing=0.1 + ), + expected_value=0.733831, + ), + dict( + testcase_name="_categorical_two_examples", + extracts=[ + { + "labels": np.array([0, 1, 0]), + "predictions": np.array([0.05, 0.95, 0]), + }, + { + "labels": np.array([0, 0, 1]), + "predictions": np.array([0.1, 0.8, 0.1]), + }, + ], + metric=cross_entropy_metrics.CategoricalCrossEntropy( + from_logits=False, label_smoothing=0.0 + ), + expected_value=1.176939, + ), + dict( + testcase_name="_categorical_two_examples_with_weights", + extracts=[ + { + "labels": np.array([0, 1, 0]), + "predictions": np.array([0.05, 0.95, 0]), + "example_weights": np.array([0.3]), + }, + { + "labels": np.array([0, 0, 1]), + "predictions": np.array([0.1, 0.8, 0.1]), + "example_weights": np.array([0.7]), + }, + ], + metric=cross_entropy_metrics.CategoricalCrossEntropy( + from_logits=False, label_smoothing=0.0 + ), + expected_value=1.627198, + ), + dict( + testcase_name="_categorical_two_examples_from_logits", + extracts=[ + { + "labels": np.array([0, 1, 0]), + "predictions": np.array([-5, 0.95, -10]), + }, + { + "labels": np.array([0, 0, 1]), + "predictions": np.array([5, -1, 0.5]), + }, + ], + metric=cross_entropy_metrics.CategoricalCrossEntropy( + from_logits=True, label_smoothing=0.0 + ), + expected_value=2.258058, + ), + dict( + testcase_name="_categorical_two_examples_from_logits_with_smoothing", + extracts=[ + { + "labels": np.array([0, 1, 0]), + "predictions": np.array([-5, 0.95, -10]), + }, + { + "labels": np.array([0, 0, 1]), + "predictions": np.array([5, -1, 0.5]), + }, + ], + metric=cross_entropy_metrics.CategoricalCrossEntropy( + from_logits=True, label_smoothing=0.1 + ), + expected_value=2.489725, + ), + ) + def testBinaryCrossEntropy(self, extracts, metric, expected_value): + computations = metric.computations(example_weighted=True) + computation = computations[0] + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "Create" >> beam.Create(extracts) + | "Process" >> beam.Map(metric_util.to_standard_metric_inputs) + | "AddSlice" >> beam.Map(lambda x: ((), x)) + | "ComputeMetric" >> beam.CombinePerKey(computation.combiner) + ) - # pylint: enable=no-value-for-parameter - def check_result(got): - try: - self.assertLen(got, 1) - got_slice_key, got_metrics = got[0] - self.assertEqual(got_slice_key, ()) - key = computation.keys[0] - self.assertIn(key, got_metrics) - self.assertAlmostEqual(got_metrics[key], expected_value, places=5) - except AssertionError as err: - raise util.BeamAssertException() from err + # pylint: enable=no-value-for-parameter + def check_result(got): + try: + self.assertLen(got, 1) + got_slice_key, got_metrics = got[0] + self.assertEqual(got_slice_key, ()) + key = computation.keys[0] + self.assertIn(key, got_metrics) + self.assertAlmostEqual(got_metrics[key], expected_value, places=5) + except AssertionError as err: + raise util.BeamAssertException() from err - util.assert_that(result, check_result, label='result') + util.assert_that(result, check_result, label="result") -if __name__ == '__main__': - absltest.main() +if __name__ == "__main__": + absltest.main() diff --git a/tensorflow_model_analysis/metrics/exact_match.py b/tensorflow_model_analysis/metrics/exact_match.py index 851c803f99..b1e971a795 100644 --- a/tensorflow_model_analysis/metrics/exact_match.py +++ b/tensorflow_model_analysis/metrics/exact_match.py @@ -12,39 +12,40 @@ # See the License for the specific language governing permissions and # limitations under the License. """Exact match metric.""" + import json from typing import Dict, Iterable, Optional import apache_beam as beam -from tensorflow_model_analysis.metrics import metric_types -from tensorflow_model_analysis.metrics import metric_util + +from tensorflow_model_analysis.metrics import metric_types, metric_util from tensorflow_model_analysis.proto import config_pb2 -EXACT_MATCH_NAME = 'exact_match' -_JSON = 'json' +EXACT_MATCH_NAME = "exact_match" +_JSON = "json" _CONVERT_TO_VALUES = frozenset([_JSON]) class ExactMatch(metric_types.Metric): - """Exact Match Metric.""" - - def __init__(self, - name: str = EXACT_MATCH_NAME, - convert_to: Optional[str] = None): - """Initializes exact match metric. - - Args: - name: The name of the metric to use. - convert_to: The conversion to perform before checking equality. - """ - - super().__init__( - metric_util.merge_per_key_computations(_exact_match), - name=name, - convert_to=convert_to) - if convert_to and convert_to not in _CONVERT_TO_VALUES: - raise ValueError('convert_to can only be one of the following: %s' % - str(convert_to)) + """Exact Match Metric.""" + + def __init__(self, name: str = EXACT_MATCH_NAME, convert_to: Optional[str] = None): + """Initializes exact match metric. + + Args: + ---- + name: The name of the metric to use. + convert_to: The conversion to perform before checking equality. + """ + super().__init__( + metric_util.merge_per_key_computations(_exact_match), + name=name, + convert_to=convert_to, + ) + if convert_to and convert_to not in _CONVERT_TO_VALUES: + raise ValueError( + "convert_to can only be one of the following: %s" % str(convert_to) + ) metric_types.register_metric(ExactMatch) @@ -53,97 +54,121 @@ def __init__(self, def _exact_match( name: str, eval_config: Optional[config_pb2.EvalConfig] = None, - model_name: str = '', - output_name: str = '', + model_name: str = "", + output_name: str = "", sub_key: Optional[metric_types.SubKey] = None, aggregation_type: Optional[metric_types.AggregationType] = None, class_weights: Optional[Dict[int, float]] = None, example_weighted: bool = False, - convert_to: Optional[str] = None) -> metric_types.MetricComputations: - """Returns metric computations for computing the exact match score.""" - key = metric_types.MetricKey( - name=name, - model_name=model_name, - output_name=output_name, - sub_key=sub_key, - example_weighted=example_weighted) - return [ - metric_types.MetricComputation( - keys=[key], - preprocessors=None, - combiner=_ExactMatchCombiner(key, eval_config, aggregation_type, - class_weights, example_weighted, - convert_to)) - ] + convert_to: Optional[str] = None, +) -> metric_types.MetricComputations: + """Returns metric computations for computing the exact match score.""" + key = metric_types.MetricKey( + name=name, + model_name=model_name, + output_name=output_name, + sub_key=sub_key, + example_weighted=example_weighted, + ) + return [ + metric_types.MetricComputation( + keys=[key], + preprocessors=None, + combiner=_ExactMatchCombiner( + key, + eval_config, + aggregation_type, + class_weights, + example_weighted, + convert_to, + ), + ) + ] class _ExactMatchAccumulator: - """Exact match accumulator.""" - __slots__ = ['total_weighted_exact_match_scores', 'total_weighted_examples'] + """Exact match accumulator.""" + + __slots__ = ["total_weighted_exact_match_scores", "total_weighted_examples"] - def __init__(self): - self.total_weighted_exact_match_scores = 0.0 - self.total_weighted_examples = 0.0 + def __init__(self): + self.total_weighted_exact_match_scores = 0.0 + self.total_weighted_examples = 0.0 - def __iadd__(self, other): - self.total_weighted_exact_match_scores += other.total_weighted_exact_match_scores - self.total_weighted_examples += other.total_weighted_examples - return self + def __iadd__(self, other): + self.total_weighted_exact_match_scores += ( + other.total_weighted_exact_match_scores + ) + self.total_weighted_examples += other.total_weighted_examples + return self class _ExactMatchCombiner(beam.CombineFn): - """Combines Exact Match scores.""" - - def __init__(self, key: metric_types.MetricKey, - eval_config: Optional[config_pb2.EvalConfig], - aggregation_type: Optional[metric_types.AggregationType], - class_weights: Optional[Dict[int, float]], - exampled_weighted: bool, convert_to: Optional[str]): - self._key = key - self._eval_config = eval_config - self._aggregation_type = aggregation_type - self._class_weights = class_weights - self._example_weighted = exampled_weighted - self._convert_to = convert_to - - def create_accumulator(self) -> _ExactMatchAccumulator: - return _ExactMatchAccumulator() - - def add_input( - self, accumulator: _ExactMatchAccumulator, - element: metric_types.StandardMetricInputs) -> _ExactMatchAccumulator: - for label, prediction, example_weight in ( - metric_util.to_label_prediction_example_weight( + """Combines Exact Match scores.""" + + def __init__( + self, + key: metric_types.MetricKey, + eval_config: Optional[config_pb2.EvalConfig], + aggregation_type: Optional[metric_types.AggregationType], + class_weights: Optional[Dict[int, float]], + exampled_weighted: bool, + convert_to: Optional[str], + ): + self._key = key + self._eval_config = eval_config + self._aggregation_type = aggregation_type + self._class_weights = class_weights + self._example_weighted = exampled_weighted + self._convert_to = convert_to + + def create_accumulator(self) -> _ExactMatchAccumulator: + return _ExactMatchAccumulator() + + def add_input( + self, + accumulator: _ExactMatchAccumulator, + element: metric_types.StandardMetricInputs, + ) -> _ExactMatchAccumulator: + for ( + label, + prediction, + example_weight, + ) in metric_util.to_label_prediction_example_weight( element, eval_config=self._eval_config, model_name=self._key.model_name, output_name=self._key.output_name, aggregation_type=self._aggregation_type, class_weights=self._class_weights, - example_weighted=self._example_weighted)): - label = label.tolist() - prediction = prediction.tolist() - if self._convert_to == _JSON: - label = [json.loads(l) for l in label] - prediction = [json.loads(p) for p in prediction] - match = [p == l for p, l in zip(prediction, label)] - score = int(all(match)) - example_weight = metric_util.safe_to_scalar(example_weight) - accumulator.total_weighted_exact_match_scores += score * example_weight - accumulator.total_weighted_examples += example_weight - return accumulator - - def merge_accumulators( - self, - accumulators: Iterable[_ExactMatchAccumulator]) -> _ExactMatchAccumulator: - accumulators = iter(accumulators) - result = next(accumulators) - for accumulator in accumulators: - result += accumulator - return result - - def extract_output( - self, accumulator: _ExactMatchAccumulator - ) -> Dict[metric_types.MetricKey, float]: - score = accumulator.total_weighted_exact_match_scores / accumulator.total_weighted_examples - return {self._key: score} + example_weighted=self._example_weighted, + ): + label = label.tolist() + prediction = prediction.tolist() + if self._convert_to == _JSON: + label = [json.loads(l) for l in label] + prediction = [json.loads(p) for p in prediction] + match = [p == l for p, l in zip(prediction, label)] + score = int(all(match)) + example_weight = metric_util.safe_to_scalar(example_weight) + accumulator.total_weighted_exact_match_scores += score * example_weight + accumulator.total_weighted_examples += example_weight + return accumulator + + def merge_accumulators( + self, accumulators: Iterable[_ExactMatchAccumulator] + ) -> _ExactMatchAccumulator: + accumulators = iter(accumulators) + result = next(accumulators) + for accumulator in accumulators: + result += accumulator + return result + + def extract_output( + self, accumulator: _ExactMatchAccumulator + ) -> Dict[metric_types.MetricKey, float]: + score = ( + accumulator.total_weighted_exact_match_scores + / accumulator.total_weighted_examples + ) + return {self._key: score} diff --git a/tensorflow_model_analysis/metrics/exact_match_test.py b/tensorflow_model_analysis/metrics/exact_match_test.py index bf8b5a2667..f78fbc2949 100644 --- a/tensorflow_model_analysis/metrics/exact_match_test.py +++ b/tensorflow_model_analysis/metrics/exact_match_test.py @@ -15,110 +15,112 @@ import json -from absl.testing import parameterized import apache_beam as beam -from apache_beam.testing import util import numpy as np import tensorflow as tf -from tensorflow_model_analysis.metrics import exact_match -from tensorflow_model_analysis.metrics import metric_util -from tensorflow_model_analysis.utils import test_util +from absl.testing import parameterized +from apache_beam.testing import util +from tensorflow_model_analysis.metrics import exact_match, metric_util +from tensorflow_model_analysis.utils import test_util -class ExactMatchTest( - test_util.TensorflowModelAnalysisTest, parameterized.TestCase -): - @parameterized.named_parameters(('text', False), ('json', True)) - def testExactMatchWithoutWeights(self, test_json): - convert_to = 'json' if test_json else None - computations = exact_match.ExactMatch(convert_to=convert_to).computations() - metric = computations[0] +class ExactMatchTest(test_util.TensorflowModelAnalysisTest, parameterized.TestCase): + @parameterized.named_parameters(("text", False), ("json", True)) + def testExactMatchWithoutWeights(self, test_json): + convert_to = "json" if test_json else None + computations = exact_match.ExactMatch(convert_to=convert_to).computations() + metric = computations[0] - def _maybe_convert_feature(f): - return json.dumps(f) if test_json else f + def _maybe_convert_feature(f): + return json.dumps(f) if test_json else f - example1 = { - 'labels': np.array([_maybe_convert_feature('Test 1 two 3')]), - 'predictions': np.array([_maybe_convert_feature('Test 1 two 3')]), - 'example_weights': np.array([1.0]), - } - example2 = { - 'labels': np.array([_maybe_convert_feature('Testing')]), - 'predictions': np.array([_maybe_convert_feature('Dog')]), - 'example_weights': np.array([1.0]), - } - example3 = { - 'labels': np.array([_maybe_convert_feature('Test 1 two 3') + ' ']), - 'predictions': np.array([_maybe_convert_feature('Test 1 two 3')]), - 'example_weights': np.array([1.0]), - } - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'Create' >> beam.Create([example1, example2, example3]) - | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs) - | 'AddSlice' >> beam.Map(lambda x: ((), x)) - | 'ComputeMetric' >> beam.CombinePerKey(metric.combiner)) + example1 = { + "labels": np.array([_maybe_convert_feature("Test 1 two 3")]), + "predictions": np.array([_maybe_convert_feature("Test 1 two 3")]), + "example_weights": np.array([1.0]), + } + example2 = { + "labels": np.array([_maybe_convert_feature("Testing")]), + "predictions": np.array([_maybe_convert_feature("Dog")]), + "example_weights": np.array([1.0]), + } + example3 = { + "labels": np.array([_maybe_convert_feature("Test 1 two 3") + " "]), + "predictions": np.array([_maybe_convert_feature("Test 1 two 3")]), + "example_weights": np.array([1.0]), + } + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "Create" >> beam.Create([example1, example2, example3]) + | "Process" >> beam.Map(metric_util.to_standard_metric_inputs) + | "AddSlice" >> beam.Map(lambda x: ((), x)) + | "ComputeMetric" >> beam.CombinePerKey(metric.combiner) + ) - # pylint: enable=no-value-for-parameter - def check_result(got): - try: - self.assertLen(got, 1) - got_slice_key, got_metrics = got[0] - self.assertEqual(got_slice_key, ()) - key = metric.keys[0] - # example1 is a perfect match (score 100) - # example2 is a complete miss (score 0) - # example3 is a perfect match for json but a miss for text. - # average score: 0.6666.. for json, 0.3333... for text. - score = 2.0 / 3 if test_json else 1.0 / 3 - self.assertDictElementsAlmostEqual( - got_metrics, {key: score}, places=5) - except AssertionError as err: - raise util.BeamAssertException(err) + # pylint: enable=no-value-for-parameter + def check_result(got): + try: + self.assertLen(got, 1) + got_slice_key, got_metrics = got[0] + self.assertEqual(got_slice_key, ()) + key = metric.keys[0] + # example1 is a perfect match (score 100) + # example2 is a complete miss (score 0) + # example3 is a perfect match for json but a miss for text. + # average score: 0.6666.. for json, 0.3333... for text. + score = 2.0 / 3 if test_json else 1.0 / 3 + self.assertDictElementsAlmostEqual( + got_metrics, {key: score}, places=5 + ) + except AssertionError as err: + raise util.BeamAssertException(err) - util.assert_that(result, check_result, label='result') + util.assert_that(result, check_result, label="result") - def testExactMatchScoreWithWeights(self): - computations = exact_match.ExactMatch().computations(example_weighted=True) - metric = computations[0] - example1 = { - 'labels': np.array(['Test 1 two 3']), - 'predictions': np.array(['Test 1 two 3']), - 'example_weights': np.array([3.0]), - } - example2 = { - 'labels': np.array(['Testing']), - 'predictions': np.array(['Dog']), - 'example_weights': np.array([1.0]), - } - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'Create' >> beam.Create([example1, example2]) - | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs) - | 'AddSlice' >> beam.Map(lambda x: ((), x)) - | 'ComputeMetric' >> beam.CombinePerKey(metric.combiner)) + def testExactMatchScoreWithWeights(self): + computations = exact_match.ExactMatch().computations(example_weighted=True) + metric = computations[0] + example1 = { + "labels": np.array(["Test 1 two 3"]), + "predictions": np.array(["Test 1 two 3"]), + "example_weights": np.array([3.0]), + } + example2 = { + "labels": np.array(["Testing"]), + "predictions": np.array(["Dog"]), + "example_weights": np.array([1.0]), + } + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "Create" >> beam.Create([example1, example2]) + | "Process" >> beam.Map(metric_util.to_standard_metric_inputs) + | "AddSlice" >> beam.Map(lambda x: ((), x)) + | "ComputeMetric" >> beam.CombinePerKey(metric.combiner) + ) - # pylint: enable=no-value-for-parameter - def check_result(got): - try: - self.assertLen(got, 1) - got_slice_key, got_metrics = got[0] - self.assertEqual(got_slice_key, ()) - key = metric.keys[0] - # example1 is a perfect match (score 100) - # example2 is a complete miss (score 0) - # average score: (1*3 + 0*1) / 4 = 0.75 - self.assertDictElementsAlmostEqual(got_metrics, {key: 0.75}, places=5) - except AssertionError as err: - raise util.BeamAssertException(err) + # pylint: enable=no-value-for-parameter + def check_result(got): + try: + self.assertLen(got, 1) + got_slice_key, got_metrics = got[0] + self.assertEqual(got_slice_key, ()) + key = metric.keys[0] + # example1 is a perfect match (score 100) + # example2 is a complete miss (score 0) + # average score: (1*3 + 0*1) / 4 = 0.75 + self.assertDictElementsAlmostEqual( + got_metrics, {key: 0.75}, places=5 + ) + except AssertionError as err: + raise util.BeamAssertException(err) - util.assert_that(result, check_result, label='result') + util.assert_that(result, check_result, label="result") -if __name__ == '__main__': - tf.test.main() +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow_model_analysis/metrics/example_count.py b/tensorflow_model_analysis/metrics/example_count.py index 0481f4f998..bbe85d5ab2 100644 --- a/tensorflow_model_analysis/metrics/example_count.py +++ b/tensorflow_model_analysis/metrics/example_count.py @@ -13,45 +13,47 @@ # limitations under the License. """Example count metric.""" -from typing import Optional, Dict, Iterable, List +from typing import Dict, Iterable, List, Optional import apache_beam as beam import numpy as np + from tensorflow_model_analysis import constants from tensorflow_model_analysis.metrics import metric_types from tensorflow_model_analysis.utils import util -EXAMPLE_COUNT_NAME = 'example_count' +EXAMPLE_COUNT_NAME = "example_count" class ExampleCount(metric_types.Metric): - """Example count. - - Note that although the example_count is independent of the model, this metric - will be associated with a model for consistency with other metrics. - """ - - def __init__(self, name: str = EXAMPLE_COUNT_NAME): - """Initializes example count. + """Example count. - Args: - name: Metric name. + Note that although the example_count is independent of the model, this metric + will be associated with a model for consistency with other metrics. """ - super().__init__(example_count, name=name) + def __init__(self, name: str = EXAMPLE_COUNT_NAME): + """Initializes example count. - @property - def compute_confidence_interval(self) -> bool: - """Always disable confidence intervals for ExampleCount. + Args: + ---- + name: Metric name. + """ + super().__init__(example_count, name=name) - Confidence intervals capture uncertainty in a metric if it were computed on - more examples. For ExampleCount, this sort of uncertainty is not meaningful, - so confidence intervals are disabled. + @property + def compute_confidence_interval(self) -> bool: + """Always disable confidence intervals for ExampleCount. - Returns: - Whether to compute confidence intervals. - """ - return False + Confidence intervals capture uncertainty in a metric if it were computed on + more examples. For ExampleCount, this sort of uncertainty is not meaningful, + so confidence intervals are disabled. + + Returns + ------- + Whether to compute confidence intervals. + """ + return False metric_types.register_metric(ExampleCount) @@ -62,89 +64,92 @@ def example_count( model_names: Optional[List[str]] = None, output_names: Optional[List[str]] = None, sub_keys: Optional[List[metric_types.SubKey]] = None, - example_weighted: bool = False) -> metric_types.MetricComputations: - """Returns metric computations for example count.""" - computations = [] - for model_name in model_names or ['']: - for output_name in output_names or ['']: - keys = [] - for sub_key in sub_keys or [None]: - key = metric_types.MetricKey( - name=name, - model_name=model_name, - output_name=output_name, - sub_key=sub_key, - example_weighted=example_weighted) - keys.append(key) - - # Note: This cannot be implemented based on the weight stored in - # calibration because weighted example count is used with multi-class, etc - # models that do not use calibration metrics. - # The combiner only needs example weights in case users do not have - # predictions or labels. - computations.append( - metric_types.MetricComputation( - keys=keys, - preprocessors=[ - metric_types.StandardMetricInputsPreprocessor( - include_filter={constants.EXAMPLE_WEIGHTS_KEY: {}}, - include_default_inputs=False, - ) - ], - combiner=_ExampleCountCombiner( - model_name, output_name, keys, example_weighted - ), - ) - ) - return computations + example_weighted: bool = False, +) -> metric_types.MetricComputations: + """Returns metric computations for example count.""" + computations = [] + for model_name in model_names or [""]: + for output_name in output_names or [""]: + keys = [] + for sub_key in sub_keys or [None]: + key = metric_types.MetricKey( + name=name, + model_name=model_name, + output_name=output_name, + sub_key=sub_key, + example_weighted=example_weighted, + ) + keys.append(key) + + # Note: This cannot be implemented based on the weight stored in + # calibration because weighted example count is used with multi-class, etc + # models that do not use calibration metrics. + # The combiner only needs example weights in case users do not have + # predictions or labels. + computations.append( + metric_types.MetricComputation( + keys=keys, + preprocessors=[ + metric_types.StandardMetricInputsPreprocessor( + include_filter={constants.EXAMPLE_WEIGHTS_KEY: {}}, + include_default_inputs=False, + ) + ], + combiner=_ExampleCountCombiner( + model_name, output_name, keys, example_weighted + ), + ) + ) + return computations class _ExampleCountCombiner(beam.CombineFn): - """Computes example count.""" - - def __init__( - self, - model_name: str, - output_name: str, - keys: List[metric_types.MetricKey], - example_weighted, - ): - self._model_name = model_name - self._output_name = output_name - self._keys = keys - self._example_weighted = example_weighted - - def create_accumulator(self) -> float: - return 0.0 - - def add_input(self, accumulator: float, - element: metric_types.StandardMetricInputs) -> float: - if not self._example_weighted or element.example_weight is None: - example_weight = np.array(1.0) - else: - example_weight = element.example_weight - if isinstance(example_weight, dict) and self._model_name: - value = util.get_by_keys( - example_weight, [self._model_name], optional=True) - if value is not None: - example_weight = value - if isinstance(example_weight, dict) and self._output_name: - example_weight = util.get_by_keys(example_weight, [self._output_name], - np.array(1.0)) - if isinstance(example_weight, dict): - raise ValueError( - f'example_count cannot be calculated on a dict {example_weight}: ' - f'model_name={self._model_name}, output_name={self._output_name}.\n\n' - 'This is most likely a configuration error (for multi-output models' - 'a separate metric is needed for each output).') - return accumulator + np.sum(example_weight) - - def merge_accumulators(self, accumulators: Iterable[float]) -> float: - result = 0.0 - for accumulator in accumulators: - result += accumulator - return result - - def extract_output(self, - accumulator: float) -> Dict[metric_types.MetricKey, float]: - return {k: accumulator for k in self._keys} + """Computes example count.""" + + def __init__( + self, + model_name: str, + output_name: str, + keys: List[metric_types.MetricKey], + example_weighted, + ): + self._model_name = model_name + self._output_name = output_name + self._keys = keys + self._example_weighted = example_weighted + + def create_accumulator(self) -> float: + return 0.0 + + def add_input( + self, accumulator: float, element: metric_types.StandardMetricInputs + ) -> float: + if not self._example_weighted or element.example_weight is None: + example_weight = np.array(1.0) + else: + example_weight = element.example_weight + if isinstance(example_weight, dict) and self._model_name: + value = util.get_by_keys(example_weight, [self._model_name], optional=True) + if value is not None: + example_weight = value + if isinstance(example_weight, dict) and self._output_name: + example_weight = util.get_by_keys( + example_weight, [self._output_name], np.array(1.0) + ) + if isinstance(example_weight, dict): + raise ValueError( + f"example_count cannot be calculated on a dict {example_weight}: " + f"model_name={self._model_name}, output_name={self._output_name}.\n\n" + "This is most likely a configuration error (for multi-output models" + "a separate metric is needed for each output)." + ) + return accumulator + np.sum(example_weight) + + def merge_accumulators(self, accumulators: Iterable[float]) -> float: + result = 0.0 + for accumulator in accumulators: + result += accumulator + return result + + def extract_output(self, accumulator: float) -> Dict[metric_types.MetricKey, float]: + return {k: accumulator for k in self._keys} diff --git a/tensorflow_model_analysis/metrics/example_count_test.py b/tensorflow_model_analysis/metrics/example_count_test.py index 5e11515743..482a19c384 100644 --- a/tensorflow_model_analysis/metrics/example_count_test.py +++ b/tensorflow_model_analysis/metrics/example_count_test.py @@ -13,91 +13,93 @@ # limitations under the License. """Tests for example count metric.""" -from absl.testing import parameterized import apache_beam as beam -from apache_beam.testing import util import numpy as np import tensorflow as tf +from absl.testing import parameterized +from apache_beam.testing import util +from google.protobuf import text_format + import tensorflow_model_analysis as tfma +from tensorflow_model_analysis.metrics import example_count, metric_types, metric_util from tensorflow_model_analysis.proto import config_pb2 -from tensorflow_model_analysis.metrics import example_count -from tensorflow_model_analysis.metrics import metric_types -from tensorflow_model_analysis.metrics import metric_util from tensorflow_model_analysis.utils import test_util -from google.protobuf import text_format +class ExampleCountTest(test_util.TensorflowModelAnalysisTest, parameterized.TestCase): + @parameterized.named_parameters( + ("unweighted", "", "", False), + ("basic", "", "", True), + ("multi-model", "model", "", True), + ("multi-output", "", "output", True), + ("multi-model-multi-output", "model", "output", True), + ) + def testExampleCount(self, model_name, output_name, example_weighted): + metric = example_count.ExampleCount().computations( + model_names=[model_name], + output_names=[output_name], + example_weighted=example_weighted, + )[0] + + example0 = {"labels": None, "predictions": None, "example_weights": [0.0]} + example1 = {"labels": None, "predictions": None, "example_weights": [0.5]} + example2 = {"labels": None, "predictions": None, "example_weights": [1.0]} + example3 = {"labels": None, "predictions": None, "example_weights": [0.7]} + + if output_name: + example0["example_weights"] = {output_name: example0["example_weights"]} + example1["example_weights"] = {output_name: example1["example_weights"]} + example2["example_weights"] = {output_name: example2["example_weights"]} + example3["example_weights"] = {output_name: example3["example_weights"]} + + if model_name: + example0["example_weights"] = {model_name: example0["example_weights"]} + example1["example_weights"] = {model_name: example1["example_weights"]} + example2["example_weights"] = {model_name: example2["example_weights"]} + example3["example_weights"] = {model_name: example3["example_weights"]} + + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "Create" >> beam.Create([example0, example1, example2, example3]) + | "Process" >> beam.Map(metric_util.to_standard_metric_inputs) + | "AddSlice" >> beam.Map(lambda x: ((), x)) + | "ComputeMetric" >> beam.CombinePerKey(metric.combiner) + ) -class ExampleCountTest( - test_util.TensorflowModelAnalysisTest, parameterized.TestCase -): - - @parameterized.named_parameters( - ('unweighted', '', '', False), ('basic', '', '', True), - ('multi-model', 'model', '', True), ('multi-output', '', 'output', True), - ('multi-model-multi-output', 'model', 'output', True)) - def testExampleCount(self, model_name, output_name, example_weighted): - metric = example_count.ExampleCount().computations( - model_names=[model_name], - output_names=[output_name], - example_weighted=example_weighted)[0] - - example0 = {'labels': None, 'predictions': None, 'example_weights': [0.0]} - example1 = {'labels': None, 'predictions': None, 'example_weights': [0.5]} - example2 = {'labels': None, 'predictions': None, 'example_weights': [1.0]} - example3 = {'labels': None, 'predictions': None, 'example_weights': [0.7]} - - if output_name: - example0['example_weights'] = {output_name: example0['example_weights']} - example1['example_weights'] = {output_name: example1['example_weights']} - example2['example_weights'] = {output_name: example2['example_weights']} - example3['example_weights'] = {output_name: example3['example_weights']} - - if model_name: - example0['example_weights'] = {model_name: example0['example_weights']} - example1['example_weights'] = {model_name: example1['example_weights']} - example2['example_weights'] = {model_name: example2['example_weights']} - example3['example_weights'] = {model_name: example3['example_weights']} - - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'Create' >> beam.Create([example0, example1, example2, example3]) - | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs) - | 'AddSlice' >> beam.Map(lambda x: ((), x)) - | 'ComputeMetric' >> beam.CombinePerKey(metric.combiner)) - - # pylint: enable=no-value-for-parameter - - def check_result(got): - try: - self.assertLen(got, 1) - got_slice_key, got_metrics = got[0] - self.assertEqual(got_slice_key, ()) - example_count_key = metric_types.MetricKey( - name='example_count', - model_name=model_name, - output_name=output_name, - example_weighted=example_weighted) - if example_weighted: - self.assertDictElementsAlmostEqual( - got_metrics, {example_count_key: (0.0 + 0.5 + 1.0 + 0.7)}) - else: - self.assertDictElementsAlmostEqual( - got_metrics, {example_count_key: (1.0 + 1.0 + 1.0 + 1.0)}) - - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(result, check_result, label='result') + # pylint: enable=no-value-for-parameter + + def check_result(got): + try: + self.assertLen(got, 1) + got_slice_key, got_metrics = got[0] + self.assertEqual(got_slice_key, ()) + example_count_key = metric_types.MetricKey( + name="example_count", + model_name=model_name, + output_name=output_name, + example_weighted=example_weighted, + ) + if example_weighted: + self.assertDictElementsAlmostEqual( + got_metrics, {example_count_key: (0.0 + 0.5 + 1.0 + 0.7)} + ) + else: + self.assertDictElementsAlmostEqual( + got_metrics, {example_count_key: (1.0 + 1.0 + 1.0 + 1.0)} + ) + + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(result, check_result, label="result") class ExampleCountEnd2EndTest(parameterized.TestCase): - - def testExampleCountsWithoutLabelPredictions(self): - eval_config = text_format.Parse( - """ + def testExampleCountsWithoutLabelPredictions(self): + eval_config = text_format.Parse( + """ model_specs { signature_name: "serving_default" example_weight_key: "example_weights" @@ -110,56 +112,54 @@ def testExampleCountsWithoutLabelPredictions(self): } } """, - config_pb2.EvalConfig(), - ) - name_list = ['example_count'] - expected_results = [0.6] - extracts = [ - { - 'features': { - 'example_weights': np.array([0.5]), - } - }, - {'features': {}, 'example_weights': np.array([0.1])}, - ] - - evaluators = tfma.default_evaluators(eval_config=eval_config) - extractors = tfma.default_extractors( - eval_shared_model=None, eval_config=eval_config - ) - - with beam.Pipeline() as p: - result = ( - p - | 'LoadData' >> beam.Create(extracts) - | 'ExtractEval' - >> tfma.ExtractAndEvaluate( - extractors=extractors, evaluators=evaluators - ) - ) - - def check_result(got): - try: - self.assertLen(got, 1) - got_slice_key, got_metrics = got[0] - self.assertEqual(got_slice_key, ()) - self.assertLen(got_metrics, len(name_list)) - for name, expected_result in zip(name_list, expected_results): - key = metric_types.MetricKey(name=name, example_weighted=True) - self.assertIn(key, got_metrics) - got_metric = got_metrics[key] - np.testing.assert_allclose( - expected_result, - got_metric, - rtol=1e-3, - err_msg=f'This {name} metric fails.', + config_pb2.EvalConfig(), + ) + name_list = ["example_count"] + expected_results = [0.6] + extracts = [ + { + "features": { + "example_weights": np.array([0.5]), + } + }, + {"features": {}, "example_weights": np.array([0.1])}, + ] + + evaluators = tfma.default_evaluators(eval_config=eval_config) + extractors = tfma.default_extractors( + eval_shared_model=None, eval_config=eval_config + ) + + with beam.Pipeline() as p: + result = ( + p + | "LoadData" >> beam.Create(extracts) + | "ExtractEval" + >> tfma.ExtractAndEvaluate(extractors=extractors, evaluators=evaluators) ) - except AssertionError as err: - raise util.BeamAssertException(err) - - self.assertIn('metrics', result) - util.assert_that(result['metrics'], check_result, label='result') - -if __name__ == '__main__': - tf.test.main() + def check_result(got): + try: + self.assertLen(got, 1) + got_slice_key, got_metrics = got[0] + self.assertEqual(got_slice_key, ()) + self.assertLen(got_metrics, len(name_list)) + for name, expected_result in zip(name_list, expected_results): + key = metric_types.MetricKey(name=name, example_weighted=True) + self.assertIn(key, got_metrics) + got_metric = got_metrics[key] + np.testing.assert_allclose( + expected_result, + got_metric, + rtol=1e-3, + err_msg=f"This {name} metric fails.", + ) + except AssertionError as err: + raise util.BeamAssertException(err) + + self.assertIn("metrics", result) + util.assert_that(result["metrics"], check_result, label="result") + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow_model_analysis/metrics/flip_metrics.py b/tensorflow_model_analysis/metrics/flip_metrics.py index 4e4db05112..8fe94f636d 100644 --- a/tensorflow_model_analysis/metrics/flip_metrics.py +++ b/tensorflow_model_analysis/metrics/flip_metrics.py @@ -14,27 +14,26 @@ """Flip rate metrics.""" import abc -from collections.abc import Iterable import dataclasses import functools +from collections.abc import Iterable from typing import Any, Callable, Optional import apache_beam as beam -from tensorflow_model_analysis.metrics import metric_types -from tensorflow_model_analysis.metrics import metric_util + +from tensorflow_model_analysis.metrics import metric_types, metric_util from tensorflow_model_analysis.proto import config_pb2 from tensorflow_model_analysis.utils import model_util - # Flip Metrics Names -FLIP_RATE_NAME = 'flip_rate' # Symmetric Flip Rate Name in BooleanFlipRates(). -SYMMETRIC_FLIP_RATE_NAME = 'symmetric_flip_rate' # In SymmetricFlipRates(). -NEG_TO_NEG_FLIP_RATE_NAME = 'neg_to_neg_flip_rate' -NEG_TO_POS_FLIP_RATE_NAME = 'neg_to_pos_flip_rate' -POS_TO_NEG_FLIP_RATE_NAME = 'pos_to_neg_flip_rate' -POS_TO_POS_FLIP_RATE_NAME = 'pos_to_pos_flip_rate' +FLIP_RATE_NAME = "flip_rate" # Symmetric Flip Rate Name in BooleanFlipRates(). +SYMMETRIC_FLIP_RATE_NAME = "symmetric_flip_rate" # In SymmetricFlipRates(). +NEG_TO_NEG_FLIP_RATE_NAME = "neg_to_neg_flip_rate" +NEG_TO_POS_FLIP_RATE_NAME = "neg_to_pos_flip_rate" +POS_TO_NEG_FLIP_RATE_NAME = "pos_to_neg_flip_rate" +POS_TO_POS_FLIP_RATE_NAME = "pos_to_pos_flip_rate" -_FLIP_COUNTS_BASE_NAME = '_flip_counts' # flip_counts_computation name. +_FLIP_COUNTS_BASE_NAME = "_flip_counts" # flip_counts_computation name. _DEFAULT_FLIP_RATE_THRESHOLD = 0.5 @@ -42,110 +41,110 @@ @dataclasses.dataclass class _BooleanFlipCountsAccumulator: - """Accumulator for computing BooleanFlipRates.""" + """Accumulator for computing BooleanFlipRates.""" - num_weighted_examples: float = 0.0 - num_weighted_neg_to_neg: float = 0.0 - num_weighted_neg_to_pos: float = 0.0 - num_weighted_pos_to_neg: float = 0.0 - num_weighted_pos_to_pos: float = 0.0 + num_weighted_examples: float = 0.0 + num_weighted_neg_to_neg: float = 0.0 + num_weighted_neg_to_pos: float = 0.0 + num_weighted_pos_to_neg: float = 0.0 + num_weighted_pos_to_pos: float = 0.0 - def merge(self, other: '_BooleanFlipCountsAccumulator'): - self.num_weighted_examples += other.num_weighted_examples - self.num_weighted_neg_to_neg += other.num_weighted_neg_to_neg - self.num_weighted_neg_to_pos += other.num_weighted_neg_to_pos - self.num_weighted_pos_to_neg += other.num_weighted_pos_to_neg - self.num_weighted_pos_to_pos += other.num_weighted_pos_to_pos + def merge(self, other: "_BooleanFlipCountsAccumulator"): + self.num_weighted_examples += other.num_weighted_examples + self.num_weighted_neg_to_neg += other.num_weighted_neg_to_neg + self.num_weighted_neg_to_pos += other.num_weighted_neg_to_pos + self.num_weighted_pos_to_neg += other.num_weighted_pos_to_neg + self.num_weighted_pos_to_pos += other.num_weighted_pos_to_pos class _BooleanFlipCountsCombiner(beam.CombineFn): - """A combiner that computes the counts needed to calculate the Flip Rates.""" - - def __init__( - self, - key: metric_types.MetricKey, - eval_config: config_pb2.EvalConfig, - baseline_model_name: str, - model_name: str, - output_name: str, - example_weighted: bool, - threshold: float, - ): - self._key = key - self._eval_config = eval_config - self._baseline_model_name = baseline_model_name - self._model_name = model_name - self._output_name = output_name - self._example_weighted = example_weighted - self._threshold = threshold - - def create_accumulator(self) -> _BooleanFlipCountsAccumulator: - return _BooleanFlipCountsAccumulator() - - def add_input( - self, - accumulator: _BooleanFlipCountsAccumulator, - element: metric_types.StandardMetricInputs, - ) -> _BooleanFlipCountsAccumulator: - _, base_prediction, base_example_weight = next( - metric_util.to_label_prediction_example_weight( - inputs=element, - eval_config=self._eval_config, - model_name=self._baseline_model_name, - output_name=self._output_name, - example_weighted=self._example_weighted, - flatten=True, - allow_none=True, + """A combiner that computes the counts needed to calculate the Flip Rates.""" + + def __init__( + self, + key: metric_types.MetricKey, + eval_config: config_pb2.EvalConfig, + baseline_model_name: str, + model_name: str, + output_name: str, + example_weighted: bool, + threshold: float, + ): + self._key = key + self._eval_config = eval_config + self._baseline_model_name = baseline_model_name + self._model_name = model_name + self._output_name = output_name + self._example_weighted = example_weighted + self._threshold = threshold + + def create_accumulator(self) -> _BooleanFlipCountsAccumulator: + return _BooleanFlipCountsAccumulator() + + def add_input( + self, + accumulator: _BooleanFlipCountsAccumulator, + element: metric_types.StandardMetricInputs, + ) -> _BooleanFlipCountsAccumulator: + _, base_prediction, base_example_weight = next( + metric_util.to_label_prediction_example_weight( + inputs=element, + eval_config=self._eval_config, + model_name=self._baseline_model_name, + output_name=self._output_name, + example_weighted=self._example_weighted, + flatten=True, + allow_none=True, + ) ) - ) - _, model_prediction, _ = next( - metric_util.to_label_prediction_example_weight( - inputs=element, - eval_config=self._eval_config, - model_name=self._model_name, - output_name=self._output_name, - example_weighted=self._example_weighted, - flatten=True, - allow_none=True, + _, model_prediction, _ = next( + metric_util.to_label_prediction_example_weight( + inputs=element, + eval_config=self._eval_config, + model_name=self._model_name, + output_name=self._output_name, + example_weighted=self._example_weighted, + flatten=True, + allow_none=True, + ) ) - ) - base_example_weight = metric_util.safe_to_scalar(base_example_weight) - base_prediciton_bool = base_prediction > self._threshold - model_prediction_bool = model_prediction > self._threshold - - accumulator.merge( - _BooleanFlipCountsAccumulator( - num_weighted_examples=base_example_weight, - num_weighted_neg_to_neg=base_example_weight - * int(not base_prediciton_bool and not model_prediction_bool), - num_weighted_neg_to_pos=base_example_weight - * int(not base_prediciton_bool and model_prediction_bool), - num_weighted_pos_to_neg=base_example_weight - * int(base_prediciton_bool and not model_prediction_bool), - num_weighted_pos_to_pos=base_example_weight - * int(base_prediciton_bool and model_prediction_bool), + base_example_weight = metric_util.safe_to_scalar(base_example_weight) + base_prediciton_bool = base_prediction > self._threshold + model_prediction_bool = model_prediction > self._threshold + + accumulator.merge( + _BooleanFlipCountsAccumulator( + num_weighted_examples=base_example_weight, + num_weighted_neg_to_neg=base_example_weight + * int(not base_prediciton_bool and not model_prediction_bool), + num_weighted_neg_to_pos=base_example_weight + * int(not base_prediciton_bool and model_prediction_bool), + num_weighted_pos_to_neg=base_example_weight + * int(base_prediciton_bool and not model_prediction_bool), + num_weighted_pos_to_pos=base_example_weight + * int(base_prediciton_bool and model_prediction_bool), + ) ) - ) - return accumulator + return accumulator - def merge_accumulators( - self, accumulators: Iterable[_BooleanFlipCountsAccumulator] - ) -> _BooleanFlipCountsAccumulator: - result = next(iter(accumulators)) + def merge_accumulators( + self, accumulators: Iterable[_BooleanFlipCountsAccumulator] + ) -> _BooleanFlipCountsAccumulator: + result = next(iter(accumulators)) - for accumulator in accumulators: - result.merge(accumulator) + for accumulator in accumulators: + result.merge(accumulator) - return result + return result - def extract_output( - self, accumulator: _BooleanFlipCountsAccumulator - ) -> dict[metric_types.MetricKey, _BooleanFlipCountsAccumulator]: - # We return a _BooleanFlipCountsAccumulator here, not a metric value. - return {self._key: accumulator} + def extract_output( + self, accumulator: _BooleanFlipCountsAccumulator + ) -> dict[metric_types.MetricKey, _BooleanFlipCountsAccumulator]: + # We return a _BooleanFlipCountsAccumulator here, not a metric value. + return {self._key: accumulator} def _flip_counts( @@ -156,337 +155,334 @@ def _flip_counts( baseline_model_name: str, threshold: float, ) -> metric_types.MetricComputation: - """Returns the metric computations for calculating the boolean flip rates. - - Args: - model_name: The model for which to compute this metric. - output_name: The output name for which to compute this metric. - example_weighted: Whether to compute this metric using example weights. - eval_config: The EvalConfig for this TFMA evaluation. This is used to - identify which model is the baseline. - baseline_model_name: The baseline model to compare the model to. - threshold: The threshold to use for converting both the baseline and - candidate predictions into boolean values that can be compared. - """ - key = metric_types.MetricKey( - name=metric_util.generate_private_name_from_arguments( - name=_FLIP_COUNTS_BASE_NAME, - eval_config=eval_config, - baseline_model_name=baseline_model_name, - threshold=threshold, - ), - model_name=model_name, - output_name=output_name, - example_weighted=example_weighted, - ) - - return metric_types.MetricComputation( - keys=[key], - preprocessors=None, - combiner=_BooleanFlipCountsCombiner( - key=key, - eval_config=eval_config, - baseline_model_name=baseline_model_name, - model_name=model_name, - output_name=output_name, - example_weighted=example_weighted, - threshold=threshold, - ), - ) - - -class _FlipRateBase(metric_types.Metric, abc.ABC): - """Base class to generate the computations for all individual flip rates.""" - - def __init__(self, name, threshold): - super().__init__( - self._metric_computations, - name=name, - threshold=threshold, - ) - - @abc.abstractmethod - def result(self, flip_counts: _BooleanFlipCountsAccumulator) -> float: - """This method will be overriden in each Individual Metric Class. + """Returns the metric computations for calculating the boolean flip rates. Args: - flip_counts: A _BooleanFlipCountsAccumulator containing the necessary - counts to calculate the individual flip rate. - - Returns: - The individual flip rate. - """ - pass - - def _get_derived_metric_result_fn( - self, - metrics: dict[metric_types.MetricKey, Any], - flip_counts_key: metric_types.MetricKey, - metric_key: metric_types.MetricKey, - calculate_flip_rate_fn: Callable[[_BooleanFlipCountsAccumulator], float], - ) -> dict[metric_types.MetricKey, float]: - """Generates the result() function for the Derived Metric Computations. - - Args: - metrics: All the metrics (computed by any computation) including the - individual flip rate metric. - flip_counts_key: The key of the flip counts computation in "metrics". - metric_key: The key of this metric computation. - calculate_flip_rate_fn: A function that calculates the necessary flip - rate. - - Returns: - The result() function. - """ - - def _result() -> dict[metric_types.MetricKey, float]: - # We only need the accumulator to calculate the result. - flip_counts = metrics[flip_counts_key] - - return {metric_key: calculate_flip_rate_fn(flip_counts)} - - return _result() - - def _metric_computations( - self, - name: str, - eval_config: config_pb2.EvalConfig, - example_weighted: bool, - threshold: float, - model_names: Iterable[str], - output_names: Optional[Iterable[str]] = ('',), - sub_keys: Optional[Iterable[metric_types.SubKey]] = None, - ) -> metric_types.MetricComputations: - """Returns metric computations for an individual boolean flip rate. - - This is not meant to be used with merge_per_key_computations because we - don't want to create computations for the baseline model, and we want to - provide the baseline model name to each Combiner - - Args: - name: Metric name for individual flip rate. + ---- + model_name: The model for which to compute this metric. + output_name: The output name for which to compute this metric. + example_weighted: Whether to compute this metric using example weights. eval_config: The EvalConfig for this TFMA evaluation. This is used to identify which model is the baseline. - example_weighted: Whether to compute this metric using example weights. + baseline_model_name: The baseline model to compare the model to. threshold: The threshold to use for converting both the baseline and candidate predictions into boolean values that can be compared. - model_names: The name of the baseline model and the candidate model. - output_names: The set of output names for which to compute this metric. - sub_keys: The set of sub_key settings for which to compute this metric. """ - computations = [] + key = metric_types.MetricKey( + name=metric_util.generate_private_name_from_arguments( + name=_FLIP_COUNTS_BASE_NAME, + eval_config=eval_config, + baseline_model_name=baseline_model_name, + threshold=threshold, + ), + model_name=model_name, + output_name=output_name, + example_weighted=example_weighted, + ) + + return metric_types.MetricComputation( + keys=[key], + preprocessors=None, + combiner=_BooleanFlipCountsCombiner( + key=key, + eval_config=eval_config, + baseline_model_name=baseline_model_name, + model_name=model_name, + output_name=output_name, + example_weighted=example_weighted, + threshold=threshold, + ), + ) - # Get the baseline model name. - baseline_spec = model_util.get_baseline_model_spec(eval_config) - baseline_model_name = baseline_spec.name if baseline_spec else None - - for candidate_model_name in model_names: - if candidate_model_name == baseline_model_name: - continue - for output_name in output_names: - for sub_key in sub_keys or (None,): - # Define the metric key. - metric_key = metric_types.MetricKey( - name=name, - model_name=candidate_model_name, - output_name=output_name, - sub_key=sub_key, - example_weighted=example_weighted, - is_diff=True, - ) - - flip_counts_computation = _flip_counts( - model_name=candidate_model_name, - output_name=output_name, - example_weighted=example_weighted, - eval_config=eval_config, - baseline_model_name=baseline_model_name, - threshold=threshold, - ) - - # Append flip counts to computations. - computations.append(flip_counts_computation) - - # Append flip rate (derived metric computation) to computations. - computations.append( - metric_types.DerivedMetricComputation( - keys=[metric_key], - result=functools.partial( - self._get_derived_metric_result_fn, - flip_counts_key=flip_counts_computation.keys[0], - metric_key=metric_key, - calculate_flip_rate_fn=self.result, - ), - ) - ) - return computations +class _FlipRateBase(metric_types.Metric, abc.ABC): + """Base class to generate the computations for all individual flip rates.""" + def __init__(self, name, threshold): + super().__init__( + self._metric_computations, + name=name, + threshold=threshold, + ) -class SymmetricFlipRate(_FlipRateBase): - """FlipRate is the rate at which predictions between models switch. + @abc.abstractmethod + def result(self, flip_counts: _BooleanFlipCountsAccumulator) -> float: + """This method will be overriden in each Individual Metric Class. + + Args: + ---- + flip_counts: A _BooleanFlipCountsAccumulator containing the necessary + counts to calculate the individual flip rate. + + Returns: + ------- + The individual flip rate. + """ + pass + + def _get_derived_metric_result_fn( + self, + metrics: dict[metric_types.MetricKey, Any], + flip_counts_key: metric_types.MetricKey, + metric_key: metric_types.MetricKey, + calculate_flip_rate_fn: Callable[[_BooleanFlipCountsAccumulator], float], + ) -> dict[metric_types.MetricKey, float]: + """Generates the result() function for the Derived Metric Computations. + + Args: + ---- + metrics: All the metrics (computed by any computation) including the + individual flip rate metric. + flip_counts_key: The key of the flip counts computation in "metrics". + metric_key: The key of this metric computation. + calculate_flip_rate_fn: A function that calculates the necessary flip + rate. + + Returns: + ------- + The result() function. + """ + + def _result() -> dict[metric_types.MetricKey, float]: + # We only need the accumulator to calculate the result. + flip_counts = metrics[flip_counts_key] + + return {metric_key: calculate_flip_rate_fn(flip_counts)} + + return _result() + + def _metric_computations( + self, + name: str, + eval_config: config_pb2.EvalConfig, + example_weighted: bool, + threshold: float, + model_names: Iterable[str], + output_names: Optional[Iterable[str]] = ("",), + sub_keys: Optional[Iterable[metric_types.SubKey]] = None, + ) -> metric_types.MetricComputations: + """Returns metric computations for an individual boolean flip rate. + + This is not meant to be used with merge_per_key_computations because we + don't want to create computations for the baseline model, and we want to + provide the baseline model name to each Combiner + + Args: + ---- + name: Metric name for individual flip rate. + eval_config: The EvalConfig for this TFMA evaluation. This is used to + identify which model is the baseline. + example_weighted: Whether to compute this metric using example weights. + threshold: The threshold to use for converting both the baseline and + candidate predictions into boolean values that can be compared. + model_names: The name of the baseline model and the candidate model. + output_names: The set of output names for which to compute this metric. + sub_keys: The set of sub_key settings for which to compute this metric. + """ + computations = [] + + # Get the baseline model name. + baseline_spec = model_util.get_baseline_model_spec(eval_config) + baseline_model_name = baseline_spec.name if baseline_spec else None + + for candidate_model_name in model_names: + if candidate_model_name == baseline_model_name: + continue + for output_name in output_names: + for sub_key in sub_keys or (None,): + # Define the metric key. + metric_key = metric_types.MetricKey( + name=name, + model_name=candidate_model_name, + output_name=output_name, + sub_key=sub_key, + example_weighted=example_weighted, + is_diff=True, + ) + + flip_counts_computation = _flip_counts( + model_name=candidate_model_name, + output_name=output_name, + example_weighted=example_weighted, + eval_config=eval_config, + baseline_model_name=baseline_model_name, + threshold=threshold, + ) + + # Append flip counts to computations. + computations.append(flip_counts_computation) + + # Append flip rate (derived metric computation) to computations. + computations.append( + metric_types.DerivedMetricComputation( + keys=[metric_key], + result=functools.partial( + self._get_derived_metric_result_fn, + flip_counts_key=flip_counts_computation.keys[0], + metric_key=metric_key, + calculate_flip_rate_fn=self.result, + ), + ) + ) + + return computations - Given a pair of models and a threshold for converting continuous model outputs - into boolean predictions, this metric will produce the symmetric flip rate - (i.e. the number of times the boolean predictions don't match, regardless of - the direction of the flip). - """ - def __init__( - self, - name: str = SYMMETRIC_FLIP_RATE_NAME, - threshold: float = _DEFAULT_FLIP_RATE_THRESHOLD, - ): - """Initializes BooleanFlipRates metric. +class SymmetricFlipRate(_FlipRateBase): + """FlipRate is the rate at which predictions between models switch. - Args: - name: Metric name for the symmetric flip rate. - threshold: The threshold to use for converting the model prediction into a - boolean value that can be used for comparison between models. + Given a pair of models and a threshold for converting continuous model outputs + into boolean predictions, this metric will produce the symmetric flip rate + (i.e. the number of times the boolean predictions don't match, regardless of + the direction of the flip). """ - super().__init__( - name=name, - threshold=threshold, - ) + def __init__( + self, + name: str = SYMMETRIC_FLIP_RATE_NAME, + threshold: float = _DEFAULT_FLIP_RATE_THRESHOLD, + ): + """Initializes BooleanFlipRates metric. + + Args: + ---- + name: Metric name for the symmetric flip rate. + threshold: The threshold to use for converting the model prediction into a + boolean value that can be used for comparison between models. + """ + super().__init__( + name=name, + threshold=threshold, + ) - def result(self, flip_counts: _BooleanFlipCountsAccumulator) -> float: - return ( - flip_counts.num_weighted_neg_to_pos - + flip_counts.num_weighted_pos_to_neg - ) / flip_counts.num_weighted_examples + def result(self, flip_counts: _BooleanFlipCountsAccumulator) -> float: + return ( + flip_counts.num_weighted_neg_to_pos + flip_counts.num_weighted_pos_to_neg + ) / flip_counts.num_weighted_examples class NegToNegFlipRate(_FlipRateBase): - """FlipRate is the rate at which predictions between models switch. - - Given a pair of models and a threshold for converting continuous model outputs - into boolean predictions, this metric will produce the neg-to-neg flip rate - (i.e. the rate at which the baseline model's and the candidate model's - predictions are both negative). - """ + """FlipRate is the rate at which predictions between models switch. - def __init__( - self, - name: str = NEG_TO_NEG_FLIP_RATE_NAME, - threshold: float = _DEFAULT_FLIP_RATE_THRESHOLD, - ): - """Initializes BooleanFlipRates metric. - - Args: - name: Metric name for the neg-to-neg flip rate. - threshold: The threshold to use for converting the model prediction into a - boolean value that can be used for comparison between models. + Given a pair of models and a threshold for converting continuous model outputs + into boolean predictions, this metric will produce the neg-to-neg flip rate + (i.e. the rate at which the baseline model's and the candidate model's + predictions are both negative). """ - super().__init__( - name=name, - threshold=threshold, - ) + def __init__( + self, + name: str = NEG_TO_NEG_FLIP_RATE_NAME, + threshold: float = _DEFAULT_FLIP_RATE_THRESHOLD, + ): + """Initializes BooleanFlipRates metric. + + Args: + ---- + name: Metric name for the neg-to-neg flip rate. + threshold: The threshold to use for converting the model prediction into a + boolean value that can be used for comparison between models. + """ + super().__init__( + name=name, + threshold=threshold, + ) - def result(self, flip_counts: _BooleanFlipCountsAccumulator) -> float: - return ( - flip_counts.num_weighted_neg_to_neg / flip_counts.num_weighted_examples - ) + def result(self, flip_counts: _BooleanFlipCountsAccumulator) -> float: + return flip_counts.num_weighted_neg_to_neg / flip_counts.num_weighted_examples class NegToPosFlipRate(_FlipRateBase): - """FlipRate is the rate at which predictions between models switch. - - Given a pair of models and a threshold for converting continuous model outputs - into boolean predictions, this metric will produce the neg-to-pos flip rate - (i.e. the rate at which the baseline model's boolean prediction is negative - and the candidate model's is positive). - """ - - def __init__( - self, - name: str = NEG_TO_POS_FLIP_RATE_NAME, - threshold: float = _DEFAULT_FLIP_RATE_THRESHOLD, - ): - """Initializes BooleanFlipRates metric. + """FlipRate is the rate at which predictions between models switch. - Args: - name: Metric name for the neg-to-pos flip rate. - threshold: The threshold to use for converting the model prediction into a - boolean value that can be used for comparison between models. + Given a pair of models and a threshold for converting continuous model outputs + into boolean predictions, this metric will produce the neg-to-pos flip rate + (i.e. the rate at which the baseline model's boolean prediction is negative + and the candidate model's is positive). """ - super().__init__( - name=name, - threshold=threshold, - ) + def __init__( + self, + name: str = NEG_TO_POS_FLIP_RATE_NAME, + threshold: float = _DEFAULT_FLIP_RATE_THRESHOLD, + ): + """Initializes BooleanFlipRates metric. + + Args: + ---- + name: Metric name for the neg-to-pos flip rate. + threshold: The threshold to use for converting the model prediction into a + boolean value that can be used for comparison between models. + """ + super().__init__( + name=name, + threshold=threshold, + ) - def result(self, flip_counts: _BooleanFlipCountsAccumulator) -> float: - return ( - flip_counts.num_weighted_neg_to_pos / flip_counts.num_weighted_examples - ) + def result(self, flip_counts: _BooleanFlipCountsAccumulator) -> float: + return flip_counts.num_weighted_neg_to_pos / flip_counts.num_weighted_examples class PosToNegFlipRate(_FlipRateBase): - """FlipRate is the rate at which predictions between models switch. - - Given a pair of models and a threshold for converting continuous model outputs - into boolean predictions, this metric will produce the pos-to-neg flip rate - (i.e. the rate at which the baseline model's boolean prediction is positive - and the candidate model's is negative). - """ + """FlipRate is the rate at which predictions between models switch. - def __init__( - self, - name: str = POS_TO_NEG_FLIP_RATE_NAME, - threshold: float = _DEFAULT_FLIP_RATE_THRESHOLD, - ): - """Initializes BooleanFlipRates metric. - - Args: - name: Metric name for the pos-to-neg flip rate. - threshold: The threshold to use for converting the model prediction into a - boolean value that can be used for comparison between models. + Given a pair of models and a threshold for converting continuous model outputs + into boolean predictions, this metric will produce the pos-to-neg flip rate + (i.e. the rate at which the baseline model's boolean prediction is positive + and the candidate model's is negative). """ - super().__init__( - name=name, - threshold=threshold, - ) + def __init__( + self, + name: str = POS_TO_NEG_FLIP_RATE_NAME, + threshold: float = _DEFAULT_FLIP_RATE_THRESHOLD, + ): + """Initializes BooleanFlipRates metric. + + Args: + ---- + name: Metric name for the pos-to-neg flip rate. + threshold: The threshold to use for converting the model prediction into a + boolean value that can be used for comparison between models. + """ + super().__init__( + name=name, + threshold=threshold, + ) - def result(self, flip_counts: _BooleanFlipCountsAccumulator) -> float: - return ( - flip_counts.num_weighted_pos_to_neg / flip_counts.num_weighted_examples - ) + def result(self, flip_counts: _BooleanFlipCountsAccumulator) -> float: + return flip_counts.num_weighted_pos_to_neg / flip_counts.num_weighted_examples class PosToPosFlipRate(_FlipRateBase): - """FlipRate is the rate at which predictions between models switch. + """FlipRate is the rate at which predictions between models switch. - Given a pair of models and a threshold for converting continuous model outputs - into boolean predictions, this metric will produce the pos-to-pos flip rate - (i.e. the rate at which the baseline model's and the candidate model's - predictions are both positive). - """ - - def __init__( - self, - name: str = POS_TO_POS_FLIP_RATE_NAME, - threshold: float = _DEFAULT_FLIP_RATE_THRESHOLD, - ): - """Initializes BooleanFlipRates metric. - - Args: - name: Metric name for the pos-to-pos flip rate. - threshold: The threshold to use for converting the model prediction into a - boolean value that can be used for comparison between models. + Given a pair of models and a threshold for converting continuous model outputs + into boolean predictions, this metric will produce the pos-to-pos flip rate + (i.e. the rate at which the baseline model's and the candidate model's + predictions are both positive). """ - super().__init__( - name=name, - threshold=threshold, - ) + def __init__( + self, + name: str = POS_TO_POS_FLIP_RATE_NAME, + threshold: float = _DEFAULT_FLIP_RATE_THRESHOLD, + ): + """Initializes BooleanFlipRates metric. + + Args: + ---- + name: Metric name for the pos-to-pos flip rate. + threshold: The threshold to use for converting the model prediction into a + boolean value that can be used for comparison between models. + """ + super().__init__( + name=name, + threshold=threshold, + ) - def result(self, flip_counts: _BooleanFlipCountsAccumulator) -> float: - return ( - flip_counts.num_weighted_pos_to_pos / flip_counts.num_weighted_examples - ) + def result(self, flip_counts: _BooleanFlipCountsAccumulator) -> float: + return flip_counts.num_weighted_pos_to_pos / flip_counts.num_weighted_examples def _boolean_flip_rates_computations( @@ -502,99 +498,15 @@ def _boolean_flip_rates_computations( sub_keys: Optional[list[Optional[metric_types.SubKey]]] = None, example_weighted: bool = False, ) -> metric_types.MetricComputations: - """Returns metric computations for all boolean flip rates. - - This is not meant to be used with merge_per_key_computations because we - don't want to create computations for the baseline model, and we want to - provide the baseline model name to each Combiner - - Args: - symmetric_flip_rate_name: Metric name for symmetric flip rate. - neg_to_neg_flip_rate_name: Metric name for the negative-to-negative flip - rate. - neg_to_pos_flip_rate_name: Metric name for the negative-to-positive flip - rate. - pos_to_neg_flip_rate_name: Metric name for the positive-to-negative flip - rate. - pos_to_pos_flip_rate_name: Metric name for the positive-to-positive flip - rate. - threshold: The threshold to use for converting both the baseline and - candidate predictions into boolean values that can be compared. - eval_config: The EvalConfig for this TFMA evaluation. This is used to - identify which model is the baseline. - model_names: The name of the baseline model and the candidate model. - output_names: The set of output names for which to compute this metric. - sub_keys: The set of sub_key settings for which to compute this metric. - example_weighted: Whether to compute this metric using example weights. - candidate predictions into boolean values that can be compared. - """ - symmetric_metric = SymmetricFlipRate( - name=symmetric_flip_rate_name, threshold=threshold - ) - neg_to_neg_metric = NegToNegFlipRate( - name=neg_to_neg_flip_rate_name, threshold=threshold - ) - neg_to_pos_metric = NegToPosFlipRate( - name=neg_to_pos_flip_rate_name, threshold=threshold - ) - pos_to_neg_metric = PosToNegFlipRate( - name=pos_to_neg_flip_rate_name, threshold=threshold - ) - pos_to_pos_metric = PosToPosFlipRate( - name=pos_to_pos_flip_rate_name, threshold=threshold - ) - - all_metrics = ( - symmetric_metric, - neg_to_neg_metric, - neg_to_pos_metric, - pos_to_neg_metric, - pos_to_pos_metric, - ) - - computations = [] - for metric in all_metrics: - computations += metric.computations( - eval_config=eval_config, - model_names=model_names, - output_names=output_names, - example_weighted=example_weighted, - sub_keys=sub_keys, - ) - - return computations - + """Returns metric computations for all boolean flip rates. -class BooleanFlipRates(metric_types.Metric): - """FlipRate is the rate at which predictions between models switch. - - Given a pair of models and a threshold for converting continuous model outputs - into boolean predictions, this metric will produce three numbers (keyed by - separate MetricKeys): - - - (symmetric) flip rate: The number of times the boolean predictions don't - match, regardless of the direction of the flip. - - negative-to-positive flip rate: The rate at which the baseline model's - boolean prediction is negative but the candidate model's is positive. - - positive-to-negative flip rate: The rate at which the baseline model's - boolean prediction is positive but the candidate model's is negative. - """ - - def __init__( - self, - threshold: float = _DEFAULT_FLIP_RATE_THRESHOLD, - flip_rate_name: str = FLIP_RATE_NAME, - neg_to_neg_flip_rate_name: str = NEG_TO_NEG_FLIP_RATE_NAME, - neg_to_pos_flip_rate_name: str = NEG_TO_POS_FLIP_RATE_NAME, - pos_to_neg_flip_rate_name: str = POS_TO_NEG_FLIP_RATE_NAME, - pos_to_pos_flip_rate_name: str = POS_TO_POS_FLIP_RATE_NAME, - ): - """Initializes BooleanFlipRates metric. + This is not meant to be used with merge_per_key_computations because we + don't want to create computations for the baseline model, and we want to + provide the baseline model name to each Combiner Args: - threshold: The threshold to use for converting the model prediction into a - boolean value that can be used for comparison between models. - flip_rate_name: Metric name for symmetric flip rate. + ---- + symmetric_flip_rate_name: Metric name for symmetric flip rate. neg_to_neg_flip_rate_name: Metric name for the negative-to-negative flip rate. neg_to_pos_flip_rate_name: Metric name for the negative-to-positive flip @@ -603,18 +515,103 @@ def __init__( rate. pos_to_pos_flip_rate_name: Metric name for the positive-to-positive flip rate. + threshold: The threshold to use for converting both the baseline and + candidate predictions into boolean values that can be compared. + eval_config: The EvalConfig for this TFMA evaluation. This is used to + identify which model is the baseline. + model_names: The name of the baseline model and the candidate model. + output_names: The set of output names for which to compute this metric. + sub_keys: The set of sub_key settings for which to compute this metric. + example_weighted: Whether to compute this metric using example weights. + candidate predictions into boolean values that can be compared. """ + symmetric_metric = SymmetricFlipRate( + name=symmetric_flip_rate_name, threshold=threshold + ) + neg_to_neg_metric = NegToNegFlipRate( + name=neg_to_neg_flip_rate_name, threshold=threshold + ) + neg_to_pos_metric = NegToPosFlipRate( + name=neg_to_pos_flip_rate_name, threshold=threshold + ) + pos_to_neg_metric = PosToNegFlipRate( + name=pos_to_neg_flip_rate_name, threshold=threshold + ) + pos_to_pos_metric = PosToPosFlipRate( + name=pos_to_pos_flip_rate_name, threshold=threshold + ) - super().__init__( - _boolean_flip_rates_computations, - symmetric_flip_rate_name=flip_rate_name, - neg_to_neg_flip_rate_name=neg_to_neg_flip_rate_name, - neg_to_pos_flip_rate_name=neg_to_pos_flip_rate_name, - pos_to_neg_flip_rate_name=pos_to_neg_flip_rate_name, - pos_to_pos_flip_rate_name=pos_to_pos_flip_rate_name, - threshold=threshold, + all_metrics = ( + symmetric_metric, + neg_to_neg_metric, + neg_to_pos_metric, + pos_to_neg_metric, + pos_to_pos_metric, ) + computations = [] + for metric in all_metrics: + computations += metric.computations( + eval_config=eval_config, + model_names=model_names, + output_names=output_names, + example_weighted=example_weighted, + sub_keys=sub_keys, + ) + + return computations + + +class BooleanFlipRates(metric_types.Metric): + """FlipRate is the rate at which predictions between models switch. + + Given a pair of models and a threshold for converting continuous model outputs + into boolean predictions, this metric will produce three numbers (keyed by + separate MetricKeys): + + - (symmetric) flip rate: The number of times the boolean predictions don't + match, regardless of the direction of the flip. + - negative-to-positive flip rate: The rate at which the baseline model's + boolean prediction is negative but the candidate model's is positive. + - positive-to-negative flip rate: The rate at which the baseline model's + boolean prediction is positive but the candidate model's is negative. + """ + + def __init__( + self, + threshold: float = _DEFAULT_FLIP_RATE_THRESHOLD, + flip_rate_name: str = FLIP_RATE_NAME, + neg_to_neg_flip_rate_name: str = NEG_TO_NEG_FLIP_RATE_NAME, + neg_to_pos_flip_rate_name: str = NEG_TO_POS_FLIP_RATE_NAME, + pos_to_neg_flip_rate_name: str = POS_TO_NEG_FLIP_RATE_NAME, + pos_to_pos_flip_rate_name: str = POS_TO_POS_FLIP_RATE_NAME, + ): + """Initializes BooleanFlipRates metric. + + Args: + ---- + threshold: The threshold to use for converting the model prediction into a + boolean value that can be used for comparison between models. + flip_rate_name: Metric name for symmetric flip rate. + neg_to_neg_flip_rate_name: Metric name for the negative-to-negative flip + rate. + neg_to_pos_flip_rate_name: Metric name for the negative-to-positive flip + rate. + pos_to_neg_flip_rate_name: Metric name for the positive-to-negative flip + rate. + pos_to_pos_flip_rate_name: Metric name for the positive-to-positive flip + rate. + """ + super().__init__( + _boolean_flip_rates_computations, + symmetric_flip_rate_name=flip_rate_name, + neg_to_neg_flip_rate_name=neg_to_neg_flip_rate_name, + neg_to_pos_flip_rate_name=neg_to_pos_flip_rate_name, + pos_to_neg_flip_rate_name=pos_to_neg_flip_rate_name, + pos_to_pos_flip_rate_name=pos_to_pos_flip_rate_name, + threshold=threshold, + ) + # Register Individual Metrics. metric_types.register_metric(SymmetricFlipRate) diff --git a/tensorflow_model_analysis/metrics/flip_metrics_test.py b/tensorflow_model_analysis/metrics/flip_metrics_test.py index 6cd1130744..2f9e13cada 100644 --- a/tensorflow_model_analysis/metrics/flip_metrics_test.py +++ b/tensorflow_model_analysis/metrics/flip_metrics_test.py @@ -15,57 +15,53 @@ import copy -from absl.testing import absltest -from absl.testing import parameterized import apache_beam as beam +from absl.testing import absltest, parameterized from apache_beam.testing import util +from google.protobuf import text_format + from tensorflow_model_analysis import constants -from tensorflow_model_analysis.metrics import flip_metrics -from tensorflow_model_analysis.metrics import metric_types -from tensorflow_model_analysis.metrics import metric_util +from tensorflow_model_analysis.metrics import flip_metrics, metric_types, metric_util from tensorflow_model_analysis.proto import config_pb2 from tensorflow_model_analysis.writers import metrics_plots_and_validations_writer -from google.protobuf import text_format - class FlipRateMetricsTest(parameterized.TestCase): - - @parameterized.named_parameters( - dict( - testcase_name='symmetric_flip_rate', - metric=flip_metrics.SymmetricFlipRate(threshold=0.5), - metric_name=flip_metrics.SYMMETRIC_FLIP_RATE_NAME, - expected_result=3 / 10, - ), - dict( - testcase_name='neg_to_neg_flip_rate', - metric=flip_metrics.NegToNegFlipRate(threshold=0.5), - metric_name=flip_metrics.NEG_TO_NEG_FLIP_RATE_NAME, - expected_result=3 / 10, - ), - dict( - testcase_name='neg_to_pos_flip_rate', - metric=flip_metrics.NegToPosFlipRate(threshold=0.5), - metric_name=flip_metrics.NEG_TO_POS_FLIP_RATE_NAME, - expected_result=1 / 10, - ), - dict( - testcase_name='pos_to_neg_flip_rate', - metric=flip_metrics.PosToNegFlipRate(threshold=0.5), - metric_name=flip_metrics.POS_TO_NEG_FLIP_RATE_NAME, - expected_result=2 / 10, - ), - dict( - testcase_name='pos_to_pos_flip_rate', - metric=flip_metrics.PosToPosFlipRate(threshold=0.5), - metric_name=flip_metrics.POS_TO_POS_FLIP_RATE_NAME, - expected_result=4 / 10, - ), - ) - def testIndividualFlipRates(self, metric, metric_name, expected_result): - eval_config = text_format.Parse( - """ + @parameterized.named_parameters( + dict( + testcase_name="symmetric_flip_rate", + metric=flip_metrics.SymmetricFlipRate(threshold=0.5), + metric_name=flip_metrics.SYMMETRIC_FLIP_RATE_NAME, + expected_result=3 / 10, + ), + dict( + testcase_name="neg_to_neg_flip_rate", + metric=flip_metrics.NegToNegFlipRate(threshold=0.5), + metric_name=flip_metrics.NEG_TO_NEG_FLIP_RATE_NAME, + expected_result=3 / 10, + ), + dict( + testcase_name="neg_to_pos_flip_rate", + metric=flip_metrics.NegToPosFlipRate(threshold=0.5), + metric_name=flip_metrics.NEG_TO_POS_FLIP_RATE_NAME, + expected_result=1 / 10, + ), + dict( + testcase_name="pos_to_neg_flip_rate", + metric=flip_metrics.PosToNegFlipRate(threshold=0.5), + metric_name=flip_metrics.POS_TO_NEG_FLIP_RATE_NAME, + expected_result=2 / 10, + ), + dict( + testcase_name="pos_to_pos_flip_rate", + metric=flip_metrics.PosToPosFlipRate(threshold=0.5), + metric_name=flip_metrics.POS_TO_POS_FLIP_RATE_NAME, + expected_result=4 / 10, + ), + ) + def testIndividualFlipRates(self, metric, metric_name, expected_result): + eval_config = text_format.Parse( + """ model_specs { name: "baseline" is_baseline: true @@ -73,101 +69,103 @@ def testIndividualFlipRates(self, metric, metric_name, expected_result): model_specs { name: "candidate" } - """, config_pb2.EvalConfig()) - baseline_model_name = 'baseline' - candidate_model_name = 'candidate' - - computations = metric.computations( - eval_config=eval_config, - model_names=['baseline', 'candidate'], - output_names=[''], - example_weighted=True, - ) - self.assertLen(computations, 2) - - flip_counts = computations[0] - flip_rate = computations[1] - - examples = [ - { - constants.LABELS_KEY: [0], - constants.PREDICTIONS_KEY: { - baseline_model_name: [0.1], - candidate_model_name: [0.9], + """, + config_pb2.EvalConfig(), + ) + baseline_model_name = "baseline" + candidate_model_name = "candidate" + + computations = metric.computations( + eval_config=eval_config, + model_names=["baseline", "candidate"], + output_names=[""], + example_weighted=True, + ) + self.assertLen(computations, 2) + + flip_counts = computations[0] + flip_rate = computations[1] + + examples = [ + { + constants.LABELS_KEY: [0], + constants.PREDICTIONS_KEY: { + baseline_model_name: [0.1], + candidate_model_name: [0.9], + }, + constants.EXAMPLE_WEIGHTS_KEY: [1], }, - constants.EXAMPLE_WEIGHTS_KEY: [1], - }, - { - constants.LABELS_KEY: [0], - constants.PREDICTIONS_KEY: { - baseline_model_name: [0.9], - candidate_model_name: [0.1], + { + constants.LABELS_KEY: [0], + constants.PREDICTIONS_KEY: { + baseline_model_name: [0.9], + candidate_model_name: [0.1], + }, + constants.EXAMPLE_WEIGHTS_KEY: [2], }, - constants.EXAMPLE_WEIGHTS_KEY: [2], - }, - { - constants.LABELS_KEY: [1], - constants.PREDICTIONS_KEY: { - baseline_model_name: [0.1], - candidate_model_name: [0.2], + { + constants.LABELS_KEY: [1], + constants.PREDICTIONS_KEY: { + baseline_model_name: [0.1], + candidate_model_name: [0.2], + }, + constants.EXAMPLE_WEIGHTS_KEY: [3], }, - constants.EXAMPLE_WEIGHTS_KEY: [3], - }, - { - constants.LABELS_KEY: [1], - constants.PREDICTIONS_KEY: { - baseline_model_name: [0.9], - candidate_model_name: [0.8], + { + constants.LABELS_KEY: [1], + constants.PREDICTIONS_KEY: { + baseline_model_name: [0.9], + candidate_model_name: [0.8], + }, + constants.EXAMPLE_WEIGHTS_KEY: [4], }, - constants.EXAMPLE_WEIGHTS_KEY: [4], - }, - ] - - with beam.Pipeline() as pipeline: - result = ( - pipeline - | 'Create' >> beam.Create(examples) - | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs) - | 'AddSlice' >> beam.Map(lambda x: ((), x)) - | 'ComputeFlipCounts' >> beam.CombinePerKey(flip_counts.combiner) - | 'ComputeFlipRates' - >> beam.Map(lambda x: (x[0], flip_rate.result(x[1]))) - ) - - def check_result(got): - try: - self.assertLen(got, 1) - got_proto = metrics_plots_and_validations_writer.convert_slice_metrics_to_proto( - got[0], add_metrics_callbacks=None - ) - got_slice_key, got_metrics = got[0] - self.assertEqual(got_slice_key, ()) - self.assertLen(got_proto.metric_keys_and_values, 1) - - model_name = candidate_model_name - output_name = '' - example_weighted = True - metric_key = metric_types.MetricKey( - name=metric_name, - model_name=model_name, - output_name=output_name, - example_weighted=example_weighted, - is_diff=True, - ) - - self.assertIn(metric_key, got_metrics) - # Verify that metric is not a 0-D np.ndarray. - self.assertIsInstance(got_metrics[metric_key], float) - self.assertAlmostEqual(got_metrics[metric_key], expected_result) - - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(result, check_result, label='result') - - def testBooleanFlipRates(self): - eval_config = text_format.Parse( - """ + ] + + with beam.Pipeline() as pipeline: + result = ( + pipeline + | "Create" >> beam.Create(examples) + | "Process" >> beam.Map(metric_util.to_standard_metric_inputs) + | "AddSlice" >> beam.Map(lambda x: ((), x)) + | "ComputeFlipCounts" >> beam.CombinePerKey(flip_counts.combiner) + | "ComputeFlipRates" + >> beam.Map(lambda x: (x[0], flip_rate.result(x[1]))) + ) + + def check_result(got): + try: + self.assertLen(got, 1) + got_proto = metrics_plots_and_validations_writer.convert_slice_metrics_to_proto( + got[0], add_metrics_callbacks=None + ) + got_slice_key, got_metrics = got[0] + self.assertEqual(got_slice_key, ()) + self.assertLen(got_proto.metric_keys_and_values, 1) + + model_name = candidate_model_name + output_name = "" + example_weighted = True + metric_key = metric_types.MetricKey( + name=metric_name, + model_name=model_name, + output_name=output_name, + example_weighted=example_weighted, + is_diff=True, + ) + + self.assertIn(metric_key, got_metrics) + # Verify that metric is not a 0-D np.ndarray. + self.assertIsInstance(got_metrics[metric_key], float) + self.assertAlmostEqual(got_metrics[metric_key], expected_result) + + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(result, check_result, label="result") + + def testBooleanFlipRates(self): + eval_config = text_format.Parse( + """ model_specs { name: "baseline" is_baseline: true @@ -176,168 +174,167 @@ def testBooleanFlipRates(self): name: "candidate" } """, - config_pb2.EvalConfig(), - ) - baseline_model_name = 'baseline' - candidate_model_name = 'candidate' - - computations = flip_metrics.BooleanFlipRates(threshold=0.5).computations( - eval_config=eval_config, - model_names=['baseline', 'candidate'], - output_names=[''], - example_weighted=True, - ) - self.assertLen(computations, 10) - - flip_counts = computations[0] - - symmetric_flip_rate = computations[1] - neg_to_neg_flip_rate = computations[3] - neg_to_pos_flip_rate = computations[5] - pos_to_neg_flip_rate = computations[7] - pos_to_pos_flip_rate = computations[9] - - all_flip_rates = ( - symmetric_flip_rate, - neg_to_neg_flip_rate, - neg_to_pos_flip_rate, - pos_to_neg_flip_rate, - pos_to_pos_flip_rate, - ) - - examples = [ - { - constants.LABELS_KEY: [0], - constants.PREDICTIONS_KEY: { - baseline_model_name: [0.1], - candidate_model_name: [0.9], + config_pb2.EvalConfig(), + ) + baseline_model_name = "baseline" + candidate_model_name = "candidate" + + computations = flip_metrics.BooleanFlipRates(threshold=0.5).computations( + eval_config=eval_config, + model_names=["baseline", "candidate"], + output_names=[""], + example_weighted=True, + ) + self.assertLen(computations, 10) + + flip_counts = computations[0] + + symmetric_flip_rate = computations[1] + neg_to_neg_flip_rate = computations[3] + neg_to_pos_flip_rate = computations[5] + pos_to_neg_flip_rate = computations[7] + pos_to_pos_flip_rate = computations[9] + + all_flip_rates = ( + symmetric_flip_rate, + neg_to_neg_flip_rate, + neg_to_pos_flip_rate, + pos_to_neg_flip_rate, + pos_to_pos_flip_rate, + ) + + examples = [ + { + constants.LABELS_KEY: [0], + constants.PREDICTIONS_KEY: { + baseline_model_name: [0.1], + candidate_model_name: [0.9], + }, + constants.EXAMPLE_WEIGHTS_KEY: [1], }, - constants.EXAMPLE_WEIGHTS_KEY: [1], - }, - { - constants.LABELS_KEY: [0], - constants.PREDICTIONS_KEY: { - baseline_model_name: [0.9], - candidate_model_name: [0.1], + { + constants.LABELS_KEY: [0], + constants.PREDICTIONS_KEY: { + baseline_model_name: [0.9], + candidate_model_name: [0.1], + }, + constants.EXAMPLE_WEIGHTS_KEY: [2], }, - constants.EXAMPLE_WEIGHTS_KEY: [2], - }, - { - constants.LABELS_KEY: [1], - constants.PREDICTIONS_KEY: { - baseline_model_name: [0.1], - candidate_model_name: [0.2], + { + constants.LABELS_KEY: [1], + constants.PREDICTIONS_KEY: { + baseline_model_name: [0.1], + candidate_model_name: [0.2], + }, + constants.EXAMPLE_WEIGHTS_KEY: [3], }, - constants.EXAMPLE_WEIGHTS_KEY: [3], - }, - { - constants.LABELS_KEY: [1], - constants.PREDICTIONS_KEY: { - baseline_model_name: [0.9], - candidate_model_name: [0.8], + { + constants.LABELS_KEY: [1], + constants.PREDICTIONS_KEY: { + baseline_model_name: [0.9], + candidate_model_name: [0.8], + }, + constants.EXAMPLE_WEIGHTS_KEY: [4], }, - constants.EXAMPLE_WEIGHTS_KEY: [4], - }, - ] - - def _add_derived_metrics(sliced_metrics, derived_computations): - slice_key, metrics = sliced_metrics - result = copy.copy(metrics) - for c in derived_computations: - result.update(c.result(result)) - return slice_key, result - - with beam.Pipeline() as pipeline: - result = ( - pipeline - | 'Create' >> beam.Create(examples) - | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs) - | 'AddSlice' >> beam.Map(lambda x: ((), x)) - | 'ComputeFlipCounts' >> beam.CombinePerKey(flip_counts.combiner) - | 'AddDerivedMetrics' - >> beam.Map(_add_derived_metrics, derived_computations=all_flip_rates) - ) - - def check_result(got): - try: - self.assertLen(got, 1) - got_proto = ( - metrics_plots_and_validations_writer - .convert_slice_metrics_to_proto( - got[0], add_metrics_callbacks=None)) - got_slice_key, got_metrics = got[0] - self.assertEqual(got_slice_key, ()) - self.assertLen(got_proto.metric_keys_and_values, 6) - - model_name = candidate_model_name - output_name = '' - example_weighted = True - is_diff = True - sym_fr_key = metric_types.MetricKey( - name=flip_metrics.FLIP_RATE_NAME, - model_name=model_name, - output_name=output_name, - example_weighted=example_weighted, - is_diff=is_diff, - ) - self.assertIn(sym_fr_key, got_metrics) - # Verify that metric is not a 0-D np.ndarray. - self.assertIsInstance(got_metrics[sym_fr_key], float) - self.assertAlmostEqual(got_metrics[sym_fr_key], 3 / 10) - - n2n_fr_key = metric_types.MetricKey( - name=flip_metrics.NEG_TO_NEG_FLIP_RATE_NAME, - model_name=model_name, - output_name=output_name, - example_weighted=example_weighted, - is_diff=is_diff, - ) - self.assertIn(n2n_fr_key, got_metrics) - # Verify that metric is not a 0-D np.ndarray. - self.assertIsInstance(got_metrics[n2n_fr_key], float) - self.assertAlmostEqual(got_metrics[n2n_fr_key], 3 / 10) - - n2p_fr_key = metric_types.MetricKey( - name=flip_metrics.NEG_TO_POS_FLIP_RATE_NAME, - model_name=model_name, - output_name=output_name, - example_weighted=example_weighted, - is_diff=is_diff, - ) - self.assertIn(n2p_fr_key, got_metrics) - # Verify that metric is not a 0-D np.ndarray. - self.assertIsInstance(got_metrics[n2p_fr_key], float) - self.assertAlmostEqual(got_metrics[n2p_fr_key], 1 / 10) - - p2n_fr_key = metric_types.MetricKey( - name=flip_metrics.POS_TO_NEG_FLIP_RATE_NAME, - model_name=model_name, - output_name=output_name, - example_weighted=example_weighted, - is_diff=is_diff, - ) - self.assertIn(p2n_fr_key, got_metrics) - # Verify that metric is not a 0-D np.ndarray. - self.assertIsInstance(got_metrics[p2n_fr_key], float) - self.assertAlmostEqual(got_metrics[p2n_fr_key], 2 / 10) - - p2p_fr_key = metric_types.MetricKey( - name=flip_metrics.POS_TO_POS_FLIP_RATE_NAME, - model_name=model_name, - output_name=output_name, - example_weighted=example_weighted, - is_diff=is_diff, - ) - self.assertIn(p2p_fr_key, got_metrics) - # Verify that metric is not a 0-D np.ndarray. - self.assertIsInstance(got_metrics[p2p_fr_key], float) - self.assertAlmostEqual(got_metrics[p2p_fr_key], 4 / 10) - - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(result, check_result, label='result') - - -if __name__ == '__main__': - absltest.main() + ] + + def _add_derived_metrics(sliced_metrics, derived_computations): + slice_key, metrics = sliced_metrics + result = copy.copy(metrics) + for c in derived_computations: + result.update(c.result(result)) + return slice_key, result + + with beam.Pipeline() as pipeline: + result = ( + pipeline + | "Create" >> beam.Create(examples) + | "Process" >> beam.Map(metric_util.to_standard_metric_inputs) + | "AddSlice" >> beam.Map(lambda x: ((), x)) + | "ComputeFlipCounts" >> beam.CombinePerKey(flip_counts.combiner) + | "AddDerivedMetrics" + >> beam.Map(_add_derived_metrics, derived_computations=all_flip_rates) + ) + + def check_result(got): + try: + self.assertLen(got, 1) + got_proto = metrics_plots_and_validations_writer.convert_slice_metrics_to_proto( + got[0], add_metrics_callbacks=None + ) + got_slice_key, got_metrics = got[0] + self.assertEqual(got_slice_key, ()) + self.assertLen(got_proto.metric_keys_and_values, 6) + + model_name = candidate_model_name + output_name = "" + example_weighted = True + is_diff = True + sym_fr_key = metric_types.MetricKey( + name=flip_metrics.FLIP_RATE_NAME, + model_name=model_name, + output_name=output_name, + example_weighted=example_weighted, + is_diff=is_diff, + ) + self.assertIn(sym_fr_key, got_metrics) + # Verify that metric is not a 0-D np.ndarray. + self.assertIsInstance(got_metrics[sym_fr_key], float) + self.assertAlmostEqual(got_metrics[sym_fr_key], 3 / 10) + + n2n_fr_key = metric_types.MetricKey( + name=flip_metrics.NEG_TO_NEG_FLIP_RATE_NAME, + model_name=model_name, + output_name=output_name, + example_weighted=example_weighted, + is_diff=is_diff, + ) + self.assertIn(n2n_fr_key, got_metrics) + # Verify that metric is not a 0-D np.ndarray. + self.assertIsInstance(got_metrics[n2n_fr_key], float) + self.assertAlmostEqual(got_metrics[n2n_fr_key], 3 / 10) + + n2p_fr_key = metric_types.MetricKey( + name=flip_metrics.NEG_TO_POS_FLIP_RATE_NAME, + model_name=model_name, + output_name=output_name, + example_weighted=example_weighted, + is_diff=is_diff, + ) + self.assertIn(n2p_fr_key, got_metrics) + # Verify that metric is not a 0-D np.ndarray. + self.assertIsInstance(got_metrics[n2p_fr_key], float) + self.assertAlmostEqual(got_metrics[n2p_fr_key], 1 / 10) + + p2n_fr_key = metric_types.MetricKey( + name=flip_metrics.POS_TO_NEG_FLIP_RATE_NAME, + model_name=model_name, + output_name=output_name, + example_weighted=example_weighted, + is_diff=is_diff, + ) + self.assertIn(p2n_fr_key, got_metrics) + # Verify that metric is not a 0-D np.ndarray. + self.assertIsInstance(got_metrics[p2n_fr_key], float) + self.assertAlmostEqual(got_metrics[p2n_fr_key], 2 / 10) + + p2p_fr_key = metric_types.MetricKey( + name=flip_metrics.POS_TO_POS_FLIP_RATE_NAME, + model_name=model_name, + output_name=output_name, + example_weighted=example_weighted, + is_diff=is_diff, + ) + self.assertIn(p2p_fr_key, got_metrics) + # Verify that metric is not a 0-D np.ndarray. + self.assertIsInstance(got_metrics[p2p_fr_key], float) + self.assertAlmostEqual(got_metrics[p2p_fr_key], 4 / 10) + + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(result, check_result, label="result") + + +if __name__ == "__main__": + absltest.main() diff --git a/tensorflow_model_analysis/metrics/mean_regression_error.py b/tensorflow_model_analysis/metrics/mean_regression_error.py index ce60c36d14..a3441d7c03 100644 --- a/tensorflow_model_analysis/metrics/mean_regression_error.py +++ b/tensorflow_model_analysis/metrics/mean_regression_error.py @@ -19,40 +19,38 @@ import apache_beam as beam import numpy as np -from tensorflow_model_analysis.metrics import metric_types -from tensorflow_model_analysis.metrics import metric_util -from tensorflow_model_analysis.proto import config_pb2 +from tensorflow_model_analysis.metrics import metric_types, metric_util +from tensorflow_model_analysis.proto import config_pb2 -MEAN_ABSOLUTE_ERROR_NAME = 'mean_absolute_error' -MEAN_SQUARED_ERROR_NAME = 'mean_squared_error' -MEAN_ABSOLUTE_PERCENTAGE_ERROR_NAME = 'mean_absolute_percentage_error' -MEAN_SQUARED_LOGARITHMIC_ERROR_NAME = 'mean_squared_logarithmic_error' +MEAN_ABSOLUTE_ERROR_NAME = "mean_absolute_error" +MEAN_SQUARED_ERROR_NAME = "mean_squared_error" +MEAN_ABSOLUTE_PERCENTAGE_ERROR_NAME = "mean_absolute_percentage_error" +MEAN_SQUARED_LOGARITHMIC_ERROR_NAME = "mean_squared_logarithmic_error" class MeanAbsoluteError(metric_types.Metric): - """Calculates the mean of absolute error between labels and predictions. - - Formula: error = abs(label - prediction) + """Calculates the mean of absolute error between labels and predictions. - The metric computes the mean of absolute error between labels and - predictions. The labels and predictions should be floats. - """ + Formula: error = abs(label - prediction) - def __init__(self, name: str = MEAN_ABSOLUTE_ERROR_NAME, **kwargs): - """Initializes mean regression error metric. - - Args: - name: The name of the metric. - **kwargs: Additional named keyword arguments. + The metric computes the mean of absolute error between labels and + predictions. The labels and predictions should be floats. """ - super().__init__( - metric_util.merge_per_key_computations( - _mean_absolute_error_computations - ), - name=name, - **kwargs - ) + + def __init__(self, name: str = MEAN_ABSOLUTE_ERROR_NAME, **kwargs): + """Initializes mean regression error metric. + + Args: + ---- + name: The name of the metric. + **kwargs: Additional named keyword arguments. + """ + super().__init__( + metric_util.merge_per_key_computations(_mean_absolute_error_computations), + name=name, + **kwargs, + ) def _mean_absolute_error_computations( @@ -66,71 +64,74 @@ def _mean_absolute_error_computations( class_weights: Optional[Dict[int, float]] = None, preprocessors: Optional[List[metric_types.Preprocessor]] = None, ) -> metric_types.MetricComputations: - """Returns metric computations for mean absolute error computations. - - Args: - name: The name of the metric. - eval_config: The configurations for TFMA pipeline. - model_name: The name of the model to get predictions from. - output_name: The name of the output under the model to get predictions from. - example_weighted: Whether the examples have specified weights. - sub_key: The key includes class, top-k, k information. It should only be in - classfication problems. - aggregation_type: The method to aggregate over classes. It should only be in - classfication problems. - class_weights: The weight of classes. It should only be in classfication - problems. - preprocessors: The preprocessors to apply to the input data. - """ - key = metric_types.MetricKey( - name=name, - model_name=model_name, - output_name=output_name, - sub_key=sub_key, - example_weighted=example_weighted) - return [ - metric_types.MetricComputation( - keys=[key], - preprocessors=preprocessors, - combiner=_MeanAbsoluteErrorCombiner( - eval_config=eval_config, - model_name=model_name, - output_name=output_name, - metric_key=key, - class_weights=class_weights, - aggregation_type=aggregation_type, - example_weighted=example_weighted, - ), - ) - ] + """Returns metric computations for mean absolute error computations. + + Args: + ---- + name: The name of the metric. + eval_config: The configurations for TFMA pipeline. + model_name: The name of the model to get predictions from. + output_name: The name of the output under the model to get predictions from. + example_weighted: Whether the examples have specified weights. + sub_key: The key includes class, top-k, k information. It should only be in + classfication problems. + aggregation_type: The method to aggregate over classes. It should only be in + classfication problems. + class_weights: The weight of classes. It should only be in classfication + problems. + preprocessors: The preprocessors to apply to the input data. + """ + key = metric_types.MetricKey( + name=name, + model_name=model_name, + output_name=output_name, + sub_key=sub_key, + example_weighted=example_weighted, + ) + return [ + metric_types.MetricComputation( + keys=[key], + preprocessors=preprocessors, + combiner=_MeanAbsoluteErrorCombiner( + eval_config=eval_config, + model_name=model_name, + output_name=output_name, + metric_key=key, + class_weights=class_weights, + aggregation_type=aggregation_type, + example_weighted=example_weighted, + ), + ) + ] metric_types.register_metric(MeanAbsoluteError) class MeanAbsolutePercentageError(metric_types.Metric): - """Calculates the mean of absolute percentage error. - - Formula: error = 100 * abs( (label - prediction) / label ) - - The metric computes the mean of absolute percentage error between labels and - predictions. The labels and predictions should be floats. - """ + """Calculates the mean of absolute percentage error. - def __init__(self, name: str = MEAN_ABSOLUTE_PERCENTAGE_ERROR_NAME, **kwargs): - """Initializes mean regression error metric. + Formula: error = 100 * abs( (label - prediction) / label ) - Args: - name: The name of the metric. - **kwargs: Additional named keyword arguments. + The metric computes the mean of absolute percentage error between labels and + predictions. The labels and predictions should be floats. """ - super().__init__( - metric_util.merge_per_key_computations( - _mean_absolute_percentage_error_computations - ), - name=name, - **kwargs - ) + + def __init__(self, name: str = MEAN_ABSOLUTE_PERCENTAGE_ERROR_NAME, **kwargs): + """Initializes mean regression error metric. + + Args: + ---- + name: The name of the metric. + **kwargs: Additional named keyword arguments. + """ + super().__init__( + metric_util.merge_per_key_computations( + _mean_absolute_percentage_error_computations + ), + name=name, + **kwargs, + ) def _mean_absolute_percentage_error_computations( @@ -144,72 +145,73 @@ def _mean_absolute_percentage_error_computations( class_weights: Optional[Dict[int, float]] = None, preprocessors: Optional[List[metric_types.Preprocessor]] = None, ) -> metric_types.MetricComputations: - """Returns metric computations for mean absolute percentage error. - - Args: - name: The name of the metric. - eval_config: The configurations for TFMA pipeline. - model_name: The name of the model to get predictions from. - output_name: The name of the output under the model to get predictions from. - example_weighted: Whether the examples have specified weights. - sub_key: The key includes class, top-k, k information. It should only be in - classfication problems. - aggregation_type: The method to aggregate over classes. It should only be in - classfication problems. - class_weights: The weight of classes. It should only be in classfication - problems. - preprocessors: The preprocessors to apply to the input data. - """ - key = metric_types.MetricKey( - name=name, - model_name=model_name, - output_name=output_name, - sub_key=sub_key, - example_weighted=example_weighted) - return [ - metric_types.MetricComputation( - keys=[key], - preprocessors=preprocessors, - combiner=_MeanAbsolutePercentageErrorCombiner( - eval_config=eval_config, - model_name=model_name, - output_name=output_name, - metric_key=key, - class_weights=class_weights, - aggregation_type=aggregation_type, - example_weighted=example_weighted, - ), - ) - ] + """Returns metric computations for mean absolute percentage error. + + Args: + ---- + name: The name of the metric. + eval_config: The configurations for TFMA pipeline. + model_name: The name of the model to get predictions from. + output_name: The name of the output under the model to get predictions from. + example_weighted: Whether the examples have specified weights. + sub_key: The key includes class, top-k, k information. It should only be in + classfication problems. + aggregation_type: The method to aggregate over classes. It should only be in + classfication problems. + class_weights: The weight of classes. It should only be in classfication + problems. + preprocessors: The preprocessors to apply to the input data. + """ + key = metric_types.MetricKey( + name=name, + model_name=model_name, + output_name=output_name, + sub_key=sub_key, + example_weighted=example_weighted, + ) + return [ + metric_types.MetricComputation( + keys=[key], + preprocessors=preprocessors, + combiner=_MeanAbsolutePercentageErrorCombiner( + eval_config=eval_config, + model_name=model_name, + output_name=output_name, + metric_key=key, + class_weights=class_weights, + aggregation_type=aggregation_type, + example_weighted=example_weighted, + ), + ) + ] metric_types.register_metric(MeanAbsolutePercentageError) class MeanSquaredError(metric_types.Metric): - """Calculates the mean of squared error between labels and predictions. - - Formula: error = L2_norm(label - prediction)**2 + """Calculates the mean of squared error between labels and predictions. - The metric computes the mean of squared error (square of L2 norm) between - labels and predictions. The labels and predictions could be arrays of - arbitrary dimensions. Their dimension should match. - """ + Formula: error = L2_norm(label - prediction)**2 - def __init__(self, name: str = MEAN_SQUARED_ERROR_NAME, **kwargs): - """Initializes mean regression error metric. - - Args: - name: The name of the metric. - **kwargs: Additional named keyword arguments. + The metric computes the mean of squared error (square of L2 norm) between + labels and predictions. The labels and predictions could be arrays of + arbitrary dimensions. Their dimension should match. """ - super().__init__( - metric_util.merge_per_key_computations( - _mean_squared_error_computations - ), - name=name, - **kwargs - ) + + def __init__(self, name: str = MEAN_SQUARED_ERROR_NAME, **kwargs): + """Initializes mean regression error metric. + + Args: + ---- + name: The name of the metric. + **kwargs: Additional named keyword arguments. + """ + super().__init__( + metric_util.merge_per_key_computations(_mean_squared_error_computations), + name=name, + **kwargs, + ) def _mean_squared_error_computations( @@ -223,74 +225,77 @@ def _mean_squared_error_computations( class_weights: Optional[Dict[int, float]] = None, preprocessors: Optional[List[metric_types.Preprocessor]] = None, ) -> metric_types.MetricComputations: - """Returns metric computations for mean squared error computations. - - Args: - name: The name of the metric. - eval_config: The configurations for TFMA pipeline. - model_name: The name of the model to get predictions from. - output_name: The name of the output under the model to get predictions from. - example_weighted: Whether the examples have specified weights. - sub_key: The key includes class, top-k, k information. It should only be in - classfication problems. - aggregation_type: The method to aggregate over classes. It should only be in - classfication problems. - class_weights: The weight of classes. It should only be in classfication - problems. - preprocessors: The preprocessors to apply to the input data. - """ - key = metric_types.MetricKey( - name=name, - model_name=model_name, - output_name=output_name, - sub_key=sub_key, - example_weighted=example_weighted) - return [ - metric_types.MetricComputation( - keys=[key], - preprocessors=preprocessors, - combiner=_MeanSquaredErrorCombiner( - eval_config=eval_config, - model_name=model_name, - output_name=output_name, - metric_key=key, - example_weighted=example_weighted, - aggregation_type=aggregation_type, - class_weights=class_weights, - ), - ) - ] + """Returns metric computations for mean squared error computations. + + Args: + ---- + name: The name of the metric. + eval_config: The configurations for TFMA pipeline. + model_name: The name of the model to get predictions from. + output_name: The name of the output under the model to get predictions from. + example_weighted: Whether the examples have specified weights. + sub_key: The key includes class, top-k, k information. It should only be in + classfication problems. + aggregation_type: The method to aggregate over classes. It should only be in + classfication problems. + class_weights: The weight of classes. It should only be in classfication + problems. + preprocessors: The preprocessors to apply to the input data. + """ + key = metric_types.MetricKey( + name=name, + model_name=model_name, + output_name=output_name, + sub_key=sub_key, + example_weighted=example_weighted, + ) + return [ + metric_types.MetricComputation( + keys=[key], + preprocessors=preprocessors, + combiner=_MeanSquaredErrorCombiner( + eval_config=eval_config, + model_name=model_name, + output_name=output_name, + metric_key=key, + example_weighted=example_weighted, + aggregation_type=aggregation_type, + class_weights=class_weights, + ), + ) + ] metric_types.register_metric(MeanSquaredError) class MeanSquaredLogarithmicError(metric_types.Metric): - """Calculates the mean of squared logarithmic error. + """Calculates the mean of squared logarithmic error. - Formula: error = L2_norm(log(label + 1) - log(prediction + 1))**2 - Note: log of an array will be elementwise, - i.e. log([x1, x2]) = [log(x1), log(x2)] + Formula: error = L2_norm(log(label + 1) - log(prediction + 1))**2 + Note: log of an array will be elementwise, + i.e. log([x1, x2]) = [log(x1), log(x2)] - The metric computes the mean of squared logarithmic error (square of L2 norm) - between labels and predictions. The labels and predictions could be arrays of - arbitrary dimensions. Their dimension should match. - """ - - def __init__(self, name: str = MEAN_SQUARED_LOGARITHMIC_ERROR_NAME, **kwargs): - """Initializes mean regression error metric. - - Args: - name: The name of the metric. - **kwargs: Additional named keyword arguments. + The metric computes the mean of squared logarithmic error (square of L2 norm) + between labels and predictions. The labels and predictions could be arrays of + arbitrary dimensions. Their dimension should match. """ - super().__init__( - metric_util.merge_per_key_computations( - _mean_squared_logarithmic_error_computations - ), - name=name, - **kwargs - ) + + def __init__(self, name: str = MEAN_SQUARED_LOGARITHMIC_ERROR_NAME, **kwargs): + """Initializes mean regression error metric. + + Args: + ---- + name: The name of the metric. + **kwargs: Additional named keyword arguments. + """ + super().__init__( + metric_util.merge_per_key_computations( + _mean_squared_logarithmic_error_computations + ), + name=name, + **kwargs, + ) def _mean_squared_logarithmic_error_computations( @@ -304,43 +309,45 @@ def _mean_squared_logarithmic_error_computations( class_weights: Optional[Dict[int, float]] = None, preprocessors: Optional[List[metric_types.Preprocessor]] = None, ) -> metric_types.MetricComputations: - """Returns metric computations for mean squared logarithmic error. - - Args: - name: The name of the metric. - eval_config: The configurations for TFMA pipeline. - model_name: The name of the model to get predictions from. - output_name: The name of the output under the model to get predictions from. - example_weighted: Whether the examples have specified weights. - sub_key: The key includes class, top-k, k information. It should only be in - classfication problems. - aggregation_type: The method to aggregate over classes. It should only be in - classfication problems. - class_weights: The weight of classes. It should only be in classfication - problems. - preprocessors: The preprocessors to apply to the input data. - """ - key = metric_types.MetricKey( - name=name, - model_name=model_name, - output_name=output_name, - sub_key=sub_key, - example_weighted=example_weighted) - return [ - metric_types.MetricComputation( - keys=[key], - preprocessors=preprocessors, - combiner=_MeanSquaredLogarithmicErrorCombiner( - eval_config=eval_config, - model_name=model_name, - output_name=output_name, - metric_key=key, - example_weighted=example_weighted, - aggregation_type=aggregation_type, - class_weights=class_weights, - ), - ) - ] + """Returns metric computations for mean squared logarithmic error. + + Args: + ---- + name: The name of the metric. + eval_config: The configurations for TFMA pipeline. + model_name: The name of the model to get predictions from. + output_name: The name of the output under the model to get predictions from. + example_weighted: Whether the examples have specified weights. + sub_key: The key includes class, top-k, k information. It should only be in + classfication problems. + aggregation_type: The method to aggregate over classes. It should only be in + classfication problems. + class_weights: The weight of classes. It should only be in classfication + problems. + preprocessors: The preprocessors to apply to the input data. + """ + key = metric_types.MetricKey( + name=name, + model_name=model_name, + output_name=output_name, + sub_key=sub_key, + example_weighted=example_weighted, + ) + return [ + metric_types.MetricComputation( + keys=[key], + preprocessors=preprocessors, + combiner=_MeanSquaredLogarithmicErrorCombiner( + eval_config=eval_config, + model_name=model_name, + output_name=output_name, + metric_key=key, + example_weighted=example_weighted, + aggregation_type=aggregation_type, + class_weights=class_weights, + ), + ) + ] metric_types.register_metric(MeanSquaredLogarithmicError) @@ -348,146 +355,145 @@ def _mean_squared_logarithmic_error_computations( @dataclasses.dataclass class _MeanRegressionErrorAccumulator: - """Accumulator for computing MeanRegressionError.""" - total_example_weights: float = 0.0 - total_regression_error: float = 0.0 + """Accumulator for computing MeanRegressionError.""" - def merge(self, other: '_MeanRegressionErrorAccumulator'): - self.total_example_weights += other.total_example_weights - self.total_regression_error += other.total_regression_error + total_example_weights: float = 0.0 + total_regression_error: float = 0.0 + def merge(self, other: "_MeanRegressionErrorAccumulator"): + self.total_example_weights += other.total_example_weights + self.total_regression_error += other.total_regression_error -class _MeanRegressionErrorCombiner(beam.CombineFn, metaclass=abc.ABCMeta): - """A combiner which computes metrics averaging regression errors.""" - - def __init__(self, eval_config: config_pb2.EvalConfig, model_name: str, - output_name: str, metric_key: metric_types.MetricKey, - aggregation_type: Optional[metric_types.AggregationType], - class_weights: Optional[Dict[int, - float]], example_weighted: bool): - self._eval_config = eval_config - self._model_name = model_name - self._output_name = output_name - self._metric_key = metric_key - self._example_weighted = example_weighted - self._aggregation_type = aggregation_type - self._class_weights = class_weights - - @abc.abstractmethod - def _regression_error(self, label, prediction) -> float: - """Returns the regression error between the label and prediction. - - Subclasses must override this method. Labels and preditctions could be - an array, a float, a string and etc. But the output of regression error must - be a float. - Args: - label: label. - prediction: prediction from the model. - """ - raise NotImplementedError('Must be implemented in subclasses.') - - def create_accumulator(self) -> _MeanRegressionErrorAccumulator: - return _MeanRegressionErrorAccumulator() - - def add_input( - self, accumulator: _MeanRegressionErrorAccumulator, - element: metric_types.StandardMetricInputs - ) -> _MeanRegressionErrorAccumulator: - - lpe_iterator = metric_util.to_label_prediction_example_weight( - element, - eval_config=self._eval_config, - model_name=self._metric_key.model_name, - output_name=self._metric_key.output_name, - aggregation_type=self._aggregation_type, - class_weights=self._class_weights, - example_weighted=self._example_weighted, - sub_key=self._metric_key.sub_key, - ) - for label, prediction, example_weight in lpe_iterator: - # The np.item method makes sure the result is a one element numpy array - # and returns the single element as a float. - error = self._regression_error(label, prediction) - if not np.isnan(error): - accumulator.total_regression_error += ( - error * metric_util.safe_to_scalar(example_weight) - ) - accumulator.total_example_weights += metric_util.safe_to_scalar( - example_weight +class _MeanRegressionErrorCombiner(beam.CombineFn, metaclass=abc.ABCMeta): + """A combiner which computes metrics averaging regression errors.""" + + def __init__( + self, + eval_config: config_pb2.EvalConfig, + model_name: str, + output_name: str, + metric_key: metric_types.MetricKey, + aggregation_type: Optional[metric_types.AggregationType], + class_weights: Optional[Dict[int, float]], + example_weighted: bool, + ): + self._eval_config = eval_config + self._model_name = model_name + self._output_name = output_name + self._metric_key = metric_key + self._example_weighted = example_weighted + self._aggregation_type = aggregation_type + self._class_weights = class_weights + + @abc.abstractmethod + def _regression_error(self, label, prediction) -> float: + """Returns the regression error between the label and prediction. + + Subclasses must override this method. Labels and preditctions could be + an array, a float, a string and etc. But the output of regression error must + be a float. + + Args: + ---- + label: label. + prediction: prediction from the model. + """ + raise NotImplementedError("Must be implemented in subclasses.") + + def create_accumulator(self) -> _MeanRegressionErrorAccumulator: + return _MeanRegressionErrorAccumulator() + + def add_input( + self, + accumulator: _MeanRegressionErrorAccumulator, + element: metric_types.StandardMetricInputs, + ) -> _MeanRegressionErrorAccumulator: + lpe_iterator = metric_util.to_label_prediction_example_weight( + element, + eval_config=self._eval_config, + model_name=self._metric_key.model_name, + output_name=self._metric_key.output_name, + aggregation_type=self._aggregation_type, + class_weights=self._class_weights, + example_weighted=self._example_weighted, + sub_key=self._metric_key.sub_key, ) - - return accumulator - - def merge_accumulators( - self, accumulators: Iterable[_MeanRegressionErrorAccumulator] - ) -> _MeanRegressionErrorAccumulator: - result = next(iter(accumulators)) - for accumulator in accumulators: - result.merge(accumulator) - return result - - def extract_output( - self, accumulator: _MeanRegressionErrorAccumulator - ) -> metric_types.MetricsDict: - if accumulator.total_example_weights != 0.0: - result = ( - accumulator.total_regression_error / accumulator.total_example_weights - ) - else: - result = float('nan') - return {self._metric_key: result} + for label, prediction, example_weight in lpe_iterator: + # The np.item method makes sure the result is a one element numpy array + # and returns the single element as a float. + error = self._regression_error(label, prediction) + if not np.isnan(error): + accumulator.total_regression_error += ( + error * metric_util.safe_to_scalar(example_weight) + ) + accumulator.total_example_weights += metric_util.safe_to_scalar( + example_weight + ) + + return accumulator + + def merge_accumulators( + self, accumulators: Iterable[_MeanRegressionErrorAccumulator] + ) -> _MeanRegressionErrorAccumulator: + result = next(iter(accumulators)) + for accumulator in accumulators: + result.merge(accumulator) + return result + + def extract_output( + self, accumulator: _MeanRegressionErrorAccumulator + ) -> metric_types.MetricsDict: + if accumulator.total_example_weights != 0.0: + result = ( + accumulator.total_regression_error / accumulator.total_example_weights + ) + else: + result = float("nan") + return {self._metric_key: result} class _MeanAbsoluteErrorCombiner(_MeanRegressionErrorCombiner): - """A combiner which computes metrics averaging absolute errors.""" + """A combiner which computes metrics averaging absolute errors.""" - def _regression_error( - self, label: np.ndarray, prediction: np.ndarray - ) -> float: - # The np.item method makes sure the result is a one element numpy array and - # returns the single element as a float. - return metric_util.safe_to_scalar(np.absolute(label - prediction)) + def _regression_error(self, label: np.ndarray, prediction: np.ndarray) -> float: + # The np.item method makes sure the result is a one element numpy array and + # returns the single element as a float. + return metric_util.safe_to_scalar(np.absolute(label - prediction)) class _MeanSquaredErrorCombiner(_MeanRegressionErrorCombiner): - """A combiner which computes metrics averaging squared errors.""" + """A combiner which computes metrics averaging squared errors.""" - def _regression_error( - self, label: np.ndarray, prediction: np.ndarray - ) -> float: - # The np.item method makes sure the result is a one element numpy array and - # returns the single element as a float. - return metric_util.safe_to_scalar(np.linalg.norm(label - prediction)) ** 2 + def _regression_error(self, label: np.ndarray, prediction: np.ndarray) -> float: + # The np.item method makes sure the result is a one element numpy array and + # returns the single element as a float. + return metric_util.safe_to_scalar(np.linalg.norm(label - prediction)) ** 2 class _MeanAbsolutePercentageErrorCombiner(_MeanRegressionErrorCombiner): - """A combiner which computes metrics averaging absolute percentage errors.""" - - def _regression_error( - self, label: np.ndarray, prediction: np.ndarray - ) -> float: - # The np.item method makes sure the result is a one element numpy array and - # returns the single element as a float. - # The error also requires the label to be a one element numpy array. - if label.size == 0 or label.item() == 0: - return float('nan') - return 100 * metric_util.safe_to_scalar( - np.absolute((label - prediction) / label) - ) + """A combiner which computes metrics averaging absolute percentage errors.""" + + def _regression_error(self, label: np.ndarray, prediction: np.ndarray) -> float: + # The np.item method makes sure the result is a one element numpy array and + # returns the single element as a float. + # The error also requires the label to be a one element numpy array. + if label.size == 0 or label.item() == 0: + return float("nan") + return 100 * metric_util.safe_to_scalar( + np.absolute((label - prediction) / label) + ) class _MeanSquaredLogarithmicErrorCombiner(_MeanRegressionErrorCombiner): - """A combiner which computes metrics averaging squared logarithmic errors.""" - - def _regression_error( - self, label: np.ndarray, prediction: np.ndarray - ) -> float: - # The np.item method makes sure the result is a one element numpy array and - # returns the single element as a float. - return ( - metric_util.safe_to_scalar( - np.linalg.norm(np.log(label + 1) - np.log(prediction + 1)) + """A combiner which computes metrics averaging squared logarithmic errors.""" + + def _regression_error(self, label: np.ndarray, prediction: np.ndarray) -> float: + # The np.item method makes sure the result is a one element numpy array and + # returns the single element as a float. + return ( + metric_util.safe_to_scalar( + np.linalg.norm(np.log(label + 1) - np.log(prediction + 1)) + ) + ** 2 ) - ** 2 - ) diff --git a/tensorflow_model_analysis/metrics/mean_regression_error_test.py b/tensorflow_model_analysis/metrics/mean_regression_error_test.py index 493fb62b09..be5c853db9 100644 --- a/tensorflow_model_analysis/metrics/mean_regression_error_test.py +++ b/tensorflow_model_analysis/metrics/mean_regression_error_test.py @@ -14,218 +14,218 @@ """Tests for mean_regression_error related metrics.""" from typing import Iterator -from absl.testing import absltest -from absl.testing import parameterized + import apache_beam as beam -from apache_beam.testing import util import numpy as np +from absl.testing import absltest, parameterized +from apache_beam.testing import util + from tensorflow_model_analysis.api import types -from tensorflow_model_analysis.metrics import mean_regression_error -from tensorflow_model_analysis.metrics import metric_types -from tensorflow_model_analysis.metrics import metric_util +from tensorflow_model_analysis.metrics import ( + mean_regression_error, + metric_types, + metric_util, +) class IdTransformPreprocessor(metric_types.Preprocessor): - """ID transform preprocessor.""" - - def __init__( - self, - ): - super().__init__( - name=metric_util.generate_private_name_from_arguments( - 'test_id_transform_preprocessor', + """ID transform preprocessor.""" + + def __init__( + self, + ): + super().__init__( + name=metric_util.generate_private_name_from_arguments( + "test_id_transform_preprocessor", + ) ) - ) - def process( - self, extracts: types.Extracts - ) -> Iterator[metric_types.StandardMetricInputs]: - yield extracts + def process( + self, extracts: types.Extracts + ) -> Iterator[metric_types.StandardMetricInputs]: + yield extracts class MeanRegressionErrorTest(parameterized.TestCase): + @parameterized.named_parameters( + # example1 is |0.1 - 1| * 0.1 + |0.3 - 0| * 0.5 + |0.5 - 2| * 1 + # = 0.09 + 0.15 + 1.5 = 1.74 + # example2 is |1 - 0.5| + |2 - 1| + |3 - 5| = 3.5 + # example3 is |3 - 5| = 2 + # average error: (1.74 + 3.5 + 2) / 5.6 = 1.292857 + ( + "_mean_absolute_error", + mean_regression_error.MeanAbsoluteError(), + 1.292857, + ), + # example1 is |0.1 - 1|^2 * 0.1 + |0.3 - 0|^2 * 0.5 + |0.5 - 2|^2 * 1 + # = 0.081 + 0.045 + 2.25 = 2.376 + # example2 is |1 - 0.5|^2 + |2 - 1|^2 + |3 - 5|^2 = 5.25 + # example3 is |3 - 5|^2 = 4 + # average error: (2.376 + 5.25 + 4) / 5.6 = 2.07607 + ( + "_mean_squared_error", + mean_regression_error.MeanSquaredError(), + 2.07607, + ), + # example1 is 100 * (|0.1 - 1| / 0.1 * 0.1 + |0.3 - 0| / 0.3 * 0.5 + + # |0.5 - 2| / 0.5 * 1) = 440 + # example2 is 100 * (|1 - 0.5| / 1 + |2 - 1| / 2 + |3 - 5| / 3) = 166.66 + # example3 is 100 * (|3 - 5| / 3) = 66.66 + # average error: (440 + 166.66 + 66.66) / 5.6 = 120.238095 + ( + "_mean_absolute_percentage_error", + mean_regression_error.MeanAbsolutePercentageError(), + 120.238095, + ), + # example1 is |log(0.1+1) - log(1+1)|^2 * 0.1 + + # |log(0.3+1) - log(0+1)|^2 * 0.5 + |log(0.5+1) - log(2+1)|^2 * 1 =0.55061 + # example2 is |log(1+1) - log(0.5+1)|^2 + |log(2+1) - log(1+1)|^2 + # + |log(3+1) - log(5+1)|^2 = 0.41156 + # example3 is |log(3+1) - log(5+1)|^2 = 0.16440 + # average error: (0.55061 + 0.41156 + 0.16440) / 5.6 = 0.20117 + ( + "_mean_squared_logarithmic_error", + mean_regression_error.MeanSquaredLogarithmicError(), + 0.20117, + ), + ) + def testRegressionErrorWithWeights(self, metric, expected_value): + computations = metric.computations(example_weighted=True) + computation = computations[0] + example1 = { + "labels": np.array([0.1, 0.3, 0.5]), + "predictions": np.array([1, 0, 2]), + "example_weights": np.array([0.1, 0.5, 1]), + } + example2 = { + "labels": np.array([1, 2, 3]), + "predictions": np.array([0.5, 1.0, 5]), + "example_weights": np.array([1.0]), + } + example3 = { + "labels": np.array([3]), + "predictions": np.array([5]), + "example_weights": np.array([1.0]), + } + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "Create" >> beam.Create([example1, example2, example3]) + | "Process" >> beam.Map(metric_util.to_standard_metric_inputs) + | "AddSlice" >> beam.Map(lambda x: ((), x)) + | "ComputeMetric" >> beam.CombinePerKey(computation.combiner) + ) + + # pylint: enable=no-value-for-parameter + def check_result(got): + try: + self.assertLen(got, 1) + got_slice_key, got_metrics = got[0] + self.assertEqual(got_slice_key, ()) + key = computation.keys[0] + self.assertIn(key, got_metrics) + self.assertAlmostEqual(got_metrics[key], expected_value, places=5) + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(result, check_result, label="result") + + @parameterized.named_parameters( + # example1 is |0.1 - 1| * 0.1 + |0.3 - 0| * 0.5 + |0.5 - 2| * 1 + # = 0.09 + 0.15 + 1.5 = 1.74 + # example2 is |1 - 0.5| + |2 - 1| + |3 - 5| = 3.5 + # example3 is |3 - 5| = 2 + # average error: (1.74 + 3.5 + 2) / 5.6 = 1.292857 + ( + "_mean_absolute_error_with_preprocessors", + mean_regression_error.MeanAbsoluteError( + preprocessors=[IdTransformPreprocessor()] + ), + 1.292857, + ), + # example1 is |0.1 - 1|^2 * 0.1 + |0.3 - 0|^2 * 0.5 + |0.5 - 2|^2 * 1 + # = 0.081 + 0.045 + 2.25 = 2.376 + # example2 is |1 - 0.5|^2 + |2 - 1|^2 + |3 - 5|^2 = 5.25 + # example3 is |3 - 5|^2 = 4 + # average error: (2.376 + 5.25 + 4) / 5.6 = 2.07607 + ( + "_mean_squared_error_with_preprocessors", + mean_regression_error.MeanSquaredError( + preprocessors=[IdTransformPreprocessor()] + ), + 2.07607, + ), + # example1 is 100 * (|0.1 - 1| / 0.1 * 0.1 + |0.3 - 0| / 0.3 * 0.5 + + # |0.5 - 2| / 0.5 * 1) = 440 + # example2 is 100 * (|1 - 0.5| / 1 + |2 - 1| / 2 + |3 - 5| / 3) = 166.66 + # example3 is 100 * (|3 - 5| / 3) = 66.66 + # average error: (440 + 166.66 + 66.66) / 5.6 = 120.238095 + ( + "_mean_absolute_percentage_error_with_preprocessors", + mean_regression_error.MeanAbsolutePercentageError( + preprocessors=[IdTransformPreprocessor()] + ), + 120.238095, + ), + # example1 is |log(0.1+1) - log(1+1)|^2 * 0.1 + + # |log(0.3+1) - log(0+1)|^2 * 0.5 + |log(0.5+1) - log(2+1)|^2 * 1 =0.55061 + # example2 is |log(1+1) - log(0.5+1)|^2 + |log(2+1) - log(1+1)|^2 + # + |log(3+1) - log(5+1)|^2 = 0.41156 + # example3 is |log(3+1) - log(5+1)|^2 = 0.16440 + # average error: (0.55061 + 0.41156 + 0.16440) / 5.6 = 0.20117 + ( + "_mean_squared_logarithmic_error_with_preprocessors", + mean_regression_error.MeanSquaredLogarithmicError( + preprocessors=[IdTransformPreprocessor()] + ), + 0.20117, + ), + ) + def testRegressionErrorWithWeightsWithPreprocessors(self, metric, expected_value): + computations = metric.computations(example_weighted=True) + computation = computations[0] + example1 = { + "labels": np.array([0.1, 0.3, 0.5]), + "predictions": np.array([1, 0, 2]), + "example_weights": np.array([0.1, 0.5, 1]), + } + example2 = { + "labels": np.array([1, 2, 3]), + "predictions": np.array([0.5, 1.0, 5]), + "example_weights": np.array([1.0]), + } + example3 = { + "labels": np.array([3]), + "predictions": np.array([5]), + "example_weights": np.array([1.0]), + } + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "Create" >> beam.Create([example1, example2, example3]) + | "PreProcess" >> beam.ParDo(computation.preprocessors[0]) + | "Process" >> beam.Map(metric_util.to_standard_metric_inputs) + | "AddSlice" >> beam.Map(lambda x: ((), x)) + | "ComputeMetric" >> beam.CombinePerKey(computation.combiner) + ) + + # pylint: enable=no-value-for-parameter + def check_result(got): + try: + self.assertLen(got, 1) + got_slice_key, got_metrics = got[0] + self.assertEqual(got_slice_key, ()) + key = computation.keys[0] + self.assertIn(key, got_metrics) + self.assertAlmostEqual(got_metrics[key], expected_value, places=5) + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(result, check_result, label="result") + - @parameterized.named_parameters( - # example1 is |0.1 - 1| * 0.1 + |0.3 - 0| * 0.5 + |0.5 - 2| * 1 - # = 0.09 + 0.15 + 1.5 = 1.74 - # example2 is |1 - 0.5| + |2 - 1| + |3 - 5| = 3.5 - # example3 is |3 - 5| = 2 - # average error: (1.74 + 3.5 + 2) / 5.6 = 1.292857 - ( - '_mean_absolute_error', - mean_regression_error.MeanAbsoluteError(), - 1.292857, - ), - # example1 is |0.1 - 1|^2 * 0.1 + |0.3 - 0|^2 * 0.5 + |0.5 - 2|^2 * 1 - # = 0.081 + 0.045 + 2.25 = 2.376 - # example2 is |1 - 0.5|^2 + |2 - 1|^2 + |3 - 5|^2 = 5.25 - # example3 is |3 - 5|^2 = 4 - # average error: (2.376 + 5.25 + 4) / 5.6 = 2.07607 - ( - '_mean_squared_error', - mean_regression_error.MeanSquaredError(), - 2.07607, - ), - # example1 is 100 * (|0.1 - 1| / 0.1 * 0.1 + |0.3 - 0| / 0.3 * 0.5 + - # |0.5 - 2| / 0.5 * 1) = 440 - # example2 is 100 * (|1 - 0.5| / 1 + |2 - 1| / 2 + |3 - 5| / 3) = 166.66 - # example3 is 100 * (|3 - 5| / 3) = 66.66 - # average error: (440 + 166.66 + 66.66) / 5.6 = 120.238095 - ( - '_mean_absolute_percentage_error', - mean_regression_error.MeanAbsolutePercentageError(), - 120.238095, - ), - # example1 is |log(0.1+1) - log(1+1)|^2 * 0.1 + - # |log(0.3+1) - log(0+1)|^2 * 0.5 + |log(0.5+1) - log(2+1)|^2 * 1 =0.55061 - # example2 is |log(1+1) - log(0.5+1)|^2 + |log(2+1) - log(1+1)|^2 - # + |log(3+1) - log(5+1)|^2 = 0.41156 - # example3 is |log(3+1) - log(5+1)|^2 = 0.16440 - # average error: (0.55061 + 0.41156 + 0.16440) / 5.6 = 0.20117 - ( - '_mean_squared_logarithmic_error', - mean_regression_error.MeanSquaredLogarithmicError(), - 0.20117, - ), - ) - def testRegressionErrorWithWeights(self, metric, expected_value): - computations = metric.computations(example_weighted=True) - computation = computations[0] - example1 = { - 'labels': np.array([0.1, 0.3, 0.5]), - 'predictions': np.array([1, 0, 2]), - 'example_weights': np.array([0.1, 0.5, 1]), - } - example2 = { - 'labels': np.array([1, 2, 3]), - 'predictions': np.array([0.5, 1.0, 5]), - 'example_weights': np.array([1.0]), - } - example3 = { - 'labels': np.array([3]), - 'predictions': np.array([5]), - 'example_weights': np.array([1.0]), - } - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'Create' >> beam.Create([example1, example2, example3]) - | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs) - | 'AddSlice' >> beam.Map(lambda x: ((), x)) - | 'ComputeMetric' >> beam.CombinePerKey(computation.combiner) - ) - - # pylint: enable=no-value-for-parameter - def check_result(got): - try: - self.assertLen(got, 1) - got_slice_key, got_metrics = got[0] - self.assertEqual(got_slice_key, ()) - key = computation.keys[0] - self.assertIn(key, got_metrics) - self.assertAlmostEqual(got_metrics[key], expected_value, places=5) - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(result, check_result, label='result') - - @parameterized.named_parameters( - # example1 is |0.1 - 1| * 0.1 + |0.3 - 0| * 0.5 + |0.5 - 2| * 1 - # = 0.09 + 0.15 + 1.5 = 1.74 - # example2 is |1 - 0.5| + |2 - 1| + |3 - 5| = 3.5 - # example3 is |3 - 5| = 2 - # average error: (1.74 + 3.5 + 2) / 5.6 = 1.292857 - ( - '_mean_absolute_error_with_preprocessors', - mean_regression_error.MeanAbsoluteError( - preprocessors=[IdTransformPreprocessor()] - ), - 1.292857, - ), - # example1 is |0.1 - 1|^2 * 0.1 + |0.3 - 0|^2 * 0.5 + |0.5 - 2|^2 * 1 - # = 0.081 + 0.045 + 2.25 = 2.376 - # example2 is |1 - 0.5|^2 + |2 - 1|^2 + |3 - 5|^2 = 5.25 - # example3 is |3 - 5|^2 = 4 - # average error: (2.376 + 5.25 + 4) / 5.6 = 2.07607 - ( - '_mean_squared_error_with_preprocessors', - mean_regression_error.MeanSquaredError( - preprocessors=[IdTransformPreprocessor()] - ), - 2.07607, - ), - # example1 is 100 * (|0.1 - 1| / 0.1 * 0.1 + |0.3 - 0| / 0.3 * 0.5 + - # |0.5 - 2| / 0.5 * 1) = 440 - # example2 is 100 * (|1 - 0.5| / 1 + |2 - 1| / 2 + |3 - 5| / 3) = 166.66 - # example3 is 100 * (|3 - 5| / 3) = 66.66 - # average error: (440 + 166.66 + 66.66) / 5.6 = 120.238095 - ( - '_mean_absolute_percentage_error_with_preprocessors', - mean_regression_error.MeanAbsolutePercentageError( - preprocessors=[IdTransformPreprocessor()] - ), - 120.238095, - ), - # example1 is |log(0.1+1) - log(1+1)|^2 * 0.1 + - # |log(0.3+1) - log(0+1)|^2 * 0.5 + |log(0.5+1) - log(2+1)|^2 * 1 =0.55061 - # example2 is |log(1+1) - log(0.5+1)|^2 + |log(2+1) - log(1+1)|^2 - # + |log(3+1) - log(5+1)|^2 = 0.41156 - # example3 is |log(3+1) - log(5+1)|^2 = 0.16440 - # average error: (0.55061 + 0.41156 + 0.16440) / 5.6 = 0.20117 - ( - '_mean_squared_logarithmic_error_with_preprocessors', - mean_regression_error.MeanSquaredLogarithmicError( - preprocessors=[IdTransformPreprocessor()] - ), - 0.20117, - ), - ) - def testRegressionErrorWithWeightsWithPreprocessors( - self, metric, expected_value - ): - computations = metric.computations(example_weighted=True) - computation = computations[0] - example1 = { - 'labels': np.array([0.1, 0.3, 0.5]), - 'predictions': np.array([1, 0, 2]), - 'example_weights': np.array([0.1, 0.5, 1]), - } - example2 = { - 'labels': np.array([1, 2, 3]), - 'predictions': np.array([0.5, 1.0, 5]), - 'example_weights': np.array([1.0]), - } - example3 = { - 'labels': np.array([3]), - 'predictions': np.array([5]), - 'example_weights': np.array([1.0]), - } - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'Create' >> beam.Create([example1, example2, example3]) - | 'PreProcess' >> beam.ParDo(computation.preprocessors[0]) - | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs) - | 'AddSlice' >> beam.Map(lambda x: ((), x)) - | 'ComputeMetric' >> beam.CombinePerKey(computation.combiner) - ) - - # pylint: enable=no-value-for-parameter - def check_result(got): - try: - self.assertLen(got, 1) - got_slice_key, got_metrics = got[0] - self.assertEqual(got_slice_key, ()) - key = computation.keys[0] - self.assertIn(key, got_metrics) - self.assertAlmostEqual(got_metrics[key], expected_value, places=5) - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(result, check_result, label='result') - - -if __name__ == '__main__': - absltest.main() +if __name__ == "__main__": + absltest.main() diff --git a/tensorflow_model_analysis/metrics/metric_specs.py b/tensorflow_model_analysis/metrics/metric_specs.py index 2c2a85ae01..ded60eee4b 100644 --- a/tensorflow_model_analysis/metrics/metric_specs.py +++ b/tensorflow_model_analysis/metrics/metric_specs.py @@ -17,29 +17,42 @@ import importlib import json import re -from typing import Any, Dict, FrozenSet, Iterable, Iterator, List, NamedTuple, Optional, Tuple, Type, Union +from typing import ( + Any, + Dict, + FrozenSet, + Iterable, + Iterator, + List, + NamedTuple, + Optional, + Tuple, + Type, + Union, +) import tensorflow as tf -from tensorflow_model_analysis.metrics import aggregation -from tensorflow_model_analysis.metrics import binary_confusion_matrices -from tensorflow_model_analysis.metrics import calibration -from tensorflow_model_analysis.metrics import calibration_plot -from tensorflow_model_analysis.metrics import confusion_matrix_metrics -from tensorflow_model_analysis.metrics import confusion_matrix_plot -from tensorflow_model_analysis.metrics import example_count -from tensorflow_model_analysis.metrics import metric_types -from tensorflow_model_analysis.metrics import metric_util -from tensorflow_model_analysis.metrics import multi_class_confusion_matrix_plot -from tensorflow_model_analysis.metrics import tf_metric_wrapper -from tensorflow_model_analysis.metrics import weighted_example_count +from tensorflow_metadata.proto.v0 import schema_pb2 + +from tensorflow_model_analysis.metrics import ( + aggregation, + binary_confusion_matrices, + calibration, + calibration_plot, + confusion_matrix_metrics, + confusion_matrix_plot, + example_count, + metric_types, + metric_util, + multi_class_confusion_matrix_plot, + tf_metric_wrapper, + weighted_example_count, +) from tensorflow_model_analysis.proto import config_pb2 from tensorflow_model_analysis.slicer import slicer_lib as slicer from tensorflow_model_analysis.utils import model_util from tensorflow_model_analysis.utils.keras_lib import tf_keras -from tensorflow_metadata.proto.v0 import schema_pb2 - - _TF_LOSSES_MODULE = tf_keras.losses.Loss().__class__.__module__ _TFOrTFMAMetricOrLoss = Union[ @@ -48,100 +61,104 @@ _TFMetricOrLoss = Union[tf_keras.metrics.Metric, tf_keras.losses.Loss] # List of metrics or losses optionally keyed by output name. -_MetricsOrLosses = Union[List[_TFOrTFMAMetricOrLoss], - Dict[str, List[_TFOrTFMAMetricOrLoss]]] +_MetricsOrLosses = Union[ + List[_TFOrTFMAMetricOrLoss], Dict[str, List[_TFOrTFMAMetricOrLoss]] +] # TF config settings that TFMA only supports default values for because the # parameters are not supported by the TFMA implementation of the metric. The # settings are keyed by class name -> arg_name -> [allowed defaults]. _UNSUPPORTED_TF_SETTINGS = { - '': { # All classes + "": { # All classes # TFMA only implements float based versions of TF metrics. - 'dtype': [None, 'float32', tf.float32] + "dtype": [None, "float32", tf.float32] }, - 'AUC': { - 'multi_label': [False], - 'num_labels': [None], - 'label_weights': [None], - 'from_logits': [False], + "AUC": { + "multi_label": [False], + "num_labels": [None], + "label_weights": [None], + "from_logits": [False], }, - 'MeanAbsoluteError': {'reduction': ['auto']}, - 'MeanSquaredError': {'reduction': ['auto']}, - 'MeanAbsolutePercentageError': {'reduction': ['auto']}, - 'MeanSqauredLogarithmicError': {'reduction': ['auto']}, - 'BinaryCrossEntropy': {'reduction': ['auto']}, - 'CategoricalCrossEntropy': {'reduction': ['auto']}, + "MeanAbsoluteError": {"reduction": ["auto"]}, + "MeanSquaredError": {"reduction": ["auto"]}, + "MeanAbsolutePercentageError": {"reduction": ["auto"]}, + "MeanSqauredLogarithmicError": {"reduction": ["auto"]}, + "BinaryCrossEntropy": {"reduction": ["auto"]}, + "CategoricalCrossEntropy": {"reduction": ["auto"]}, } -def config_from_metric( - metric: _TFOrTFMAMetricOrLoss) -> config_pb2.MetricConfig: - """Returns MetricConfig associated with given metric instance.""" - if isinstance(metric, tf_keras.metrics.Metric): - if _is_supported_tf_metric(metric): - return _remove_unsupported_tf_settings(_serialize_tf_metric(metric)) +def config_from_metric(metric: _TFOrTFMAMetricOrLoss) -> config_pb2.MetricConfig: + """Returns MetricConfig associated with given metric instance.""" + if isinstance(metric, tf_keras.metrics.Metric): + if _is_supported_tf_metric(metric): + return _remove_unsupported_tf_settings(_serialize_tf_metric(metric)) + else: + return _serialize_tf_metric(metric) + elif isinstance(metric, tf_keras.losses.Loss): + # For loss like MeanAbsoluteError, TFMA provides native support. + # The support should be checked here. + if _is_supported_tf_metric(metric): + return _remove_unsupported_tf_settings(_serialize_tf_metric(metric)) + else: + return _serialize_tf_loss(metric) + elif isinstance(metric, metric_types.Metric): + return _serialize_tfma_metric(metric) else: - return _serialize_tf_metric(metric) - elif isinstance(metric, tf_keras.losses.Loss): - # For loss like MeanAbsoluteError, TFMA provides native support. - # The support should be checked here. - if _is_supported_tf_metric(metric): - return _remove_unsupported_tf_settings(_serialize_tf_metric(metric)) + raise NotImplementedError( + f"unknown metric type {type(metric)}: metric={metric}" + ) + + +def _example_weighted_default( + eval_config: config_pb2.EvalConfig, spec: config_pb2.MetricsSpec +) -> bool: + """Returns tue if example weighted is the default for the given spec. + + If any of the models and/or outputs have example weights then example weighted + will be true by default. + + Args: + ---- + eval_config: Eval config. + spec: Metrics spec to get default for. + """ + + # Default to using example example weights if an example weight key is set. + def has_example_weight_key(model_spec: config_pb2.ModelSpec) -> bool: + if model_spec.example_weight_key: + return True + if spec.output_names and model_spec.example_weight_keys: + for output_name in spec.output_names: + if output_name in model_spec.example_weight_keys: + return True + return False + + if spec.model_names: + for model_name in spec.model_names: + model_spec = model_util.get_model_spec(eval_config, model_name) + if model_spec and has_example_weight_key(model_spec): + return True else: - return _serialize_tf_loss(metric) - elif isinstance(metric, metric_types.Metric): - return _serialize_tfma_metric(metric) - else: - raise NotImplementedError('unknown metric type {}: metric={}'.format( - type(metric), metric)) - - -def _example_weighted_default(eval_config: config_pb2.EvalConfig, - spec: config_pb2.MetricsSpec) -> bool: - """Returns tue if example weighted is the default for the given spec. - - If any of the models and/or outputs have example weights then example weighted - will be true by default. - - Args: - eval_config: Eval config. - spec: Metrics spec to get default for. - """ - - # Default to using example example weights if an example weight key is set. - def has_example_weight_key(model_spec: config_pb2.ModelSpec) -> bool: - if model_spec.example_weight_key: - return True - if spec.output_names and model_spec.example_weight_keys: - for output_name in spec.output_names: - if output_name in model_spec.example_weight_keys: - return True + for model_spec in eval_config.model_specs: + if has_example_weight_key(model_spec): + return True return False - if spec.model_names: - for model_name in spec.model_names: - model_spec = model_util.get_model_spec(eval_config, model_name) - if model_spec and has_example_weight_key(model_spec): - return True - else: - for model_spec in eval_config.model_specs: - if has_example_weight_key(model_spec): - return True - return False - - -def _example_weight_options(eval_config: config_pb2.EvalConfig, - spec: config_pb2.MetricsSpec) -> List[bool]: - """Returns example weight options for given spec.""" - result = [] - if not spec.HasField('example_weights'): - result.append(_example_weighted_default(eval_config, spec)) - else: - if spec.example_weights.weighted: - result.append(True) - if spec.example_weights.unweighted: - result.append(False) - return result + +def _example_weight_options( + eval_config: config_pb2.EvalConfig, spec: config_pb2.MetricsSpec +) -> List[bool]: + """Returns example weight options for given spec.""" + result = [] + if not spec.HasField("example_weights"): + result.append(_example_weighted_default(eval_config, spec)) + else: + if spec.example_weights.weighted: + result.append(True) + if spec.example_weights.unweighted: + result.append(False) + return result def specs_from_metrics( @@ -154,138 +171,151 @@ def specs_from_metrics( aggregate: Optional[config_pb2.AggregationOptions] = None, query_key: Optional[str] = None, include_example_count: Optional[bool] = None, - include_weighted_example_count: Optional[bool] = None + include_weighted_example_count: Optional[bool] = None, ) -> List[config_pb2.MetricsSpec]: - """Returns specs for tf_keras.metrics/losses or tfma.metrics classes. - - Examples: - - metrics_specs = specs_from_metrics( - [ - tf_keras.metrics.BinaryAccuracy(), - tfma.metrics.AUC(), - tfma.metrics.MeanLabel(), - tfma.metrics.MeanPrediction() - ... - ], - unweighted=[ - tfma.metrics.Precision(), - tfma.metrics.Recall() - ]) - - metrics_specs = specs_from_metrics({ - 'output1': [ - tf_keras.metrics.BinaryAccuracy(), - tfma.metrics.AUC(), - tfma.metrics.MeanLabel(), - tfma.metrics.MeanPrediction() - ... - ], - 'output2': [ - tfma.metrics.Precision(), - tfma.metrics.Recall(), - ] - }) - - Args: - metrics: List of tfma.metrics.Metric, tf_keras.metrics.Metric, or - tf_keras.losses.Loss. For multi-output models a dict of dicts may be - passed where the first dict is indexed by the output_name. Whether these - metrics are weighted or not will be determined based on whether the - ModelSpec associated with the metrics contains example weight key settings - or not. - unweighted_metrics: Same as metrics only these metrics will not be weighted - by example_weight regardless of the example weight key settings. - model_names: Optional model names (if multi-model evaluation). - output_names: Optional output names (if multi-output models). If the metrics - are a dict this should not be set. - output_weights: Optional output weights for creating overall metric - aggregated across outputs (if multi-output model). If a weight is not - provided for an output, it's weight defaults to 0.0 (i.e. output ignored). - binarize: Optional settings for binarizing multi-class/multi-label metrics. - aggregate: Optional settings for aggregating multi-class/multi-label - metrics. - query_key: Optional query key for query/ranking based metrics. - include_example_count: True to add example_count metric. Default is True. - include_weighted_example_count: True to add weighted example_count metric. - Default is True. A weighted example count will be added per output for - multi-output models. - - Returns: - MetricsSpecs based on options provided. A separate spec is returned for - weighted vs unweighted metrics. A separate spec is also returned for each - output if a dict of metrics per output is passed. - """ - if isinstance(metrics, dict) and output_names: - raise ValueError('metrics cannot be a dict when output_names is used: ' - 'metrics={}, output_names={}'.format( - metrics, output_names)) - if (metrics and unweighted_metrics and - isinstance(metrics, dict) != isinstance(unweighted_metrics, dict)): - raise ValueError( - 'metrics and unweighted_metrics must both be either dicts or lists: ' - f'metrics={metrics}, unweighted_metrics={unweighted_metrics}') - - if isinstance(metrics, dict) or isinstance(unweighted_metrics, dict): - metrics_dict = metrics if isinstance(metrics, dict) else {} - unweighted_metrics_dict = ( - unweighted_metrics if isinstance(unweighted_metrics, dict) else {}) - specs = [] - output_names = set(metrics_dict) | set(unweighted_metrics_dict) - for output_name in sorted(output_names): - specs.extend( - specs_from_metrics( - metrics_dict.get(output_name), - unweighted_metrics=unweighted_metrics_dict.get(output_name), - model_names=model_names, - output_names=[output_name], - binarize=binarize, - aggregate=aggregate, - include_example_count=include_example_count, - include_weighted_example_count=include_weighted_example_count)) - include_example_count = False - return specs + """Returns specs for tf_keras.metrics/losses or tfma.metrics classes. + + Examples: + -------- + metrics_specs = specs_from_metrics( + [ + tf_keras.metrics.BinaryAccuracy(), + tfma.metrics.AUC(), + tfma.metrics.MeanLabel(), + tfma.metrics.MeanPrediction() + ... + ], + unweighted=[ + tfma.metrics.Precision(), + tfma.metrics.Recall() + ]) + + metrics_specs = specs_from_metrics({ + 'output1': [ + tf_keras.metrics.BinaryAccuracy(), + tfma.metrics.AUC(), + tfma.metrics.MeanLabel(), + tfma.metrics.MeanPrediction() + ... + ], + 'output2': [ + tfma.metrics.Precision(), + tfma.metrics.Recall(), + ] + }) + + Args: + ---- + metrics: List of tfma.metrics.Metric, tf_keras.metrics.Metric, or + tf_keras.losses.Loss. For multi-output models a dict of dicts may be + passed where the first dict is indexed by the output_name. Whether these + metrics are weighted or not will be determined based on whether the + ModelSpec associated with the metrics contains example weight key settings + or not. + unweighted_metrics: Same as metrics only these metrics will not be weighted + by example_weight regardless of the example weight key settings. + model_names: Optional model names (if multi-model evaluation). + output_names: Optional output names (if multi-output models). If the metrics + are a dict this should not be set. + output_weights: Optional output weights for creating overall metric + aggregated across outputs (if multi-output model). If a weight is not + provided for an output, it's weight defaults to 0.0 (i.e. output ignored). + binarize: Optional settings for binarizing multi-class/multi-label metrics. + aggregate: Optional settings for aggregating multi-class/multi-label + metrics. + query_key: Optional query key for query/ranking based metrics. + include_example_count: True to add example_count metric. Default is True. + include_weighted_example_count: True to add weighted example_count metric. + Default is True. A weighted example count will be added per output for + multi-output models. + + Returns: + ------- + MetricsSpecs based on options provided. A separate spec is returned for + weighted vs unweighted metrics. A separate spec is also returned for each + output if a dict of metrics per output is passed. + """ + if isinstance(metrics, dict) and output_names: + raise ValueError( + "metrics cannot be a dict when output_names is used: " + f"metrics={metrics}, output_names={output_names}" + ) + if ( + metrics + and unweighted_metrics + and isinstance(metrics, dict) != isinstance(unweighted_metrics, dict) + ): + raise ValueError( + "metrics and unweighted_metrics must both be either dicts or lists: " + f"metrics={metrics}, unweighted_metrics={unweighted_metrics}" + ) + + if isinstance(metrics, dict) or isinstance(unweighted_metrics, dict): + metrics_dict = metrics if isinstance(metrics, dict) else {} + unweighted_metrics_dict = ( + unweighted_metrics if isinstance(unweighted_metrics, dict) else {} + ) + specs = [] + output_names = set(metrics_dict) | set(unweighted_metrics_dict) + for output_name in sorted(output_names): + specs.extend( + specs_from_metrics( + metrics_dict.get(output_name), + unweighted_metrics=unweighted_metrics_dict.get(output_name), + model_names=model_names, + output_names=[output_name], + binarize=binarize, + aggregate=aggregate, + include_example_count=include_example_count, + include_weighted_example_count=include_weighted_example_count, + ) + ) + include_example_count = False + return specs + + if include_example_count is None: + include_example_count = True + if include_weighted_example_count is None: + include_weighted_example_count = True + + # Add the computations for the example counts and weights since they are + # independent of the model and class ID. + specs = example_count_specs( + model_names=model_names, + output_names=output_names, + output_weights=output_weights, + include_example_count=include_example_count, + include_weighted_example_count=include_weighted_example_count, + ) - if include_example_count is None: - include_example_count = True - if include_weighted_example_count is None: - include_weighted_example_count = True - - # Add the computations for the example counts and weights since they are - # independent of the model and class ID. - specs = example_count_specs( - model_names=model_names, - output_names=output_names, - output_weights=output_weights, - include_example_count=include_example_count, - include_weighted_example_count=include_weighted_example_count) - - if metrics: - specs.append( - config_pb2.MetricsSpec( - metrics=[config_from_metric(metric) for metric in metrics], - model_names=model_names, - output_names=output_names, - output_weights=output_weights, - binarize=binarize, - aggregate=aggregate, - example_weights=None, - query_key=query_key)) - if unweighted_metrics: - specs.append( - config_pb2.MetricsSpec( - metrics=[ - config_from_metric(metric) for metric in unweighted_metrics - ], - model_names=model_names, - output_names=output_names, - output_weights=output_weights, - binarize=binarize, - aggregate=aggregate, - example_weights=config_pb2.ExampleWeightOptions(unweighted=True), - query_key=query_key)) + if metrics: + specs.append( + config_pb2.MetricsSpec( + metrics=[config_from_metric(metric) for metric in metrics], + model_names=model_names, + output_names=output_names, + output_weights=output_weights, + binarize=binarize, + aggregate=aggregate, + example_weights=None, + query_key=query_key, + ) + ) + if unweighted_metrics: + specs.append( + config_pb2.MetricsSpec( + metrics=[config_from_metric(metric) for metric in unweighted_metrics], + model_names=model_names, + output_names=output_names, + output_weights=output_weights, + binarize=binarize, + aggregate=aggregate, + example_weights=config_pb2.ExampleWeightOptions(unweighted=True), + query_key=query_key, + ) + ) - return specs + return specs def example_count_specs( @@ -293,41 +323,47 @@ def example_count_specs( output_names: Optional[List[str]] = None, output_weights: Optional[Dict[str, float]] = None, include_example_count: bool = True, - include_weighted_example_count: bool = True + include_weighted_example_count: bool = True, ) -> List[config_pb2.MetricsSpec]: - """Returns metric specs for example count and weighted example counts. - - Args: - model_names: Optional list of model names (if multi-model evaluation). - output_names: Optional list of output names (if multi-output model). - output_weights: Optional output weights for creating overall metric - aggregated across outputs (if multi-output model). If a weight is not - provided for an output, it's weight defaults to 0.0 (i.e. output ignored). - include_example_count: True to add example_count metric. - include_weighted_example_count: True to add weighted_example_count metric. A - weighted example count will be added per output for multi-output models. - """ - specs = [] - if include_example_count: - metric_config = _serialize_tfma_metric(example_count.ExampleCount()) - specs.append( - config_pb2.MetricsSpec( - metrics=[metric_config], - model_names=model_names, - example_weights=config_pb2.ExampleWeightOptions(unweighted=True))) - if include_weighted_example_count: - # TODO(b/143180976): Replace WeightedExampleCount with ExampleCount once the - # UI is updated to distinguish weighted for unweighted metrics. - metric_config = _serialize_tfma_metric( - weighted_example_count.WeightedExampleCount()) - specs.append( - config_pb2.MetricsSpec( - metrics=[metric_config], - model_names=model_names, - output_names=output_names, - output_weights=output_weights, - example_weights=config_pb2.ExampleWeightOptions(weighted=True))) - return specs + """Returns metric specs for example count and weighted example counts. + + Args: + ---- + model_names: Optional list of model names (if multi-model evaluation). + output_names: Optional list of output names (if multi-output model). + output_weights: Optional output weights for creating overall metric + aggregated across outputs (if multi-output model). If a weight is not + provided for an output, it's weight defaults to 0.0 (i.e. output ignored). + include_example_count: True to add example_count metric. + include_weighted_example_count: True to add weighted_example_count metric. A + weighted example count will be added per output for multi-output models. + """ + specs = [] + if include_example_count: + metric_config = _serialize_tfma_metric(example_count.ExampleCount()) + specs.append( + config_pb2.MetricsSpec( + metrics=[metric_config], + model_names=model_names, + example_weights=config_pb2.ExampleWeightOptions(unweighted=True), + ) + ) + if include_weighted_example_count: + # TODO(b/143180976): Replace WeightedExampleCount with ExampleCount once the + # UI is updated to distinguish weighted for unweighted metrics. + metric_config = _serialize_tfma_metric( + weighted_example_count.WeightedExampleCount() + ) + specs.append( + config_pb2.MetricsSpec( + metrics=[metric_config], + model_names=model_names, + output_names=output_names, + output_weights=output_weights, + example_weights=config_pb2.ExampleWeightOptions(weighted=True), + ) + ) + return specs def default_regression_specs( @@ -340,40 +376,43 @@ def default_regression_specs( min_value: Optional[float] = None, max_value: Optional[float] = None, ) -> List[config_pb2.MetricsSpec]: - """Returns default metric specs for for regression problems. - - Args: - model_names: Optional model names (if multi-model evaluation). - output_names: Optional list of output names (if multi-output model). - output_weights: Optional output weights for creating overall metric - aggregated across outputs (if multi-output model). If a weight is not - provided for an output, it's weight defaults to 0.0 (i.e. output ignored). - loss_functions: Loss functions to use (if None MSE is used). - min_value: Min value for calibration plot (if None no plot will be created). - max_value: Max value for calibration plot (if None no plot will be created). - """ - - if loss_functions is None: - loss_functions = [tf_keras.metrics.MeanSquaredError(name='mse')] - - metrics = [ - tf_keras.metrics.Accuracy(name='accuracy'), - calibration.MeanLabel(name='mean_label'), - calibration.MeanPrediction(name='mean_prediction'), - calibration.Calibration(name='calibration'), - ] - for fn in loss_functions: - metrics.append(fn) - if min_value is not None and max_value is not None: - metrics.append( - calibration_plot.CalibrationPlot( - name='calibration_plot', left=min_value, right=max_value)) - - return specs_from_metrics( - metrics, - model_names=model_names, - output_names=output_names, - output_weights=output_weights) + """Returns default metric specs for for regression problems. + + Args: + ---- + model_names: Optional model names (if multi-model evaluation). + output_names: Optional list of output names (if multi-output model). + output_weights: Optional output weights for creating overall metric + aggregated across outputs (if multi-output model). If a weight is not + provided for an output, it's weight defaults to 0.0 (i.e. output ignored). + loss_functions: Loss functions to use (if None MSE is used). + min_value: Min value for calibration plot (if None no plot will be created). + max_value: Max value for calibration plot (if None no plot will be created). + """ + if loss_functions is None: + loss_functions = [tf_keras.metrics.MeanSquaredError(name="mse")] + + metrics = [ + tf_keras.metrics.Accuracy(name="accuracy"), + calibration.MeanLabel(name="mean_label"), + calibration.MeanPrediction(name="mean_prediction"), + calibration.Calibration(name="calibration"), + ] + for fn in loss_functions: + metrics.append(fn) + if min_value is not None and max_value is not None: + metrics.append( + calibration_plot.CalibrationPlot( + name="calibration_plot", left=min_value, right=max_value + ) + ) + + return specs_from_metrics( + metrics, + model_names=model_names, + output_names=output_names, + output_weights=output_weights, + ) def default_binary_classification_specs( @@ -382,48 +421,51 @@ def default_binary_classification_specs( output_weights: Optional[Dict[str, float]] = None, binarize: Optional[config_pb2.BinarizationOptions] = None, aggregate: Optional[config_pb2.AggregationOptions] = None, - include_loss: bool = True) -> List[config_pb2.MetricsSpec]: - """Returns default metric specs for binary classification problems. - - Args: - model_names: Optional model names (if multi-model evaluation). - output_names: Optional list of output names (if multi-output model). - output_weights: Optional output weights for creating overall metric - aggregated across outputs (if multi-output model). If a weight is not - provided for an output, it's weight defaults to 0.0 (i.e. output ignored). - binarize: Optional settings for binarizing multi-class/multi-label metrics. - aggregate: Optional settings for aggregating multi-class/multi-label - metrics. - include_loss: True to include loss. - """ - - metrics = [ - confusion_matrix_metrics.BinaryAccuracy(name='binary_accuracy'), - confusion_matrix_metrics.AUC( - name='auc', - num_thresholds=binary_confusion_matrices.DEFAULT_NUM_THRESHOLDS), - confusion_matrix_metrics.AUC( - name='auc_precison_recall', # Matches default name used by estimator. - curve='PR', - num_thresholds=binary_confusion_matrices.DEFAULT_NUM_THRESHOLDS), - confusion_matrix_metrics.Precision(name='precision'), - confusion_matrix_metrics.Recall(name='recall'), - calibration.MeanLabel(name='mean_label'), - calibration.MeanPrediction(name='mean_prediction'), - calibration.Calibration(name='calibration'), - confusion_matrix_plot.ConfusionMatrixPlot(name='confusion_matrix_plot'), - calibration_plot.CalibrationPlot(name='calibration_plot') - ] - if include_loss: - metrics.append(tf_keras.metrics.BinaryCrossentropy(name='loss')) - - return specs_from_metrics( - metrics, - model_names=model_names, - output_names=output_names, - output_weights=output_weights, - binarize=binarize, - aggregate=aggregate) + include_loss: bool = True, +) -> List[config_pb2.MetricsSpec]: + """Returns default metric specs for binary classification problems. + + Args: + ---- + model_names: Optional model names (if multi-model evaluation). + output_names: Optional list of output names (if multi-output model). + output_weights: Optional output weights for creating overall metric + aggregated across outputs (if multi-output model). If a weight is not + provided for an output, it's weight defaults to 0.0 (i.e. output ignored). + binarize: Optional settings for binarizing multi-class/multi-label metrics. + aggregate: Optional settings for aggregating multi-class/multi-label + metrics. + include_loss: True to include loss. + """ + metrics = [ + confusion_matrix_metrics.BinaryAccuracy(name="binary_accuracy"), + confusion_matrix_metrics.AUC( + name="auc", num_thresholds=binary_confusion_matrices.DEFAULT_NUM_THRESHOLDS + ), + confusion_matrix_metrics.AUC( + name="auc_precison_recall", # Matches default name used by estimator. + curve="PR", + num_thresholds=binary_confusion_matrices.DEFAULT_NUM_THRESHOLDS, + ), + confusion_matrix_metrics.Precision(name="precision"), + confusion_matrix_metrics.Recall(name="recall"), + calibration.MeanLabel(name="mean_label"), + calibration.MeanPrediction(name="mean_prediction"), + calibration.Calibration(name="calibration"), + confusion_matrix_plot.ConfusionMatrixPlot(name="confusion_matrix_plot"), + calibration_plot.CalibrationPlot(name="calibration_plot"), + ] + if include_loss: + metrics.append(tf_keras.metrics.BinaryCrossentropy(name="loss")) + + return specs_from_metrics( + metrics, + model_names=model_names, + output_names=output_names, + output_weights=output_weights, + binarize=binarize, + aggregate=aggregate, + ) def default_multi_class_classification_specs( @@ -432,86 +474,98 @@ def default_multi_class_classification_specs( output_weights: Optional[Dict[str, float]] = None, binarize: Optional[config_pb2.BinarizationOptions] = None, aggregate: Optional[config_pb2.AggregationOptions] = None, - sparse: bool = True) -> List[config_pb2.MetricsSpec]: - """Returns default metric specs for multi-class classification problems. - - Args: - model_names: Optional model names if multi-model evaluation. - output_names: Optional list of output names (if multi-output model). - output_weights: Optional output weights for creating overall metric - aggregated across outputs (if multi-output model). If a weight is not - provided for an output, it's weight defaults to 0.0 (i.e. output ignored). - binarize: Optional settings for binarizing multi-class/multi-label metrics. - aggregate: Optional settings for aggregating multi-class/multi-label - metrics. - sparse: True if the labels are sparse. - """ - - if sparse: - metrics = [ - tf_keras.metrics.SparseCategoricalCrossentropy(name='loss'), - tf_keras.metrics.SparseCategoricalAccuracy(name='accuracy'), - ] - else: - metrics = [ - tf_keras.metrics.CategoricalCrossentropy(name='loss'), - tf_keras.metrics.CategoricalAccuracy(name='accuracy'), - ] - metrics.append( - multi_class_confusion_matrix_plot.MultiClassConfusionMatrixPlot()) - if binarize is not None: - for top_k in binarize.top_k_list.values: - metrics.extend([ - confusion_matrix_metrics.Precision(name='precision', top_k=top_k), - confusion_matrix_metrics.Recall(name='recall', top_k=top_k) - ]) - binarize_without_top_k = config_pb2.BinarizationOptions() - binarize_without_top_k.CopyFrom(binarize) - binarize_without_top_k.ClearField('top_k_list') - binarize = binarize_without_top_k - multi_class_metrics = specs_from_metrics( - metrics, - model_names=model_names, - output_names=output_names, - output_weights=output_weights) - if aggregate is None: - aggregate = config_pb2.AggregationOptions(micro_average=True) - multi_class_metrics.extend( - default_binary_classification_specs( - model_names=model_names, - output_names=output_names, - output_weights=output_weights, - binarize=binarize, - aggregate=aggregate)) - return multi_class_metrics + sparse: bool = True, +) -> List[config_pb2.MetricsSpec]: + """Returns default metric specs for multi-class classification problems. + + Args: + ---- + model_names: Optional model names if multi-model evaluation. + output_names: Optional list of output names (if multi-output model). + output_weights: Optional output weights for creating overall metric + aggregated across outputs (if multi-output model). If a weight is not + provided for an output, it's weight defaults to 0.0 (i.e. output ignored). + binarize: Optional settings for binarizing multi-class/multi-label metrics. + aggregate: Optional settings for aggregating multi-class/multi-label + metrics. + sparse: True if the labels are sparse. + """ + if sparse: + metrics = [ + tf_keras.metrics.SparseCategoricalCrossentropy(name="loss"), + tf_keras.metrics.SparseCategoricalAccuracy(name="accuracy"), + ] + else: + metrics = [ + tf_keras.metrics.CategoricalCrossentropy(name="loss"), + tf_keras.metrics.CategoricalAccuracy(name="accuracy"), + ] + metrics.append(multi_class_confusion_matrix_plot.MultiClassConfusionMatrixPlot()) + if binarize is not None: + for top_k in binarize.top_k_list.values: + metrics.extend( + [ + confusion_matrix_metrics.Precision(name="precision", top_k=top_k), + confusion_matrix_metrics.Recall(name="recall", top_k=top_k), + ] + ) + binarize_without_top_k = config_pb2.BinarizationOptions() + binarize_without_top_k.CopyFrom(binarize) + binarize_without_top_k.ClearField("top_k_list") + binarize = binarize_without_top_k + multi_class_metrics = specs_from_metrics( + metrics, + model_names=model_names, + output_names=output_names, + output_weights=output_weights, + ) + if aggregate is None: + aggregate = config_pb2.AggregationOptions(micro_average=True) + multi_class_metrics.extend( + default_binary_classification_specs( + model_names=model_names, + output_names=output_names, + output_weights=output_weights, + binarize=binarize, + aggregate=aggregate, + ) + ) + return multi_class_metrics def metric_instance( metric_config: config_pb2.MetricConfig, - tfma_metric_classes: Optional[Dict[str, Type[metric_types.Metric]]] = None + tfma_metric_classes: Optional[Dict[str, Type[metric_types.Metric]]] = None, ) -> metric_types.Metric: - """Creates instance of metric associated with config.""" - if tfma_metric_classes is None: - tfma_metric_classes = metric_types.registered_metrics() - if metric_config.class_name in tfma_metric_classes: - return _deserialize_tfma_metric(metric_config, tfma_metric_classes) - elif not metric_config.module: - return _deserialize_tf_metric(metric_config, {}) # pytype: disable=bad-return-type # typed-keras - else: - cls = getattr( - importlib.import_module(metric_config.module), metric_config.class_name) - if issubclass(cls, tf_keras.metrics.Metric): - return _deserialize_tf_metric(metric_config, - {metric_config.class_name: cls}) # pytype: disable=bad-return-type # typed-keras - elif issubclass(cls, tf_keras.losses.Loss): - return _deserialize_tf_loss(metric_config, - {metric_config.class_name: cls}) # pytype: disable=bad-return-type # typed-keras - elif issubclass(cls, metric_types.Metric): - return _deserialize_tfma_metric(metric_config, - {metric_config.class_name: cls}) + """Creates instance of metric associated with config.""" + if tfma_metric_classes is None: + tfma_metric_classes = metric_types.registered_metrics() + if metric_config.class_name in tfma_metric_classes: + return _deserialize_tfma_metric(metric_config, tfma_metric_classes) + elif not metric_config.module: + return _deserialize_tf_metric( + metric_config, {} + ) # pytype: disable=bad-return-type # typed-keras else: - raise NotImplementedError('unknown metric type {}: metric={}'.format( - cls, metric_config)) + cls = getattr( + importlib.import_module(metric_config.module), metric_config.class_name + ) + if issubclass(cls, tf_keras.metrics.Metric): + return _deserialize_tf_metric( + metric_config, {metric_config.class_name: cls} + ) # pytype: disable=bad-return-type # typed-keras + elif issubclass(cls, tf_keras.losses.Loss): + return _deserialize_tf_loss( + metric_config, {metric_config.class_name: cls} + ) # pytype: disable=bad-return-type # typed-keras + elif issubclass(cls, metric_types.Metric): + return _deserialize_tfma_metric( + metric_config, {metric_config.class_name: cls} + ) + else: + raise NotImplementedError( + f"unknown metric type {cls}: metric={metric_config}" + ) def _keys_for_metric( @@ -521,750 +575,816 @@ def _keys_for_metric( sub_keys: List[Optional[metric_types.SubKey]], example_weights: List[Optional[bool]], ) -> Iterator[metric_types.MetricKey]: - """Yields all non-diff keys for a specific metric name.""" - for model_name in spec.model_names or ['']: - for output_name in spec.output_names or ['']: - for sub_key in sub_keys: - for example_weighted in example_weights: - key = metric_types.MetricKey( - name=metric_name, - model_name=model_name, - output_name=output_name, - sub_key=sub_key, - aggregation_type=aggregation_type, - example_weighted=example_weighted) - yield key + """Yields all non-diff keys for a specific metric name.""" + for model_name in spec.model_names or [""]: + for output_name in spec.output_names or [""]: + for sub_key in sub_keys: + for example_weighted in example_weights: + key = metric_types.MetricKey( + name=metric_name, + model_name=model_name, + output_name=output_name, + sub_key=sub_key, + aggregation_type=aggregation_type, + example_weighted=example_weighted, + ) + yield key def keys_and_metrics_from_specs( - eval_config: config_pb2.EvalConfig, - metrics_specs: Iterable[config_pb2.MetricsSpec] -) -> Iterator[Tuple[metric_types.MetricKey, config_pb2.MetricConfig, - metric_types.Metric]]: - """Yields key, config, instance tuples for each non-diff metric in specs.""" - tfma_metric_classes = metric_types.registered_metrics() - for spec in metrics_specs: - for aggregation_type, sub_keys in _create_sub_keys(spec).items(): - for metric_config in spec.metrics: - instance = metric_instance(metric_config, tfma_metric_classes) - for key in _keys_for_metric(instance.name, spec, aggregation_type, - sub_keys, - _example_weight_options(eval_config, spec)): - yield key, metric_config, instance + eval_config: config_pb2.EvalConfig, metrics_specs: Iterable[config_pb2.MetricsSpec] +) -> Iterator[ + Tuple[metric_types.MetricKey, config_pb2.MetricConfig, metric_types.Metric] +]: + """Yields key, config, instance tuples for each non-diff metric in specs.""" + tfma_metric_classes = metric_types.registered_metrics() + for spec in metrics_specs: + for aggregation_type, sub_keys in _create_sub_keys(spec).items(): + for metric_config in spec.metrics: + instance = metric_instance(metric_config, tfma_metric_classes) + for key in _keys_for_metric( + instance.name, + spec, + aggregation_type, + sub_keys, + _example_weight_options(eval_config, spec), + ): + yield key, metric_config, instance def metric_keys_to_skip_for_confidence_intervals( - metrics_specs: Iterable[config_pb2.MetricsSpec], - eval_config: config_pb2.EvalConfig) -> FrozenSet[metric_types.MetricKey]: - """Returns metric keys not to be displayed with confidence intervals.""" - skipped_keys = [] - for key, _, instance in keys_and_metrics_from_specs(eval_config, - metrics_specs): - # if metric does not implement compute_confidence_interval, do not skip - if not getattr(instance, 'compute_confidence_interval', True): - skipped_keys.append(key) - return frozenset(skipped_keys) + metrics_specs: Iterable[config_pb2.MetricsSpec], eval_config: config_pb2.EvalConfig +) -> FrozenSet[metric_types.MetricKey]: + """Returns metric keys not to be displayed with confidence intervals.""" + skipped_keys = [] + for key, _, instance in keys_and_metrics_from_specs(eval_config, metrics_specs): + # if metric does not implement compute_confidence_interval, do not skip + if not getattr(instance, "compute_confidence_interval", True): + skipped_keys.append(key) + return frozenset(skipped_keys) # Optional slice and associated threshold setting. If slice is not set it # matches all slices. -_SliceAndThreshold = Tuple[Optional[Union[config_pb2.SlicingSpec, - config_pb2.CrossSlicingSpec]], - Union[config_pb2.GenericChangeThreshold, - config_pb2.GenericValueThreshold]] +_SliceAndThreshold = Tuple[ + Optional[Union[config_pb2.SlicingSpec, config_pb2.CrossSlicingSpec]], + Union[config_pb2.GenericChangeThreshold, config_pb2.GenericValueThreshold], +] def metric_thresholds_from_metrics_specs( metrics_specs: Iterable[config_pb2.MetricsSpec], - eval_config: Optional[config_pb2.EvalConfig] = None + eval_config: Optional[config_pb2.EvalConfig] = None, ) -> Dict[metric_types.MetricKey, Iterable[_SliceAndThreshold]]: - """Returns thresholds associated with given metrics specs.""" - if eval_config is None: - eval_config = config_pb2.EvalConfig() - result = collections.defaultdict(list) - existing = collections.defaultdict(dict) - - def add_if_not_exists( - key: metric_types.MetricKey, - slice_spec: Optional[Union[config_pb2.SlicingSpec, - config_pb2.CrossSlicingSpec]], - threshold: Union[config_pb2.GenericChangeThreshold, - config_pb2.GenericValueThreshold]): - """Adds value to results if it doesn't already exist.""" - hashable_slice_spec = None - if slice_spec: - hashable_slice_spec = slicer.deserialize_slice_spec(slice_spec) - # Note that hashing by SerializeToString() is only safe if used within the - # same process. - threshold_hash = threshold.SerializeToString() - if (not (key in existing and hashable_slice_spec in existing[key] and - threshold_hash in existing[key][hashable_slice_spec])): - if hashable_slice_spec not in existing[key]: - existing[key][hashable_slice_spec] = {} - existing[key][hashable_slice_spec][threshold_hash] = True - result[key].append((slice_spec, threshold)) - - def add_threshold(key: metric_types.MetricKey, - slice_spec: Union[Optional[config_pb2.SlicingSpec], - Optional[config_pb2.CrossSlicingSpec]], - threshold: config_pb2.MetricThreshold): - """Adds thresholds to results.""" - if threshold.HasField('value_threshold'): - add_if_not_exists(key, slice_spec, threshold.value_threshold) - if threshold.HasField('change_threshold'): - key = key.make_diff_key() - add_if_not_exists(key, slice_spec, threshold.change_threshold) - - for spec in metrics_specs: - for aggregation_type, sub_keys in _create_sub_keys(spec).items(): - # Add thresholds for metrics computed in-graph. - for metric_name, threshold in spec.thresholds.items(): - for key in _keys_for_metric(metric_name, spec, aggregation_type, - sub_keys, [None]): - add_threshold(key, None, threshold) - for metric_name, per_slice_thresholds in spec.per_slice_thresholds.items( - ): - for key in _keys_for_metric(metric_name, spec, aggregation_type, - sub_keys, [None]): - for per_slice_threshold in per_slice_thresholds.thresholds: + """Returns thresholds associated with given metrics specs.""" + if eval_config is None: + eval_config = config_pb2.EvalConfig() + result = collections.defaultdict(list) + existing = collections.defaultdict(dict) + + def add_if_not_exists( + key: metric_types.MetricKey, + slice_spec: Optional[ + Union[config_pb2.SlicingSpec, config_pb2.CrossSlicingSpec] + ], + threshold: Union[ + config_pb2.GenericChangeThreshold, config_pb2.GenericValueThreshold + ], + ): + """Adds value to results if it doesn't already exist.""" + hashable_slice_spec = None + if slice_spec: + hashable_slice_spec = slicer.deserialize_slice_spec(slice_spec) + # Note that hashing by SerializeToString() is only safe if used within the + # same process. + threshold_hash = threshold.SerializeToString() + if not ( + key in existing + and hashable_slice_spec in existing[key] + and threshold_hash in existing[key][hashable_slice_spec] + ): + if hashable_slice_spec not in existing[key]: + existing[key][hashable_slice_spec] = {} + existing[key][hashable_slice_spec][threshold_hash] = True + result[key].append((slice_spec, threshold)) + + def add_threshold( + key: metric_types.MetricKey, + slice_spec: Union[ + Optional[config_pb2.SlicingSpec], Optional[config_pb2.CrossSlicingSpec] + ], + threshold: config_pb2.MetricThreshold, + ): + """Adds thresholds to results.""" + if threshold.HasField("value_threshold"): + add_if_not_exists(key, slice_spec, threshold.value_threshold) + if threshold.HasField("change_threshold"): + key = key.make_diff_key() + add_if_not_exists(key, slice_spec, threshold.change_threshold) + + for spec in metrics_specs: + for aggregation_type, sub_keys in _create_sub_keys(spec).items(): + # Add thresholds for metrics computed in-graph. + for metric_name, threshold in spec.thresholds.items(): + for key in _keys_for_metric( + metric_name, spec, aggregation_type, sub_keys, [None] + ): + add_threshold(key, None, threshold) + for metric_name, per_slice_thresholds in spec.per_slice_thresholds.items(): + for key in _keys_for_metric( + metric_name, spec, aggregation_type, sub_keys, [None] + ): + for per_slice_threshold in per_slice_thresholds.thresholds: + for slice_spec in per_slice_threshold.slicing_specs: + add_threshold( + key, slice_spec, per_slice_threshold.threshold + ) + for ( + metric_name, + cross_slice_thresholds, + ) in spec.cross_slice_thresholds.items(): + for key in _keys_for_metric( + metric_name, spec, aggregation_type, sub_keys, [None] + ): + for cross_slice_threshold in cross_slice_thresholds.thresholds: + for ( + cross_slice_spec + ) in cross_slice_threshold.cross_slicing_specs: + add_threshold( + key, cross_slice_spec, cross_slice_threshold.threshold + ) + + # Add thresholds for post export metrics defined in MetricConfigs. + for key, metric_config, _ in keys_and_metrics_from_specs( + eval_config, metrics_specs + ): + if metric_config.HasField("threshold"): + add_threshold(key, None, metric_config.threshold) + for per_slice_threshold in metric_config.per_slice_thresholds: for slice_spec in per_slice_threshold.slicing_specs: - add_threshold(key, slice_spec, per_slice_threshold.threshold) - for metric_name, cross_slice_thresholds in ( - spec.cross_slice_thresholds.items()): - for key in _keys_for_metric(metric_name, spec, aggregation_type, - sub_keys, [None]): - for cross_slice_threshold in cross_slice_thresholds.thresholds: + add_threshold(key, slice_spec, per_slice_threshold.threshold) + for cross_slice_threshold in metric_config.cross_slice_thresholds: for cross_slice_spec in cross_slice_threshold.cross_slicing_specs: - add_threshold(key, cross_slice_spec, - cross_slice_threshold.threshold) + add_threshold(key, cross_slice_spec, cross_slice_threshold.threshold) - # Add thresholds for post export metrics defined in MetricConfigs. - for key, metric_config, _ in keys_and_metrics_from_specs( - eval_config, metrics_specs): - if metric_config.HasField('threshold'): - add_threshold(key, None, metric_config.threshold) - for per_slice_threshold in metric_config.per_slice_thresholds: - for slice_spec in per_slice_threshold.slicing_specs: - add_threshold(key, slice_spec, per_slice_threshold.threshold) - for cross_slice_threshold in metric_config.cross_slice_thresholds: - for cross_slice_spec in cross_slice_threshold.cross_slicing_specs: - add_threshold(key, cross_slice_spec, cross_slice_threshold.threshold) - - return result + return result def to_computations( metrics_specs: List[config_pb2.MetricsSpec], eval_config: Optional[config_pb2.EvalConfig] = None, - schema: Optional[schema_pb2.Schema] = None + schema: Optional[schema_pb2.Schema] = None, ) -> metric_types.MetricComputations: - """Returns computations associated with given metrics specs.""" - computations = [] - - # - # Split into TF metrics and TFMA metrics - # - - # Dict[Text, Type[tf_keras.metrics.Metric]] - tf_metric_classes = {} # class_name -> class - # Dict[Text, Type[tf_keras.losses.Loss]] - tf_loss_classes = {} # class_name -> class - # List[metric_types.MetricsSpec] - tf_metrics_specs = [] - # Dict[Text, Type[metric_types.Metric]] - tfma_metric_classes = metric_types.registered_metrics() # class_name -> class - # List[metric_types.MetricsSpec] - tfma_metrics_specs = [] - # - # Note: Lists are used instead of Dicts for the following items because - # protos are are no hashable. - # - # List[List[_TFOrTFMAMetricOrLoss]] (offsets align with metrics_specs). - per_spec_metric_instances = [] - # List[List[MetricConfig]] (offsets align with metrics_specs). - per_spec_metric_configs = [] - # List[List[_TFMetricOrLoss]] (offsets align with tf_metrics_specs). - per_tf_spec_metric_instances = [] - # List[List[metric_types.Metric]]] (offsets align with tfma_metrics_specs). - per_tfma_spec_metric_instances = [] - for spec in metrics_specs: - tf_spec = config_pb2.MetricsSpec() - tf_spec.CopyFrom(spec) - del tf_spec.metrics[:] - tfma_spec = config_pb2.MetricsSpec() - tfma_spec.CopyFrom(spec) - del tfma_spec.metrics[:] - for metric in spec.metrics: - if metric.class_name in tfma_metric_classes: - tfma_spec.metrics.append(metric) - elif not metric.module: - tf_spec.metrics.append(metric) - else: - cls = getattr(importlib.import_module(metric.module), metric.class_name) - if issubclass(cls, tf_keras.metrics.Metric): - tf_metric_classes[metric.class_name] = cls - tf_spec.metrics.append(metric) - elif issubclass(cls, tf_keras.losses.Loss): - tf_loss_classes[metric.class_name] = cls - tf_spec.metrics.append(metric) - else: - tfma_metric_classes[metric.class_name] = cls - tfma_spec.metrics.append(metric) - - metric_instances = [] - metric_configs = [] - if tf_spec.metrics: - tf_metrics_specs.append(tf_spec) - tf_metric_instances = [] - for m in tf_spec.metrics: - # To distinguish losses from metrics, losses are required to set the - # module name. - if m.module == _TF_LOSSES_MODULE: - tf_metric_instances.append(_deserialize_tf_loss(m, tf_loss_classes)) - else: - tf_metric_instances.append( - _deserialize_tf_metric(m, tf_metric_classes)) - per_tf_spec_metric_instances.append(tf_metric_instances) - metric_instances.extend(tf_metric_instances) - metric_configs.extend(tf_spec.metrics) - if tfma_spec.metrics: - tfma_metrics_specs.append(tfma_spec) - tfma_metric_instances = [ - _deserialize_tfma_metric(m, tfma_metric_classes) - for m in tfma_spec.metrics - ] - per_tfma_spec_metric_instances.append(tfma_metric_instances) - metric_instances.extend(tfma_metric_instances) - metric_configs.extend(tfma_spec.metrics) - per_spec_metric_instances.append(metric_instances) - per_spec_metric_configs.append(metric_configs) - - # Process TF specs - computations.extend( - _process_tf_metrics_specs(tf_metrics_specs, per_tf_spec_metric_instances, - eval_config)) - - # Process TFMA specs - computations.extend( - _process_tfma_metrics_specs(tfma_metrics_specs, - per_tfma_spec_metric_instances, eval_config, - schema)) - - # Process aggregation based metrics (output aggregation and macro averaging). - # Note that processing of TF and TFMA specs were setup to create the binarized - # metrics that macro averaging depends on. - for i, spec in enumerate(metrics_specs): - for example_weighted in _example_weight_options(eval_config, spec): - for aggregation_type, sub_keys in _create_sub_keys(spec).items(): - output_names = spec.output_names or [''] - output_weights = dict(spec.output_weights) - if not set(output_weights).issubset(output_names): - raise ValueError( - 'one or more output_names used in output_weights does not exist: ' - 'output_names={}, output_weights={}'.format( - output_names, output_weights)) - for model_name in spec.model_names or ['']: - for sub_key in sub_keys: - for metric, _ in zip(per_spec_metric_instances[i], - per_spec_metric_configs[i]): - if (aggregation_type and - (aggregation_type.macro_average or - aggregation_type.weighted_macro_average)): - class_weights = _class_weights(spec) or {} - for output_name in output_names: - macro_average_sub_keys = _macro_average_sub_keys( - sub_key, class_weights) - if aggregation_type.macro_average: - computations.extend( - aggregation.macro_average( - metric.get_config()['name'], - sub_keys=macro_average_sub_keys, - eval_config=eval_config, - model_name=model_name, - output_name=output_name, - sub_key=sub_key, - class_weights=class_weights, - example_weighted=example_weighted)) - elif aggregation_type.weighted_macro_average: - computations.extend( - aggregation.weighted_macro_average( - metric.get_config()['name'], - sub_keys=macro_average_sub_keys, - eval_config=eval_config, - model_name=model_name, - output_name=output_name, - sub_key=sub_key, - class_weights=class_weights, - example_weighted=example_weighted)) - if output_weights: - computations.extend( - aggregation.output_average( - metric.get_config()['name'], - output_weights=output_weights, - eval_config=eval_config, - model_name=model_name, - sub_key=sub_key, - example_weighted=example_weighted)) + """Returns computations associated with given metrics specs.""" + computations = [] + + # + # Split into TF metrics and TFMA metrics + # + + # Dict[Text, Type[tf_keras.metrics.Metric]] + tf_metric_classes = {} # class_name -> class + # Dict[Text, Type[tf_keras.losses.Loss]] + tf_loss_classes = {} # class_name -> class + # List[metric_types.MetricsSpec] + tf_metrics_specs = [] + # Dict[Text, Type[metric_types.Metric]] + tfma_metric_classes = metric_types.registered_metrics() # class_name -> class + # List[metric_types.MetricsSpec] + tfma_metrics_specs = [] + # + # Note: Lists are used instead of Dicts for the following items because + # protos are are no hashable. + # + # List[List[_TFOrTFMAMetricOrLoss]] (offsets align with metrics_specs). + per_spec_metric_instances = [] + # List[List[MetricConfig]] (offsets align with metrics_specs). + per_spec_metric_configs = [] + # List[List[_TFMetricOrLoss]] (offsets align with tf_metrics_specs). + per_tf_spec_metric_instances = [] + # List[List[metric_types.Metric]]] (offsets align with tfma_metrics_specs). + per_tfma_spec_metric_instances = [] + for spec in metrics_specs: + tf_spec = config_pb2.MetricsSpec() + tf_spec.CopyFrom(spec) + del tf_spec.metrics[:] + tfma_spec = config_pb2.MetricsSpec() + tfma_spec.CopyFrom(spec) + del tfma_spec.metrics[:] + for metric in spec.metrics: + if metric.class_name in tfma_metric_classes: + tfma_spec.metrics.append(metric) + elif not metric.module: + tf_spec.metrics.append(metric) + else: + cls = getattr(importlib.import_module(metric.module), metric.class_name) + if issubclass(cls, tf_keras.metrics.Metric): + tf_metric_classes[metric.class_name] = cls + tf_spec.metrics.append(metric) + elif issubclass(cls, tf_keras.losses.Loss): + tf_loss_classes[metric.class_name] = cls + tf_spec.metrics.append(metric) + else: + tfma_metric_classes[metric.class_name] = cls + tfma_spec.metrics.append(metric) + + metric_instances = [] + metric_configs = [] + if tf_spec.metrics: + tf_metrics_specs.append(tf_spec) + tf_metric_instances = [] + for m in tf_spec.metrics: + # To distinguish losses from metrics, losses are required to set the + # module name. + if m.module == _TF_LOSSES_MODULE: + tf_metric_instances.append(_deserialize_tf_loss(m, tf_loss_classes)) + else: + tf_metric_instances.append( + _deserialize_tf_metric(m, tf_metric_classes) + ) + per_tf_spec_metric_instances.append(tf_metric_instances) + metric_instances.extend(tf_metric_instances) + metric_configs.extend(tf_spec.metrics) + if tfma_spec.metrics: + tfma_metrics_specs.append(tfma_spec) + tfma_metric_instances = [ + _deserialize_tfma_metric(m, tfma_metric_classes) + for m in tfma_spec.metrics + ] + per_tfma_spec_metric_instances.append(tfma_metric_instances) + metric_instances.extend(tfma_metric_instances) + metric_configs.extend(tfma_spec.metrics) + per_spec_metric_instances.append(metric_instances) + per_spec_metric_configs.append(metric_configs) + + # Process TF specs + computations.extend( + _process_tf_metrics_specs( + tf_metrics_specs, per_tf_spec_metric_instances, eval_config + ) + ) - return computations + # Process TFMA specs + computations.extend( + _process_tfma_metrics_specs( + tfma_metrics_specs, per_tfma_spec_metric_instances, eval_config, schema + ) + ) + + # Process aggregation based metrics (output aggregation and macro averaging). + # Note that processing of TF and TFMA specs were setup to create the binarized + # metrics that macro averaging depends on. + for i, spec in enumerate(metrics_specs): + for example_weighted in _example_weight_options(eval_config, spec): + for aggregation_type, sub_keys in _create_sub_keys(spec).items(): + output_names = spec.output_names or [""] + output_weights = dict(spec.output_weights) + if not set(output_weights).issubset(output_names): + raise ValueError( + "one or more output_names used in output_weights does not exist: " + f"output_names={output_names}, output_weights={output_weights}" + ) + for model_name in spec.model_names or [""]: + for sub_key in sub_keys: + for metric, _ in zip( + per_spec_metric_instances[i], per_spec_metric_configs[i] + ): + if aggregation_type and ( + aggregation_type.macro_average + or aggregation_type.weighted_macro_average + ): + class_weights = _class_weights(spec) or {} + for output_name in output_names: + macro_average_sub_keys = _macro_average_sub_keys( + sub_key, class_weights + ) + if aggregation_type.macro_average: + computations.extend( + aggregation.macro_average( + metric.get_config()["name"], + sub_keys=macro_average_sub_keys, + eval_config=eval_config, + model_name=model_name, + output_name=output_name, + sub_key=sub_key, + class_weights=class_weights, + example_weighted=example_weighted, + ) + ) + elif aggregation_type.weighted_macro_average: + computations.extend( + aggregation.weighted_macro_average( + metric.get_config()["name"], + sub_keys=macro_average_sub_keys, + eval_config=eval_config, + model_name=model_name, + output_name=output_name, + sub_key=sub_key, + class_weights=class_weights, + example_weighted=example_weighted, + ) + ) + if output_weights: + computations.extend( + aggregation.output_average( + metric.get_config()["name"], + output_weights=output_weights, + eval_config=eval_config, + model_name=model_name, + sub_key=sub_key, + example_weighted=example_weighted, + ) + ) + + return computations def _process_tf_metrics_specs( tf_metrics_specs: List[config_pb2.MetricsSpec], per_tf_spec_metric_instances: List[List[_TFMetricOrLoss]], - eval_config: config_pb2.EvalConfig) -> metric_types.MetricComputations: - """Processes list of TF MetricsSpecs to create computations.""" - - # Wrap args into structure that is hashable so we can track unique arg sets. - class UniqueArgs( - NamedTuple('UniqueArgs', - [('model_name', str), - ('sub_key', Optional[metric_types.SubKey]), - ('aggregation_type', Optional[metric_types.AggregationType]), - ('class_weights', Tuple[Tuple[int, float], ...])])): - pass - - def _create_private_tf_metrics( - metrics: List[_TFMetricOrLoss]) -> List[_TFMetricOrLoss]: - """Creates private versions of TF metrics.""" - result = [] - for m in metrics: - if isinstance(m, tf_keras.metrics.Metric): - result.append(_private_tf_metric(m)) - else: - result.append(_private_tf_loss(m)) - return result - - # - # Group TF metrics by the subkeys, models and outputs. This is done in reverse - # because model and subkey processing is done outside of TF and so each unique - # sub key combination needs to be run through a separate model instance. Note - # that output_names are handled by the tf_metric_computation since all the - # outputs are batch calculated in a single model evaluation call. - # - - # UniqueArgs -> output_name -> [_TFMetricOrLoss] - metrics_by_unique_args = collections.defaultdict(dict) - for i, spec in enumerate(tf_metrics_specs): - metrics = per_tf_spec_metric_instances[i] - sub_keys_by_aggregation_type = _create_sub_keys(spec) - # Keep track of metrics that can be shared between macro averaging and - # binarization. For example, if macro averaging is being performed over 10 - # classes and 5 of the classes are also being binarized, then those 5 - # classes can be re-used by the macro averaging calculation. The remaining - # 5 classes need to be added as private metrics since those classes were - # not requested but are still needed for the macro averaging calculation. - if None in sub_keys_by_aggregation_type: - shared_sub_keys = set(sub_keys_by_aggregation_type[None]) - else: - shared_sub_keys = set() - for aggregation_type, sub_keys in sub_keys_by_aggregation_type.items(): - if aggregation_type: - class_weights = tuple(sorted((_class_weights(spec) or {}).items())) - else: - class_weights = () - is_macro = ( - aggregation_type and (aggregation_type.macro_average or - aggregation_type.weighted_macro_average)) - for parent_sub_key in sub_keys: - if is_macro: - child_sub_keys = _macro_average_sub_keys(parent_sub_key, - _class_weights(spec)) + eval_config: config_pb2.EvalConfig, +) -> metric_types.MetricComputations: + """Processes list of TF MetricsSpecs to create computations.""" + + # Wrap args into structure that is hashable so we can track unique arg sets. + class UniqueArgs( + NamedTuple( + "UniqueArgs", + [ + ("model_name", str), + ("sub_key", Optional[metric_types.SubKey]), + ("aggregation_type", Optional[metric_types.AggregationType]), + ("class_weights", Tuple[Tuple[int, float], ...]), + ], + ) + ): + pass + + def _create_private_tf_metrics( + metrics: List[_TFMetricOrLoss], + ) -> List[_TFMetricOrLoss]: + """Creates private versions of TF metrics.""" + result = [] + for m in metrics: + if isinstance(m, tf_keras.metrics.Metric): + result.append(_private_tf_metric(m)) + else: + result.append(_private_tf_loss(m)) + return result + + # + # Group TF metrics by the subkeys, models and outputs. This is done in reverse + # because model and subkey processing is done outside of TF and so each unique + # sub key combination needs to be run through a separate model instance. Note + # that output_names are handled by the tf_metric_computation since all the + # outputs are batch calculated in a single model evaluation call. + # + + # UniqueArgs -> output_name -> [_TFMetricOrLoss] + metrics_by_unique_args = collections.defaultdict(dict) + for i, spec in enumerate(tf_metrics_specs): + metrics = per_tf_spec_metric_instances[i] + sub_keys_by_aggregation_type = _create_sub_keys(spec) + # Keep track of metrics that can be shared between macro averaging and + # binarization. For example, if macro averaging is being performed over 10 + # classes and 5 of the classes are also being binarized, then those 5 + # classes can be re-used by the macro averaging calculation. The remaining + # 5 classes need to be added as private metrics since those classes were + # not requested but are still needed for the macro averaging calculation. + if None in sub_keys_by_aggregation_type: + shared_sub_keys = set(sub_keys_by_aggregation_type[None]) else: - child_sub_keys = [parent_sub_key] - for output_name in spec.output_names or ['']: - for sub_key in child_sub_keys: - if is_macro and sub_key not in shared_sub_keys: - # Create private metrics for all non-shared metrics. - instances = _create_private_tf_metrics(metrics) + shared_sub_keys = set() + for aggregation_type, sub_keys in sub_keys_by_aggregation_type.items(): + if aggregation_type: + class_weights = tuple(sorted((_class_weights(spec) or {}).items())) else: - instances = metrics - for model_name in spec.model_names or ['']: - unique_args = UniqueArgs( - model_name, sub_key, - aggregation_type if not is_macro else None, - class_weights if not is_macro else ()) - if unique_args not in metrics_by_unique_args: - # Tuple of weighted and unweighted metrics by output - metrics_by_unique_args[unique_args] = ( - collections.defaultdict(list), - collections.defaultdict(list)) - for instance in instances: - for example_weighted in _example_weight_options( - eval_config, spec): - if example_weighted: - metrics_by_unique_args[unique_args][0][output_name].append( - instance) - else: - metrics_by_unique_args[unique_args][1][output_name].append( - instance) - - # Convert Unique args and outputs to calls to compute TF metrics - result = [] - for args, metrics_by_output in metrics_by_unique_args.items(): - class_weights = dict(args.class_weights) if args.class_weights else None - weighted_metrics_by_output, unweighted_metrics_by_output = metrics_by_output - if weighted_metrics_by_output: - result.extend( - tf_metric_wrapper.tf_metric_computations( - weighted_metrics_by_output, - eval_config=eval_config, - model_name=args.model_name, - sub_key=args.sub_key, - aggregation_type=args.aggregation_type, - class_weights=class_weights, - example_weighted=True)) - if unweighted_metrics_by_output: - result.extend( - tf_metric_wrapper.tf_metric_computations( - unweighted_metrics_by_output, - eval_config=eval_config, - model_name=args.model_name, - sub_key=args.sub_key, - aggregation_type=args.aggregation_type, - class_weights=class_weights, - example_weighted=False)) - return result + class_weights = () + is_macro = aggregation_type and ( + aggregation_type.macro_average + or aggregation_type.weighted_macro_average + ) + for parent_sub_key in sub_keys: + if is_macro: + child_sub_keys = _macro_average_sub_keys( + parent_sub_key, _class_weights(spec) + ) + else: + child_sub_keys = [parent_sub_key] + for output_name in spec.output_names or [""]: + for sub_key in child_sub_keys: + if is_macro and sub_key not in shared_sub_keys: + # Create private metrics for all non-shared metrics. + instances = _create_private_tf_metrics(metrics) + else: + instances = metrics + for model_name in spec.model_names or [""]: + unique_args = UniqueArgs( + model_name, + sub_key, + aggregation_type if not is_macro else None, + class_weights if not is_macro else (), + ) + if unique_args not in metrics_by_unique_args: + # Tuple of weighted and unweighted metrics by output + metrics_by_unique_args[unique_args] = ( + collections.defaultdict(list), + collections.defaultdict(list), + ) + for instance in instances: + for example_weighted in _example_weight_options( + eval_config, spec + ): + if example_weighted: + metrics_by_unique_args[unique_args][0][ + output_name + ].append(instance) + else: + metrics_by_unique_args[unique_args][1][ + output_name + ].append(instance) + + # Convert Unique args and outputs to calls to compute TF metrics + result = [] + for args, metrics_by_output in metrics_by_unique_args.items(): + class_weights = dict(args.class_weights) if args.class_weights else None + weighted_metrics_by_output, unweighted_metrics_by_output = metrics_by_output + if weighted_metrics_by_output: + result.extend( + tf_metric_wrapper.tf_metric_computations( + weighted_metrics_by_output, + eval_config=eval_config, + model_name=args.model_name, + sub_key=args.sub_key, + aggregation_type=args.aggregation_type, + class_weights=class_weights, + example_weighted=True, + ) + ) + if unweighted_metrics_by_output: + result.extend( + tf_metric_wrapper.tf_metric_computations( + unweighted_metrics_by_output, + eval_config=eval_config, + model_name=args.model_name, + sub_key=args.sub_key, + aggregation_type=args.aggregation_type, + class_weights=class_weights, + example_weighted=False, + ) + ) + return result def _process_tfma_metrics_specs( tfma_metrics_specs: List[config_pb2.MetricsSpec], per_tfma_spec_metric_instances: List[List[metric_types.Metric]], eval_config: config_pb2.EvalConfig, - schema: Optional[schema_pb2.Schema]) -> metric_types.MetricComputations: - """Processes list of TFMA MetricsSpecs to create computations.""" - - # - # Computations are per metric, so separate by metrics and the specs associated - # with them. - # - - # Dict[bytes,List[config_pb2.MetricSpec]] (hash(MetricConfig)->[MetricSpec]) - tfma_specs_by_metric_config = {} - # Dict[bytes,metric_types.Metric] (hash(MetricConfig)->Metric) - hashed_metrics = {} - hashed_configs = {} - for i, spec in enumerate(tfma_metrics_specs): - for metric_config, metric in zip(spec.metrics, - per_tfma_spec_metric_instances[i]): - # Note that hashing by SerializeToString() is only safe if used within the - # same process. - config_hash = metric_config.SerializeToString() - if config_hash not in tfma_specs_by_metric_config: - hashed_metrics[config_hash] = metric - hashed_configs[config_hash] = metric_config - tfma_specs_by_metric_config[config_hash] = [] - tfma_specs_by_metric_config[config_hash].append(spec) - - # - # Create computations for each metric. - # - - result = [] - for config_hash, specs in tfma_specs_by_metric_config.items(): - metric = hashed_metrics[config_hash] - metric_config = hashed_configs[config_hash] - for spec in specs: - sub_keys_by_aggregation_type = _create_sub_keys(spec) - # Keep track of sub-keys that can be shared between macro averaging and - # binarization. For example, if macro averaging is being performed over - # 10 classes and 5 of the classes are also being binarized, then those 5 - # classes can be re-used by the macro averaging calculation. The - # remaining 5 classes need to be added as private metrics since those - # classes were not requested but are still needed for the macro - # averaging calculation. - if None in sub_keys_by_aggregation_type: - shared_sub_keys = set(sub_keys_by_aggregation_type[None]) - else: - shared_sub_keys = set() - for aggregation_type, sub_keys in sub_keys_by_aggregation_type.items(): - class_weights = _class_weights(spec) if aggregation_type else None - is_macro = ( - aggregation_type and (aggregation_type.macro_average or - aggregation_type.weighted_macro_average)) - if is_macro: - updated_sub_keys = [] - for sub_key in sub_keys: - for key in _macro_average_sub_keys(sub_key, class_weights): - if key not in shared_sub_keys: - updated_sub_keys.append(key) - if not updated_sub_keys: - continue - aggregation_type = aggregation_type if not is_macro else None - class_weights = None - sub_keys = updated_sub_keys - instance = _private_tfma_metric(metric) - else: - instance = metric - for example_weighted in _example_weight_options(eval_config, spec): - result.extend( - instance.computations( - eval_config=eval_config, - schema=schema, - model_names=list(spec.model_names) or [''], - output_names=list(spec.output_names) or [''], - sub_keys=sub_keys, - aggregation_type=aggregation_type, - class_weights=class_weights if class_weights else None, - example_weighted=example_weighted, - query_key=spec.query_key)) - return result + schema: Optional[schema_pb2.Schema], +) -> metric_types.MetricComputations: + """Processes list of TFMA MetricsSpecs to create computations.""" + # + # Computations are per metric, so separate by metrics and the specs associated + # with them. + # + + # Dict[bytes,List[config_pb2.MetricSpec]] (hash(MetricConfig)->[MetricSpec]) + tfma_specs_by_metric_config = {} + # Dict[bytes,metric_types.Metric] (hash(MetricConfig)->Metric) + hashed_metrics = {} + hashed_configs = {} + for i, spec in enumerate(tfma_metrics_specs): + for metric_config, metric in zip( + spec.metrics, per_tfma_spec_metric_instances[i] + ): + # Note that hashing by SerializeToString() is only safe if used within the + # same process. + config_hash = metric_config.SerializeToString() + if config_hash not in tfma_specs_by_metric_config: + hashed_metrics[config_hash] = metric + hashed_configs[config_hash] = metric_config + tfma_specs_by_metric_config[config_hash] = [] + tfma_specs_by_metric_config[config_hash].append(spec) + + # + # Create computations for each metric. + # + + result = [] + for config_hash, specs in tfma_specs_by_metric_config.items(): + metric = hashed_metrics[config_hash] + metric_config = hashed_configs[config_hash] + for spec in specs: + sub_keys_by_aggregation_type = _create_sub_keys(spec) + # Keep track of sub-keys that can be shared between macro averaging and + # binarization. For example, if macro averaging is being performed over + # 10 classes and 5 of the classes are also being binarized, then those 5 + # classes can be re-used by the macro averaging calculation. The + # remaining 5 classes need to be added as private metrics since those + # classes were not requested but are still needed for the macro + # averaging calculation. + if None in sub_keys_by_aggregation_type: + shared_sub_keys = set(sub_keys_by_aggregation_type[None]) + else: + shared_sub_keys = set() + for aggregation_type, sub_keys in sub_keys_by_aggregation_type.items(): + class_weights = _class_weights(spec) if aggregation_type else None + is_macro = aggregation_type and ( + aggregation_type.macro_average + or aggregation_type.weighted_macro_average + ) + if is_macro: + updated_sub_keys = [] + for sub_key in sub_keys: + for key in _macro_average_sub_keys(sub_key, class_weights): + if key not in shared_sub_keys: + updated_sub_keys.append(key) + if not updated_sub_keys: + continue + aggregation_type = aggregation_type if not is_macro else None + class_weights = None + sub_keys = updated_sub_keys + instance = _private_tfma_metric(metric) + else: + instance = metric + for example_weighted in _example_weight_options(eval_config, spec): + result.extend( + instance.computations( + eval_config=eval_config, + schema=schema, + model_names=list(spec.model_names) or [""], + output_names=list(spec.output_names) or [""], + sub_keys=sub_keys, + aggregation_type=aggregation_type, + class_weights=class_weights if class_weights else None, + example_weighted=example_weighted, + query_key=spec.query_key, + ) + ) + return result def _create_sub_keys( - spec: config_pb2.MetricsSpec -) -> Dict[Optional[metric_types.AggregationType], - List[Optional[metric_types.SubKey]]]: - """Creates sub keys per aggregation type.""" - result = {} - if spec.HasField('binarize'): - sub_keys = [] - if spec.binarize.class_ids.values: - for v in spec.binarize.class_ids.values: - sub_keys.append(metric_types.SubKey(class_id=v)) - if spec.binarize.k_list.values: - for v in spec.binarize.k_list.values: - sub_keys.append(metric_types.SubKey(k=v)) - if spec.binarize.top_k_list.values: - for v in spec.binarize.top_k_list.values: - sub_keys.append(metric_types.SubKey(top_k=v)) - if sub_keys: - result[None] = sub_keys - if spec.HasField('aggregate'): - sub_keys = [] - for top_k in spec.aggregate.top_k_list.values: - sub_keys.append(metric_types.SubKey(top_k=top_k)) - if not sub_keys: - sub_keys = [None] - result[_aggregation_type(spec)] = sub_keys - return result if result else {None: [None]} + spec: config_pb2.MetricsSpec, +) -> Dict[Optional[metric_types.AggregationType], List[Optional[metric_types.SubKey]]]: + """Creates sub keys per aggregation type.""" + result = {} + if spec.HasField("binarize"): + sub_keys = [] + if spec.binarize.class_ids.values: + for v in spec.binarize.class_ids.values: + sub_keys.append(metric_types.SubKey(class_id=v)) + if spec.binarize.k_list.values: + for v in spec.binarize.k_list.values: + sub_keys.append(metric_types.SubKey(k=v)) + if spec.binarize.top_k_list.values: + for v in spec.binarize.top_k_list.values: + sub_keys.append(metric_types.SubKey(top_k=v)) + if sub_keys: + result[None] = sub_keys + if spec.HasField("aggregate"): + sub_keys = [] + for top_k in spec.aggregate.top_k_list.values: + sub_keys.append(metric_types.SubKey(top_k=top_k)) + if not sub_keys: + sub_keys = [None] + result[_aggregation_type(spec)] = sub_keys + return result if result else {None: [None]} def _macro_average_sub_keys( - sub_key: Optional[metric_types.SubKey], - class_weights: Dict[int, float]) -> Iterable[metric_types.SubKey]: - """Returns sub-keys required in order to compute macro average sub-key. - - Args: - sub_key: SubKey associated with macro_average or weighted_macro_average. - class_weights: Class weights associated with sub-key. - - Raises: - ValueError: If invalid sub-key passed or class weights required but not - passed. - """ - if not sub_key: - if not class_weights: - raise ValueError( - 'class_weights are required in order to compute macro average over ' - 'all classes: sub_key={}, class_weights={}'.format( - sub_key, class_weights)) - return [metric_types.SubKey(class_id=i) for i in class_weights] - elif sub_key.top_k: - return [metric_types.SubKey(k=i + 1) for i in range(sub_key.top_k)] - else: - raise ValueError('invalid sub_key for performing macro averaging: ' - 'sub_key={}'.format(sub_key)) + sub_key: Optional[metric_types.SubKey], class_weights: Dict[int, float] +) -> Iterable[metric_types.SubKey]: + """Returns sub-keys required in order to compute macro average sub-key. + + Args: + ---- + sub_key: SubKey associated with macro_average or weighted_macro_average. + class_weights: Class weights associated with sub-key. + + Raises: + ------ + ValueError: If invalid sub-key passed or class weights required but not + passed. + """ + if not sub_key: + if not class_weights: + raise ValueError( + "class_weights are required in order to compute macro average over " + f"all classes: sub_key={sub_key}, class_weights={class_weights}" + ) + return [metric_types.SubKey(class_id=i) for i in class_weights] + elif sub_key.top_k: + return [metric_types.SubKey(k=i + 1) for i in range(sub_key.top_k)] + else: + raise ValueError( + "invalid sub_key for performing macro averaging: " f"sub_key={sub_key}" + ) def _aggregation_type( - spec: config_pb2.MetricsSpec) -> Optional[metric_types.AggregationType]: - """Returns AggregationType associated with AggregationOptions at offset.""" - if spec.aggregate.micro_average: - return metric_types.AggregationType(micro_average=True) - if spec.aggregate.macro_average: - return metric_types.AggregationType(macro_average=True) - if spec.aggregate.weighted_macro_average: - return metric_types.AggregationType(weighted_macro_average=True) - return None + spec: config_pb2.MetricsSpec, +) -> Optional[metric_types.AggregationType]: + """Returns AggregationType associated with AggregationOptions at offset.""" + if spec.aggregate.micro_average: + return metric_types.AggregationType(micro_average=True) + if spec.aggregate.macro_average: + return metric_types.AggregationType(macro_average=True) + if spec.aggregate.weighted_macro_average: + return metric_types.AggregationType(weighted_macro_average=True) + return None def _class_weights(spec: config_pb2.MetricsSpec) -> Optional[Dict[int, float]]: - """Returns class weights associated with AggregationOptions at offset.""" - if spec.aggregate.HasField('top_k_list'): - if spec.aggregate.class_weights: - raise ValueError('class_weights are not supported when top_k_list used: ' - 'spec={}'.format(spec)) - return None - return dict(spec.aggregate.class_weights) or None + """Returns class weights associated with AggregationOptions at offset.""" + if spec.aggregate.HasField("top_k_list"): + if spec.aggregate.class_weights: + raise ValueError( + "class_weights are not supported when top_k_list used: " f"spec={spec}" + ) + return None + return dict(spec.aggregate.class_weights) or None def _is_supported_tf_metric(tf_metric: _TFMetricOrLoss) -> bool: - """Returns true if TF metric has an equivalent implementation in TFMA.""" - if not metric_types.is_registered_metric(tf_metric.__class__.__name__): - return False - cfg = tf_metric.get_config() - for cls_name, settings in _UNSUPPORTED_TF_SETTINGS.items(): - if not cls_name or cls_name == tf_metric.__class__.__name__: - for param, values in settings.items(): - if param in cfg and cfg[param] not in values: - return False - return True + """Returns true if TF metric has an equivalent implementation in TFMA.""" + if not metric_types.is_registered_metric(tf_metric.__class__.__name__): + return False + cfg = tf_metric.get_config() + for cls_name, settings in _UNSUPPORTED_TF_SETTINGS.items(): + if not cls_name or cls_name == tf_metric.__class__.__name__: + for param, values in settings.items(): + if param in cfg and cfg[param] not in values: + return False + return True def _remove_unsupported_tf_settings( - metric_config: config_pb2.MetricConfig) -> config_pb2.MetricConfig: - """Deletes unsupported TF settings from config. - - Removes TF config settings that TFMA only supports default values for because - the parameters are not supported by the TFMA implementation of the metric. - - Args: - metric_config: Metric config. + metric_config: config_pb2.MetricConfig, +) -> config_pb2.MetricConfig: + """Deletes unsupported TF settings from config. + + Removes TF config settings that TFMA only supports default values for because + the parameters are not supported by the TFMA implementation of the metric. + + Args: + ---- + metric_config: Metric config. + + Returns: + ------- + Updated metric config with unsupported settings removed. + """ + cfg = _metric_config(metric_config.config) + for cls_name, settings in _UNSUPPORTED_TF_SETTINGS.items(): + if not cls_name or cls_name == metric_config.class_name: + for param in settings: + if param in cfg: + del cfg[param] + + return config_pb2.MetricConfig( + class_name=metric_config.class_name, config=json.dumps(cfg, sort_keys=True) + ) - Returns: - Updated metric config with unsupported settings removed. - """ - cfg = _metric_config(metric_config.config) - for cls_name, settings in _UNSUPPORTED_TF_SETTINGS.items(): - if not cls_name or cls_name == metric_config.class_name: - for param in settings: - if param in cfg: - del cfg[param] - return config_pb2.MetricConfig( - class_name=metric_config.class_name, - config=json.dumps(cfg, sort_keys=True)) +def _metric_config(cfg: str) -> Dict[str, Any]: + """Returns deserializable metric config from JSON string.""" + if not cfg: + json_cfg = "{}" + elif cfg[0] != "{": + json_cfg = "{" + cfg + "}" + else: + json_cfg = cfg + return json.loads(json_cfg) -def _metric_config(cfg: str) -> Dict[str, Any]: - """Returns deserializable metric config from JSON string.""" - if not cfg: - json_cfg = '{}' - elif cfg[0] != '{': - json_cfg = '{' + cfg + '}' - else: - json_cfg = cfg - return json.loads(json_cfg) - - -def _maybe_add_name_to_config(cfg: Dict[str, Any], - class_name: str) -> Dict[str, Any]: - """Adds default name field to metric config if not present.""" - if 'name' not in cfg: - # Use snake_case version of class name as default name. - intermediate = re.sub('(.)([A-Z][a-z0-9]+)', r'\1_\2', class_name) - cfg['name'] = re.sub('([a-z])([A-Z])', r'\1_\2', intermediate).lower() - return cfg +def _maybe_add_name_to_config(cfg: Dict[str, Any], class_name: str) -> Dict[str, Any]: + """Adds default name field to metric config if not present.""" + if "name" not in cfg: + # Use snake_case version of class name as default name. + intermediate = re.sub("(.)([A-Z][a-z0-9]+)", r"\1_\2", class_name) + cfg["name"] = re.sub("([a-z])([A-Z])", r"\1_\2", intermediate).lower() + return cfg def _tf_class_and_config( - metric_config: config_pb2.MetricConfig) -> Tuple[str, Dict[str, Any]]: - """Returns the tensorflow class and config associated with metric_config.""" - cls_name = metric_config.class_name - cfg = _metric_config(metric_config.config) + metric_config: config_pb2.MetricConfig, +) -> Tuple[str, Dict[str, Any]]: + """Returns the tensorflow class and config associated with metric_config.""" + cls_name = metric_config.class_name + cfg = _metric_config(metric_config.config) - # The same metric type may be used for different keys when multi-class metrics - # are used (e.g. AUC for class0, # class1, etc). TF tries to generate unique - # metric names even though these metrics are already unique within a - # MetricKey. To workaround this issue, if a name is not set, then add a - # default name ourselves. - return cls_name, _maybe_add_name_to_config(cfg, cls_name) + # The same metric type may be used for different keys when multi-class metrics + # are used (e.g. AUC for class0, # class1, etc). TF tries to generate unique + # metric names even though these metrics are already unique within a + # MetricKey. To workaround this issue, if a name is not set, then add a + # default name ourselves. + return cls_name, _maybe_add_name_to_config(cfg, cls_name) def _serialize_tf_metric( metric: tf_keras.metrics.Metric, ) -> config_pb2.MetricConfig: - """Serializes TF metric.""" - cfg = metric_util.serialize_metric(metric, use_legacy_format=True) - if ( - tf_keras.utils.get_registered_name(metric.__class__) - == metric.__class__.__name__ - ): - module = metric.__class__.__module__ - else: - module = None - return config_pb2.MetricConfig( - class_name=cfg['class_name'], - module=module, - config=json.dumps(cfg['config'], sort_keys=True), - ) + """Serializes TF metric.""" + cfg = metric_util.serialize_metric(metric, use_legacy_format=True) + if ( + tf_keras.utils.get_registered_name(metric.__class__) + == metric.__class__.__name__ + ): + module = metric.__class__.__module__ + else: + module = None + return config_pb2.MetricConfig( + class_name=cfg["class_name"], + module=module, + config=json.dumps(cfg["config"], sort_keys=True), + ) def _deserialize_tf_metric( metric_config: config_pb2.MetricConfig, custom_objects: Dict[str, Type[tf_keras.metrics.Metric]], ) -> tf_keras.metrics.Metric: - """Deserializes a tf_keras.metrics metric.""" - cls_name, cfg = _tf_class_and_config(metric_config) - with tf_keras.utils.custom_object_scope(custom_objects): - return metric_util.deserialize_metric( - {'class_name': cls_name, 'config': cfg}, use_legacy_format=True - ) + """Deserializes a tf_keras.metrics metric.""" + cls_name, cfg = _tf_class_and_config(metric_config) + with tf_keras.utils.custom_object_scope(custom_objects): + return metric_util.deserialize_metric( + {"class_name": cls_name, "config": cfg}, use_legacy_format=True + ) def _private_tf_metric( metric: tf_keras.metrics.Metric, ) -> tf_keras.metrics.Metric: - """Creates a private version of given metric.""" - cfg = metric_util.serialize_metric(metric) - if not cfg['config']['name'].startswith('_'): - cfg['config']['name'] = '_' + cfg['config']['name'] - with tf_keras.utils.custom_object_scope( - {metric.__class__.__name__: metric.__class__} - ): - return metric_util.deserialize_metric(cfg, use_legacy_format=True) + """Creates a private version of given metric.""" + cfg = metric_util.serialize_metric(metric) + if not cfg["config"]["name"].startswith("_"): + cfg["config"]["name"] = "_" + cfg["config"]["name"] + with tf_keras.utils.custom_object_scope( + {metric.__class__.__name__: metric.__class__} + ): + return metric_util.deserialize_metric(cfg, use_legacy_format=True) def _serialize_tf_loss(loss: tf_keras.losses.Loss) -> config_pb2.MetricConfig: - """Serializes TF loss.""" - cfg = metric_util.serialize_loss(loss, use_legacy_format=True) - return config_pb2.MetricConfig( - class_name=cfg['class_name'], - module=loss.__class__.__module__, - config=json.dumps(cfg['config'], sort_keys=True)) + """Serializes TF loss.""" + cfg = metric_util.serialize_loss(loss, use_legacy_format=True) + return config_pb2.MetricConfig( + class_name=cfg["class_name"], + module=loss.__class__.__module__, + config=json.dumps(cfg["config"], sort_keys=True), + ) def _deserialize_tf_loss( metric_config: config_pb2.MetricConfig, custom_objects: Dict[str, Type[tf_keras.losses.Loss]], ) -> tf_keras.losses.Loss: - """Deserializes a tf_keras.loss metric.""" - cls_name, cfg = _tf_class_and_config(metric_config) - with tf_keras.utils.custom_object_scope(custom_objects): - return metric_util.deserialize_loss( - {'class_name': cls_name, 'config': cfg}, use_legacy_format=True - ) + """Deserializes a tf_keras.loss metric.""" + cls_name, cfg = _tf_class_and_config(metric_config) + with tf_keras.utils.custom_object_scope(custom_objects): + return metric_util.deserialize_loss( + {"class_name": cls_name, "config": cfg}, use_legacy_format=True + ) def _private_tf_loss(loss: tf_keras.losses.Loss) -> tf_keras.losses.Loss: - """Creates a private version of given loss.""" - cfg = metric_util.serialize_loss(loss) - if not cfg['config']['name'].startswith('_'): - cfg['config']['name'] = '_' + cfg['config']['name'] - with tf_keras.utils.custom_object_scope( - {loss.__class__.__name__: loss.__class__} - ): - return metric_util.deserialize_loss(cfg, use_legacy_format=True) - - -def _serialize_tfma_metric( - metric: metric_types.Metric) -> config_pb2.MetricConfig: - """Serializes TFMA metric.""" - cfg = metric_util.serialize_keras_object(metric) - return config_pb2.MetricConfig( - class_name=cfg['class_name'], - config=json.dumps(cfg['config'], sort_keys=True)) + """Creates a private version of given loss.""" + cfg = metric_util.serialize_loss(loss) + if not cfg["config"]["name"].startswith("_"): + cfg["config"]["name"] = "_" + cfg["config"]["name"] + with tf_keras.utils.custom_object_scope({loss.__class__.__name__: loss.__class__}): + return metric_util.deserialize_loss(cfg, use_legacy_format=True) + + +def _serialize_tfma_metric(metric: metric_types.Metric) -> config_pb2.MetricConfig: + """Serializes TFMA metric.""" + cfg = metric_util.serialize_keras_object(metric) + return config_pb2.MetricConfig( + class_name=cfg["class_name"], config=json.dumps(cfg["config"], sort_keys=True) + ) def _deserialize_tfma_metric( metric_config: config_pb2.MetricConfig, - custom_objects: Dict[str, - Type[metric_types.Metric]]) -> metric_types.Metric: - """Deserializes a tfma.metrics metric.""" - with tf_keras.utils.custom_object_scope(custom_objects): - return metric_util.deserialize_keras_object({ - 'class_name': metric_config.class_name, - 'config': _metric_config(metric_config.config), - }) + custom_objects: Dict[str, Type[metric_types.Metric]], +) -> metric_types.Metric: + """Deserializes a tfma.metrics metric.""" + with tf_keras.utils.custom_object_scope(custom_objects): + return metric_util.deserialize_keras_object( + { + "class_name": metric_config.class_name, + "config": _metric_config(metric_config.config), + } + ) def _private_tfma_metric(metric: metric_types.Metric) -> metric_types.Metric: - """Creates a private version of given metric.""" - cfg = metric_util.serialize_keras_object(metric) - if not cfg['config']['name'].startswith('_'): - cfg['config']['name'] = '_' + cfg['config']['name'] - with tf_keras.utils.custom_object_scope( - {metric.__class__.__name__: metric.__class__} - ): - return metric_util.deserialize_keras_object(cfg) + """Creates a private version of given metric.""" + cfg = metric_util.serialize_keras_object(metric) + if not cfg["config"]["name"].startswith("_"): + cfg["config"]["name"] = "_" + cfg["config"]["name"] + with tf_keras.utils.custom_object_scope( + {metric.__class__.__name__: metric.__class__} + ): + return metric_util.deserialize_keras_object(cfg) diff --git a/tensorflow_model_analysis/metrics/metric_specs_test.py b/tensorflow_model_analysis/metrics/metric_specs_test.py index 9c307bc648..902ec6e22f 100644 --- a/tensorflow_model_analysis/metrics/metric_specs_test.py +++ b/tensorflow_model_analysis/metrics/metric_specs_test.py @@ -16,679 +16,773 @@ import json import tensorflow as tf -from tensorflow_model_analysis.metrics import calibration -from tensorflow_model_analysis.metrics import confusion_matrix_metrics + # This module should be included as the tests involves replacement of keras # mean squared error and mean absolute error with native TFMA metrics.x -from tensorflow_model_analysis.metrics import mean_regression_error # pylint: disable=unused-import -from tensorflow_model_analysis.metrics import metric_specs -from tensorflow_model_analysis.metrics import metric_types +from tensorflow_model_analysis.metrics import ( + calibration, + confusion_matrix_metrics, + metric_specs, + metric_types, +) from tensorflow_model_analysis.proto import config_pb2 from tensorflow_model_analysis.utils.keras_lib import tf_keras # TODO(b/272542795): Remove once the Keras version has caught up. def _maybe_add_fn_name(kv, name): - # Check new Keras version behavior per b/272542795#comment17. - if 'fn' in tf_keras.losses.MeanAbsoluteError().get_config(): - kv['fn'] = name - return kv + # Check new Keras version behavior per b/272542795#comment17. + if "fn" in tf_keras.losses.MeanAbsoluteError().get_config(): + kv["fn"] = name + return kv class MetricSpecsTest(tf.test.TestCase): + def testSpecsFromMetrics(self): + metrics_specs = metric_specs.specs_from_metrics( + { + "output_name1": [ + tf_keras.metrics.Precision(name="precision"), + tf_keras.metrics.MeanSquaredError("mse"), + tf_keras.losses.MeanAbsoluteError(name="mae"), + ], + "output_name2": [ + confusion_matrix_metrics.Precision(name="precision"), + tf_keras.losses.MeanAbsolutePercentageError(name="mape"), + calibration.MeanPrediction("mean_prediction"), + ], + }, + unweighted_metrics={ + "output_name1": [calibration.MeanLabel("mean_label")], + "output_name2": [tf_keras.metrics.RootMeanSquaredError("rmse")], + }, + model_names=["model_name1", "model_name2"], + binarize=config_pb2.BinarizationOptions(class_ids={"values": [0, 1]}), + aggregate=config_pb2.AggregationOptions(macro_average=True), + ) - def testSpecsFromMetrics(self): - metrics_specs = metric_specs.specs_from_metrics( - { - 'output_name1': [ - tf_keras.metrics.Precision(name='precision'), - tf_keras.metrics.MeanSquaredError('mse'), - tf_keras.losses.MeanAbsoluteError(name='mae'), - ], - 'output_name2': [ - confusion_matrix_metrics.Precision(name='precision'), - tf_keras.losses.MeanAbsolutePercentageError(name='mape'), - calibration.MeanPrediction('mean_prediction'), - ], - }, - unweighted_metrics={ - 'output_name1': [calibration.MeanLabel('mean_label')], - 'output_name2': [tf_keras.metrics.RootMeanSquaredError('rmse')], - }, - model_names=['model_name1', 'model_name2'], - binarize=config_pb2.BinarizationOptions(class_ids={'values': [0, 1]}), - aggregate=config_pb2.AggregationOptions(macro_average=True), - ) - - self.assertLen(metrics_specs, 7) - self.assertProtoEquals( - metrics_specs[0], - config_pb2.MetricsSpec( - metrics=[ - config_pb2.MetricConfig( - class_name='ExampleCount', - config=json.dumps({'name': 'example_count'})), - ], - model_names=['model_name1', 'model_name2'], - example_weights=config_pb2.ExampleWeightOptions(unweighted=True))) - self.assertProtoEquals( - metrics_specs[1], - config_pb2.MetricsSpec( - metrics=[ - config_pb2.MetricConfig( - class_name='WeightedExampleCount', - config=json.dumps({'name': 'weighted_example_count'})), - ], - model_names=['model_name1', 'model_name2'], - output_names=['output_name1'], - example_weights=config_pb2.ExampleWeightOptions(weighted=True))) - self.assertProtoEquals( - metrics_specs[2], - config_pb2.MetricsSpec( - metrics=[ - config_pb2.MetricConfig( - class_name='Precision', - config=json.dumps( - { - 'name': 'precision', - 'class_id': None, - 'thresholds': None, - 'top_k': None, - }, - sort_keys=True, + self.assertLen(metrics_specs, 7) + self.assertProtoEquals( + metrics_specs[0], + config_pb2.MetricsSpec( + metrics=[ + config_pb2.MetricConfig( + class_name="ExampleCount", + config=json.dumps({"name": "example_count"}), ), - ), - config_pb2.MetricConfig( - class_name='MeanSquaredError', - config=json.dumps( - { - 'name': 'mse', - }, - sort_keys=True, + ], + model_names=["model_name1", "model_name2"], + example_weights=config_pb2.ExampleWeightOptions(unweighted=True), + ), + ) + self.assertProtoEquals( + metrics_specs[1], + config_pb2.MetricsSpec( + metrics=[ + config_pb2.MetricConfig( + class_name="WeightedExampleCount", + config=json.dumps({"name": "weighted_example_count"}), ), - ), - config_pb2.MetricConfig( - class_name='MeanAbsoluteError', - config=json.dumps( - _maybe_add_fn_name( - {'name': 'mae'}, 'mean_absolute_error' + ], + model_names=["model_name1", "model_name2"], + output_names=["output_name1"], + example_weights=config_pb2.ExampleWeightOptions(weighted=True), + ), + ) + self.assertProtoEquals( + metrics_specs[2], + config_pb2.MetricsSpec( + metrics=[ + config_pb2.MetricConfig( + class_name="Precision", + config=json.dumps( + { + "name": "precision", + "class_id": None, + "thresholds": None, + "top_k": None, + }, + sort_keys=True, ), - sort_keys=True, ), - ), - ], - model_names=['model_name1', 'model_name2'], - output_names=['output_name1'], - binarize=config_pb2.BinarizationOptions( - class_ids={'values': [0, 1]} + config_pb2.MetricConfig( + class_name="MeanSquaredError", + config=json.dumps( + { + "name": "mse", + }, + sort_keys=True, + ), + ), + config_pb2.MetricConfig( + class_name="MeanAbsoluteError", + config=json.dumps( + _maybe_add_fn_name({"name": "mae"}, "mean_absolute_error"), + sort_keys=True, + ), + ), + ], + model_names=["model_name1", "model_name2"], + output_names=["output_name1"], + binarize=config_pb2.BinarizationOptions(class_ids={"values": [0, 1]}), + aggregate=config_pb2.AggregationOptions(macro_average=True), ), - aggregate=config_pb2.AggregationOptions(macro_average=True), - ), - ) - self.assertProtoEquals( - metrics_specs[3], - config_pb2.MetricsSpec( - metrics=[ - config_pb2.MetricConfig( - class_name='MeanLabel', - config=json.dumps({'name': 'mean_label'})) - ], - model_names=['model_name1', 'model_name2'], - output_names=['output_name1'], - binarize=config_pb2.BinarizationOptions( - class_ids={'values': [0, 1]}), - aggregate=config_pb2.AggregationOptions(macro_average=True), - example_weights=config_pb2.ExampleWeightOptions(unweighted=True))) - self.assertProtoEquals( - metrics_specs[4], - config_pb2.MetricsSpec( - metrics=[ - config_pb2.MetricConfig( - class_name='WeightedExampleCount', - config=json.dumps({'name': 'weighted_example_count'})), - ], - model_names=['model_name1', 'model_name2'], - output_names=['output_name2'], - example_weights=config_pb2.ExampleWeightOptions(weighted=True))) - self.assertProtoEquals( - metrics_specs[5], - config_pb2.MetricsSpec( - metrics=[ - config_pb2.MetricConfig( - class_name='Precision', - config=json.dumps( - { - 'name': 'precision', - }, - sort_keys=True, + ) + self.assertProtoEquals( + metrics_specs[3], + config_pb2.MetricsSpec( + metrics=[ + config_pb2.MetricConfig( + class_name="MeanLabel", + config=json.dumps({"name": "mean_label"}), + ) + ], + model_names=["model_name1", "model_name2"], + output_names=["output_name1"], + binarize=config_pb2.BinarizationOptions(class_ids={"values": [0, 1]}), + aggregate=config_pb2.AggregationOptions(macro_average=True), + example_weights=config_pb2.ExampleWeightOptions(unweighted=True), + ), + ) + self.assertProtoEquals( + metrics_specs[4], + config_pb2.MetricsSpec( + metrics=[ + config_pb2.MetricConfig( + class_name="WeightedExampleCount", + config=json.dumps({"name": "weighted_example_count"}), ), - ), - config_pb2.MetricConfig( - class_name='MeanAbsolutePercentageError', - config=json.dumps( - _maybe_add_fn_name( - {'name': 'mape'}, 'mean_absolute_percentage_error' + ], + model_names=["model_name1", "model_name2"], + output_names=["output_name2"], + example_weights=config_pb2.ExampleWeightOptions(weighted=True), + ), + ) + self.assertProtoEquals( + metrics_specs[5], + config_pb2.MetricsSpec( + metrics=[ + config_pb2.MetricConfig( + class_name="Precision", + config=json.dumps( + { + "name": "precision", + }, + sort_keys=True, ), - sort_keys=True, ), - ), - config_pb2.MetricConfig( - class_name='MeanPrediction', - config=json.dumps({'name': 'mean_prediction'}), - ), - ], - model_names=['model_name1', 'model_name2'], - output_names=['output_name2'], - binarize=config_pb2.BinarizationOptions( - class_ids={'values': [0, 1]} + config_pb2.MetricConfig( + class_name="MeanAbsolutePercentageError", + config=json.dumps( + _maybe_add_fn_name( + {"name": "mape"}, "mean_absolute_percentage_error" + ), + sort_keys=True, + ), + ), + config_pb2.MetricConfig( + class_name="MeanPrediction", + config=json.dumps({"name": "mean_prediction"}), + ), + ], + model_names=["model_name1", "model_name2"], + output_names=["output_name2"], + binarize=config_pb2.BinarizationOptions(class_ids={"values": [0, 1]}), + aggregate=config_pb2.AggregationOptions(macro_average=True), ), - aggregate=config_pb2.AggregationOptions(macro_average=True), - ), - ) - # This is for the compatibility issue when migrating from Keras 2 to Keras - # 3. In the older versions, TFMA is using tf.keras which is pointing to - # Keras 2. However, the newest tf.keras is pointing to Keras 3, and TFMA is - # using tf_keras pacakge for Keras 2. In this case, there is a module name - # discrepancy, where tf.keras modulue is pointing to keras.src.metrics but - # tf_keras modulue is pointing to tf_keras.src.metrics. + ) + # This is for the compatibility issue when migrating from Keras 2 to Keras + # 3. In the older versions, TFMA is using tf.keras which is pointing to + # Keras 2. However, the newest tf.keras is pointing to Keras 3, and TFMA is + # using tf_keras pacakge for Keras 2. In this case, there is a module name + # discrepancy, where tf.keras modulue is pointing to keras.src.metrics but + # tf_keras modulue is pointing to tf_keras.src.metrics. - version_fn = getattr(tf.keras, '__version__', None) - if not version_fn or (version_fn and version_fn.startswith('3.')): - self.assertProtoEquals( - metrics_specs[6], - config_pb2.MetricsSpec( - metrics=[ - config_pb2.MetricConfig( - class_name='RootMeanSquaredError', - module='tf_keras.src.metrics.regression_metrics', - config=json.dumps( - {'name': 'rmse', 'dtype': 'float32'}, sort_keys=True - ), - ) - ], - model_names=['model_name1', 'model_name2'], - output_names=['output_name2'], - binarize=config_pb2.BinarizationOptions( - class_ids={'values': [0, 1]} - ), - aggregate=config_pb2.AggregationOptions(macro_average=True), - example_weights=config_pb2.ExampleWeightOptions(unweighted=True), - ), - ) - else: - self.assertProtoEquals( - metrics_specs[6], - config_pb2.MetricsSpec( - metrics=[ - config_pb2.MetricConfig( - class_name='RootMeanSquaredError', - module='tf_keras.metrics.regression_metrics', - config=json.dumps( - {'name': 'rmse', 'dtype': 'float32'}, sort_keys=True - ), - ) - ], - model_names=['model_name1', 'model_name2'], - output_names=['output_name2'], - binarize=config_pb2.BinarizationOptions( - class_ids={'values': [0, 1]} - ), - aggregate=config_pb2.AggregationOptions(macro_average=True), - example_weights=config_pb2.ExampleWeightOptions(unweighted=True), - ), - ) + version_fn = getattr(tf.keras, "__version__", None) + if not version_fn or (version_fn and version_fn.startswith("3.")): + self.assertProtoEquals( + metrics_specs[6], + config_pb2.MetricsSpec( + metrics=[ + config_pb2.MetricConfig( + class_name="RootMeanSquaredError", + module="tf_keras.src.metrics.regression_metrics", + config=json.dumps( + {"name": "rmse", "dtype": "float32"}, sort_keys=True + ), + ) + ], + model_names=["model_name1", "model_name2"], + output_names=["output_name2"], + binarize=config_pb2.BinarizationOptions( + class_ids={"values": [0, 1]} + ), + aggregate=config_pb2.AggregationOptions(macro_average=True), + example_weights=config_pb2.ExampleWeightOptions(unweighted=True), + ), + ) + else: + self.assertProtoEquals( + metrics_specs[6], + config_pb2.MetricsSpec( + metrics=[ + config_pb2.MetricConfig( + class_name="RootMeanSquaredError", + module="tf_keras.metrics.regression_metrics", + config=json.dumps( + {"name": "rmse", "dtype": "float32"}, sort_keys=True + ), + ) + ], + model_names=["model_name1", "model_name2"], + output_names=["output_name2"], + binarize=config_pb2.BinarizationOptions( + class_ids={"values": [0, 1]} + ), + aggregate=config_pb2.AggregationOptions(macro_average=True), + example_weights=config_pb2.ExampleWeightOptions(unweighted=True), + ), + ) - def testMetricKeysToSkipForConfidenceIntervals(self): - metrics_specs = [ - config_pb2.MetricsSpec( - metrics=[ - config_pb2.MetricConfig( - class_name='ExampleCount', - config=json.dumps({'name': 'example_count'}), - threshold=config_pb2.MetricThreshold( - value_threshold=config_pb2.GenericValueThreshold())), - config_pb2.MetricConfig( - class_name='MeanLabel', - config=json.dumps({'name': 'mean_label'}), - threshold=config_pb2.MetricThreshold( - change_threshold=config_pb2.GenericChangeThreshold())), - config_pb2.MetricConfig( - class_name='MeanSquaredError', - config=json.dumps({'name': 'mse'}), - threshold=config_pb2.MetricThreshold( - change_threshold=config_pb2.GenericChangeThreshold())) - ], - model_names=['model_name1', 'model_name2'], - output_names=['output_name1', 'output_name2']), - ] - metrics_specs += metric_specs.specs_from_metrics( - [tf_keras.metrics.MeanSquaredError('mse')], - model_names=['model_name1', 'model_name2'], - ) - keys = metric_specs.metric_keys_to_skip_for_confidence_intervals( - metrics_specs, eval_config=config_pb2.EvalConfig()) - self.assertLen(keys, 8) - self.assertIn( - metric_types.MetricKey( - name='example_count', - model_name='model_name1', - output_name='output_name1'), keys) - self.assertIn( - metric_types.MetricKey( - name='example_count', - model_name='model_name1', - output_name='output_name2'), keys) - self.assertIn( - metric_types.MetricKey( - name='example_count', - model_name='model_name2', - output_name='output_name1'), keys) - self.assertIn( - metric_types.MetricKey( - name='example_count', - model_name='model_name2', - output_name='output_name2'), keys) - self.assertIn( - metric_types.MetricKey(name='example_count', model_name='model_name1'), - keys) - self.assertIn( - metric_types.MetricKey( - name='weighted_example_count', - model_name='model_name1', - example_weighted=True), keys) - self.assertIn( - metric_types.MetricKey(name='example_count', model_name='model_name2'), - keys) - self.assertIn( - metric_types.MetricKey( - name='weighted_example_count', - model_name='model_name2', - example_weighted=True), keys) + def testMetricKeysToSkipForConfidenceIntervals(self): + metrics_specs = [ + config_pb2.MetricsSpec( + metrics=[ + config_pb2.MetricConfig( + class_name="ExampleCount", + config=json.dumps({"name": "example_count"}), + threshold=config_pb2.MetricThreshold( + value_threshold=config_pb2.GenericValueThreshold() + ), + ), + config_pb2.MetricConfig( + class_name="MeanLabel", + config=json.dumps({"name": "mean_label"}), + threshold=config_pb2.MetricThreshold( + change_threshold=config_pb2.GenericChangeThreshold() + ), + ), + config_pb2.MetricConfig( + class_name="MeanSquaredError", + config=json.dumps({"name": "mse"}), + threshold=config_pb2.MetricThreshold( + change_threshold=config_pb2.GenericChangeThreshold() + ), + ), + ], + model_names=["model_name1", "model_name2"], + output_names=["output_name1", "output_name2"], + ), + ] + metrics_specs += metric_specs.specs_from_metrics( + [tf_keras.metrics.MeanSquaredError("mse")], + model_names=["model_name1", "model_name2"], + ) + keys = metric_specs.metric_keys_to_skip_for_confidence_intervals( + metrics_specs, eval_config=config_pb2.EvalConfig() + ) + self.assertLen(keys, 8) + self.assertIn( + metric_types.MetricKey( + name="example_count", + model_name="model_name1", + output_name="output_name1", + ), + keys, + ) + self.assertIn( + metric_types.MetricKey( + name="example_count", + model_name="model_name1", + output_name="output_name2", + ), + keys, + ) + self.assertIn( + metric_types.MetricKey( + name="example_count", + model_name="model_name2", + output_name="output_name1", + ), + keys, + ) + self.assertIn( + metric_types.MetricKey( + name="example_count", + model_name="model_name2", + output_name="output_name2", + ), + keys, + ) + self.assertIn( + metric_types.MetricKey(name="example_count", model_name="model_name1"), keys + ) + self.assertIn( + metric_types.MetricKey( + name="weighted_example_count", + model_name="model_name1", + example_weighted=True, + ), + keys, + ) + self.assertIn( + metric_types.MetricKey(name="example_count", model_name="model_name2"), keys + ) + self.assertIn( + metric_types.MetricKey( + name="weighted_example_count", + model_name="model_name2", + example_weighted=True, + ), + keys, + ) - def testMetricThresholdsFromMetricsSpecs(self): - slice_specs = [ - config_pb2.SlicingSpec(feature_keys=['feature1']), - config_pb2.SlicingSpec(feature_values={'feature2': 'value1'}) - ] + def testMetricThresholdsFromMetricsSpecs(self): + slice_specs = [ + config_pb2.SlicingSpec(feature_keys=["feature1"]), + config_pb2.SlicingSpec(feature_values={"feature2": "value1"}), + ] - # For cross slice tests. - baseline_slice_spec = config_pb2.SlicingSpec(feature_keys=['feature3']) + # For cross slice tests. + baseline_slice_spec = config_pb2.SlicingSpec(feature_keys=["feature3"]) - metrics_specs = [ - config_pb2.MetricsSpec( - thresholds={ - 'auc': - config_pb2.MetricThreshold( - value_threshold=config_pb2.GenericValueThreshold()), - 'mean/label': - config_pb2.MetricThreshold( + metrics_specs = [ + config_pb2.MetricsSpec( + thresholds={ + "auc": config_pb2.MetricThreshold( + value_threshold=config_pb2.GenericValueThreshold() + ), + "mean/label": config_pb2.MetricThreshold( value_threshold=config_pb2.GenericValueThreshold(), - change_threshold=config_pb2.GenericChangeThreshold()), - 'mse': - config_pb2.MetricThreshold( - change_threshold=config_pb2.GenericChangeThreshold()) - }, - per_slice_thresholds={ - 'auc': - config_pb2.PerSliceMetricThresholds(thresholds=[ - config_pb2.PerSliceMetricThreshold( - slicing_specs=slice_specs, - threshold=config_pb2.MetricThreshold( - value_threshold=config_pb2 - .GenericValueThreshold())) - ]), - 'mean/label': - config_pb2.PerSliceMetricThresholds(thresholds=[ - config_pb2.PerSliceMetricThreshold( - slicing_specs=slice_specs, - threshold=config_pb2.MetricThreshold( - value_threshold=config_pb2 - .GenericValueThreshold(), - change_threshold=config_pb2 - .GenericChangeThreshold())) - ]) - }, - cross_slice_thresholds={ - 'auc': - config_pb2.CrossSliceMetricThresholds(thresholds=[ - config_pb2.CrossSliceMetricThreshold( - cross_slicing_specs=[ - config_pb2.CrossSlicingSpec( - baseline_spec=baseline_slice_spec, - slicing_specs=slice_specs) - ], - threshold=config_pb2.MetricThreshold( - value_threshold=config_pb2 - .GenericValueThreshold(), - change_threshold=config_pb2 - .GenericChangeThreshold())) - ]), - 'mse': - config_pb2.CrossSliceMetricThresholds(thresholds=[ - config_pb2.CrossSliceMetricThreshold( - cross_slicing_specs=[ - config_pb2.CrossSlicingSpec( - baseline_spec=baseline_slice_spec, - slicing_specs=slice_specs) - ], - threshold=config_pb2.MetricThreshold( - change_threshold=config_pb2 - .GenericChangeThreshold())), - # Test for duplicate cross_slicing_spec. - config_pb2.CrossSliceMetricThreshold( - cross_slicing_specs=[ - config_pb2.CrossSlicingSpec( - baseline_spec=baseline_slice_spec, - slicing_specs=slice_specs) - ], - threshold=config_pb2.MetricThreshold( - value_threshold=config_pb2 - .GenericValueThreshold())) - ]) - }, - model_names=['model_name'], - output_names=['output_name']), - config_pb2.MetricsSpec( - metrics=[ - config_pb2.MetricConfig( - class_name='ExampleCount', - config=json.dumps({'name': 'example_count'}), - threshold=config_pb2.MetricThreshold( - value_threshold=config_pb2.GenericValueThreshold())) - ], - model_names=['model_name1', 'model_name2'], - example_weights=config_pb2.ExampleWeightOptions(unweighted=True)), - config_pb2.MetricsSpec( - metrics=[ - config_pb2.MetricConfig( - class_name='WeightedExampleCount', - config=json.dumps({'name': 'weighted_example_count'}), - threshold=config_pb2.MetricThreshold( - value_threshold=config_pb2.GenericValueThreshold())) - ], - model_names=['model_name1', 'model_name2'], - output_names=['output_name1', 'output_name2'], - example_weights=config_pb2.ExampleWeightOptions(weighted=True)), - config_pb2.MetricsSpec( - metrics=[ - config_pb2.MetricConfig( - class_name='MeanSquaredError', - config=json.dumps({'name': 'mse'}), - threshold=config_pb2.MetricThreshold( - change_threshold=config_pb2.GenericChangeThreshold())), - config_pb2.MetricConfig( - class_name='MeanLabel', - config=json.dumps({'name': 'mean_label'}), - threshold=config_pb2.MetricThreshold( - change_threshold=config_pb2.GenericChangeThreshold()), - per_slice_thresholds=[ - config_pb2.PerSliceMetricThreshold( - slicing_specs=slice_specs, - threshold=config_pb2.MetricThreshold( - change_threshold=config_pb2 - .GenericChangeThreshold())), - ], - cross_slice_thresholds=[ - config_pb2.CrossSliceMetricThreshold( - cross_slicing_specs=[ - config_pb2.CrossSlicingSpec( - baseline_spec=baseline_slice_spec, - slicing_specs=slice_specs) - ], - threshold=config_pb2.MetricThreshold( - change_threshold=config_pb2 - .GenericChangeThreshold())) - ]), - ], - model_names=['model_name'], - output_names=['output_name'], - binarize=config_pb2.BinarizationOptions( - class_ids={'values': [0, 1]}), - aggregate=config_pb2.AggregationOptions( - macro_average=True, class_weights={ - 0: 1.0, - 1: 1.0 - })) - ] + change_threshold=config_pb2.GenericChangeThreshold(), + ), + "mse": config_pb2.MetricThreshold( + change_threshold=config_pb2.GenericChangeThreshold() + ), + }, + per_slice_thresholds={ + "auc": config_pb2.PerSliceMetricThresholds( + thresholds=[ + config_pb2.PerSliceMetricThreshold( + slicing_specs=slice_specs, + threshold=config_pb2.MetricThreshold( + value_threshold=config_pb2.GenericValueThreshold() + ), + ) + ] + ), + "mean/label": config_pb2.PerSliceMetricThresholds( + thresholds=[ + config_pb2.PerSliceMetricThreshold( + slicing_specs=slice_specs, + threshold=config_pb2.MetricThreshold( + value_threshold=config_pb2.GenericValueThreshold(), + change_threshold=config_pb2.GenericChangeThreshold(), + ), + ) + ] + ), + }, + cross_slice_thresholds={ + "auc": config_pb2.CrossSliceMetricThresholds( + thresholds=[ + config_pb2.CrossSliceMetricThreshold( + cross_slicing_specs=[ + config_pb2.CrossSlicingSpec( + baseline_spec=baseline_slice_spec, + slicing_specs=slice_specs, + ) + ], + threshold=config_pb2.MetricThreshold( + value_threshold=config_pb2.GenericValueThreshold(), + change_threshold=config_pb2.GenericChangeThreshold(), + ), + ) + ] + ), + "mse": config_pb2.CrossSliceMetricThresholds( + thresholds=[ + config_pb2.CrossSliceMetricThreshold( + cross_slicing_specs=[ + config_pb2.CrossSlicingSpec( + baseline_spec=baseline_slice_spec, + slicing_specs=slice_specs, + ) + ], + threshold=config_pb2.MetricThreshold( + change_threshold=config_pb2.GenericChangeThreshold() + ), + ), + # Test for duplicate cross_slicing_spec. + config_pb2.CrossSliceMetricThreshold( + cross_slicing_specs=[ + config_pb2.CrossSlicingSpec( + baseline_spec=baseline_slice_spec, + slicing_specs=slice_specs, + ) + ], + threshold=config_pb2.MetricThreshold( + value_threshold=config_pb2.GenericValueThreshold() + ), + ), + ] + ), + }, + model_names=["model_name"], + output_names=["output_name"], + ), + config_pb2.MetricsSpec( + metrics=[ + config_pb2.MetricConfig( + class_name="ExampleCount", + config=json.dumps({"name": "example_count"}), + threshold=config_pb2.MetricThreshold( + value_threshold=config_pb2.GenericValueThreshold() + ), + ) + ], + model_names=["model_name1", "model_name2"], + example_weights=config_pb2.ExampleWeightOptions(unweighted=True), + ), + config_pb2.MetricsSpec( + metrics=[ + config_pb2.MetricConfig( + class_name="WeightedExampleCount", + config=json.dumps({"name": "weighted_example_count"}), + threshold=config_pb2.MetricThreshold( + value_threshold=config_pb2.GenericValueThreshold() + ), + ) + ], + model_names=["model_name1", "model_name2"], + output_names=["output_name1", "output_name2"], + example_weights=config_pb2.ExampleWeightOptions(weighted=True), + ), + config_pb2.MetricsSpec( + metrics=[ + config_pb2.MetricConfig( + class_name="MeanSquaredError", + config=json.dumps({"name": "mse"}), + threshold=config_pb2.MetricThreshold( + change_threshold=config_pb2.GenericChangeThreshold() + ), + ), + config_pb2.MetricConfig( + class_name="MeanLabel", + config=json.dumps({"name": "mean_label"}), + threshold=config_pb2.MetricThreshold( + change_threshold=config_pb2.GenericChangeThreshold() + ), + per_slice_thresholds=[ + config_pb2.PerSliceMetricThreshold( + slicing_specs=slice_specs, + threshold=config_pb2.MetricThreshold( + change_threshold=config_pb2.GenericChangeThreshold() + ), + ), + ], + cross_slice_thresholds=[ + config_pb2.CrossSliceMetricThreshold( + cross_slicing_specs=[ + config_pb2.CrossSlicingSpec( + baseline_spec=baseline_slice_spec, + slicing_specs=slice_specs, + ) + ], + threshold=config_pb2.MetricThreshold( + change_threshold=config_pb2.GenericChangeThreshold() + ), + ) + ], + ), + ], + model_names=["model_name"], + output_names=["output_name"], + binarize=config_pb2.BinarizationOptions(class_ids={"values": [0, 1]}), + aggregate=config_pb2.AggregationOptions( + macro_average=True, class_weights={0: 1.0, 1: 1.0} + ), + ), + ] - thresholds = metric_specs.metric_thresholds_from_metrics_specs( - metrics_specs, eval_config=config_pb2.EvalConfig()) + thresholds = metric_specs.metric_thresholds_from_metrics_specs( + metrics_specs, eval_config=config_pb2.EvalConfig() + ) - expected_keys_and_threshold_counts = { - metric_types.MetricKey( - name='auc', - model_name='model_name', - output_name='output_name', - is_diff=False, - example_weighted=None): - 4, - metric_types.MetricKey( - name='auc', - model_name='model_name', - output_name='output_name', - is_diff=True, - example_weighted=None): - 1, - metric_types.MetricKey( - name='mean/label', - model_name='model_name', - output_name='output_name', - is_diff=True, - example_weighted=None): - 3, - metric_types.MetricKey( - name='mean/label', - model_name='model_name', - output_name='output_name', - is_diff=False, - example_weighted=None): - 3, - metric_types.MetricKey(name='example_count', model_name='model_name1'): - 1, - metric_types.MetricKey(name='example_count', model_name='model_name2'): - 1, - metric_types.MetricKey( - name='weighted_example_count', - model_name='model_name1', - output_name='output_name1', - example_weighted=True): - 1, - metric_types.MetricKey( - name='weighted_example_count', - model_name='model_name1', - output_name='output_name2', - example_weighted=True): - 1, - metric_types.MetricKey( - name='weighted_example_count', - model_name='model_name2', - output_name='output_name1', - example_weighted=True): - 1, - metric_types.MetricKey( - name='weighted_example_count', - model_name='model_name2', - output_name='output_name2', - example_weighted=True): - 1, - metric_types.MetricKey( - name='mse', - model_name='model_name', - output_name='output_name', - sub_key=metric_types.SubKey(class_id=0), - is_diff=True): - 1, - metric_types.MetricKey( - name='mse', - model_name='model_name', - output_name='output_name', - sub_key=metric_types.SubKey(class_id=1), - is_diff=True): - 1, - metric_types.MetricKey( - name='mse', - model_name='model_name', - output_name='output_name', - is_diff=False, - example_weighted=None): - 1, - metric_types.MetricKey( - name='mse', - model_name='model_name', - output_name='output_name', - is_diff=True, - example_weighted=None): - 2, - metric_types.MetricKey( - name='mse', - model_name='model_name', - output_name='output_name', - aggregation_type=metric_types.AggregationType(macro_average=True), - is_diff=True): - 1, - metric_types.MetricKey( - name='mean_label', - model_name='model_name', - output_name='output_name', - sub_key=metric_types.SubKey(class_id=0), - is_diff=True): - 4, - metric_types.MetricKey( - name='mean_label', - model_name='model_name', - output_name='output_name', - sub_key=metric_types.SubKey(class_id=1), - is_diff=True): - 4, - metric_types.MetricKey( - name='mean_label', - model_name='model_name', - output_name='output_name', - aggregation_type=metric_types.AggregationType(macro_average=True), - is_diff=True): - 4 - } - self.assertLen(thresholds, len(expected_keys_and_threshold_counts)) - for key, count in expected_keys_and_threshold_counts.items(): - self.assertIn(key, thresholds) - self.assertLen(thresholds[key], count, 'failed for key {}'.format(key)) + expected_keys_and_threshold_counts = { + metric_types.MetricKey( + name="auc", + model_name="model_name", + output_name="output_name", + is_diff=False, + example_weighted=None, + ): 4, + metric_types.MetricKey( + name="auc", + model_name="model_name", + output_name="output_name", + is_diff=True, + example_weighted=None, + ): 1, + metric_types.MetricKey( + name="mean/label", + model_name="model_name", + output_name="output_name", + is_diff=True, + example_weighted=None, + ): 3, + metric_types.MetricKey( + name="mean/label", + model_name="model_name", + output_name="output_name", + is_diff=False, + example_weighted=None, + ): 3, + metric_types.MetricKey(name="example_count", model_name="model_name1"): 1, + metric_types.MetricKey(name="example_count", model_name="model_name2"): 1, + metric_types.MetricKey( + name="weighted_example_count", + model_name="model_name1", + output_name="output_name1", + example_weighted=True, + ): 1, + metric_types.MetricKey( + name="weighted_example_count", + model_name="model_name1", + output_name="output_name2", + example_weighted=True, + ): 1, + metric_types.MetricKey( + name="weighted_example_count", + model_name="model_name2", + output_name="output_name1", + example_weighted=True, + ): 1, + metric_types.MetricKey( + name="weighted_example_count", + model_name="model_name2", + output_name="output_name2", + example_weighted=True, + ): 1, + metric_types.MetricKey( + name="mse", + model_name="model_name", + output_name="output_name", + sub_key=metric_types.SubKey(class_id=0), + is_diff=True, + ): 1, + metric_types.MetricKey( + name="mse", + model_name="model_name", + output_name="output_name", + sub_key=metric_types.SubKey(class_id=1), + is_diff=True, + ): 1, + metric_types.MetricKey( + name="mse", + model_name="model_name", + output_name="output_name", + is_diff=False, + example_weighted=None, + ): 1, + metric_types.MetricKey( + name="mse", + model_name="model_name", + output_name="output_name", + is_diff=True, + example_weighted=None, + ): 2, + metric_types.MetricKey( + name="mse", + model_name="model_name", + output_name="output_name", + aggregation_type=metric_types.AggregationType(macro_average=True), + is_diff=True, + ): 1, + metric_types.MetricKey( + name="mean_label", + model_name="model_name", + output_name="output_name", + sub_key=metric_types.SubKey(class_id=0), + is_diff=True, + ): 4, + metric_types.MetricKey( + name="mean_label", + model_name="model_name", + output_name="output_name", + sub_key=metric_types.SubKey(class_id=1), + is_diff=True, + ): 4, + metric_types.MetricKey( + name="mean_label", + model_name="model_name", + output_name="output_name", + aggregation_type=metric_types.AggregationType(macro_average=True), + is_diff=True, + ): 4, + } + self.assertLen(thresholds, len(expected_keys_and_threshold_counts)) + for key, count in expected_keys_and_threshold_counts.items(): + self.assertIn(key, thresholds) + self.assertLen(thresholds[key], count, f"failed for key {key}") - def testToComputations(self): - computations = metric_specs.to_computations( - metric_specs.specs_from_metrics( - [ - tf_keras.metrics.MeanSquaredError('mse'), - # Add a loss exactly same as metric - # (https://github.com/tensorflow/tfx/issues/1550) - tf_keras.losses.MeanSquaredError(name='loss'), - calibration.MeanLabel('mean_label'), - ], - model_names=['model_name'], - output_names=['output_1', 'output_2'], - output_weights={'output_1': 1.0, 'output_2': 1.0}, - binarize=config_pb2.BinarizationOptions( - class_ids={'values': [0, 1]} - ), - aggregate=config_pb2.AggregationOptions( - macro_average=True, class_weights={0: 1.0, 1: 1.0} + def testToComputations(self): + computations = metric_specs.to_computations( + metric_specs.specs_from_metrics( + [ + tf_keras.metrics.MeanSquaredError("mse"), + # Add a loss exactly same as metric + # (https://github.com/tensorflow/tfx/issues/1550) + tf_keras.losses.MeanSquaredError(name="loss"), + calibration.MeanLabel("mean_label"), + ], + model_names=["model_name"], + output_names=["output_1", "output_2"], + output_weights={"output_1": 1.0, "output_2": 1.0}, + binarize=config_pb2.BinarizationOptions(class_ids={"values": [0, 1]}), + aggregate=config_pb2.AggregationOptions( + macro_average=True, class_weights={0: 1.0, 1: 1.0} + ), ), - ), - config_pb2.EvalConfig(), - ) + config_pb2.EvalConfig(), + ) - keys = [] - for m in computations: - for k in m.keys: - if not k.name.startswith('_'): - keys.append(k) - self.assertLen(keys, 31) - self.assertIn( - metric_types.MetricKey(name='example_count', model_name='model_name'), - keys) - for output_name in ('output_1', 'output_2', ''): - self.assertIn( - metric_types.MetricKey( - name='weighted_example_count', - model_name='model_name', - output_name=output_name, - example_weighted=True), keys) - self.assertIn( - metric_types.MetricKey( - name='mse', - model_name='model_name', - output_name=output_name, - sub_key=metric_types.SubKey(class_id=0)), keys) - self.assertIn( - metric_types.MetricKey( - name='mse', - model_name='model_name', - output_name=output_name, - sub_key=metric_types.SubKey(class_id=1)), keys) - aggregation_type = metric_types.AggregationType( - macro_average=True) if output_name else None - self.assertIn( - metric_types.MetricKey( - name='mse', - model_name='model_name', - output_name=output_name, - aggregation_type=aggregation_type), keys) - self.assertIn( - metric_types.MetricKey( - name='loss', - model_name='model_name', - output_name=output_name, - sub_key=metric_types.SubKey(class_id=0)), keys) - self.assertIn( - metric_types.MetricKey( - name='loss', - model_name='model_name', - output_name=output_name, - sub_key=metric_types.SubKey(class_id=1)), keys) - aggregation_type = metric_types.AggregationType( - macro_average=True) if output_name else None - self.assertIn( - metric_types.MetricKey( - name='loss', - model_name='model_name', - output_name=output_name, - aggregation_type=aggregation_type), keys) - self.assertIn( - metric_types.MetricKey( - name='mean_label', - model_name='model_name', - output_name=output_name, - sub_key=metric_types.SubKey(class_id=0)), keys) - self.assertIn( - metric_types.MetricKey( - name='mean_label', - model_name='model_name', - output_name=output_name, - sub_key=metric_types.SubKey(class_id=1)), keys) - aggregation_type = metric_types.AggregationType( - macro_average=True) if output_name else None - self.assertIn( - metric_types.MetricKey( - name='mean_label', - model_name='model_name', - output_name=output_name, - aggregation_type=aggregation_type), keys) + keys = [] + for m in computations: + for k in m.keys: + if not k.name.startswith("_"): + keys.append(k) + self.assertLen(keys, 31) + self.assertIn( + metric_types.MetricKey(name="example_count", model_name="model_name"), keys + ) + for output_name in ("output_1", "output_2", ""): + self.assertIn( + metric_types.MetricKey( + name="weighted_example_count", + model_name="model_name", + output_name=output_name, + example_weighted=True, + ), + keys, + ) + self.assertIn( + metric_types.MetricKey( + name="mse", + model_name="model_name", + output_name=output_name, + sub_key=metric_types.SubKey(class_id=0), + ), + keys, + ) + self.assertIn( + metric_types.MetricKey( + name="mse", + model_name="model_name", + output_name=output_name, + sub_key=metric_types.SubKey(class_id=1), + ), + keys, + ) + aggregation_type = ( + metric_types.AggregationType(macro_average=True) + if output_name + else None + ) + self.assertIn( + metric_types.MetricKey( + name="mse", + model_name="model_name", + output_name=output_name, + aggregation_type=aggregation_type, + ), + keys, + ) + self.assertIn( + metric_types.MetricKey( + name="loss", + model_name="model_name", + output_name=output_name, + sub_key=metric_types.SubKey(class_id=0), + ), + keys, + ) + self.assertIn( + metric_types.MetricKey( + name="loss", + model_name="model_name", + output_name=output_name, + sub_key=metric_types.SubKey(class_id=1), + ), + keys, + ) + aggregation_type = ( + metric_types.AggregationType(macro_average=True) + if output_name + else None + ) + self.assertIn( + metric_types.MetricKey( + name="loss", + model_name="model_name", + output_name=output_name, + aggregation_type=aggregation_type, + ), + keys, + ) + self.assertIn( + metric_types.MetricKey( + name="mean_label", + model_name="model_name", + output_name=output_name, + sub_key=metric_types.SubKey(class_id=0), + ), + keys, + ) + self.assertIn( + metric_types.MetricKey( + name="mean_label", + model_name="model_name", + output_name=output_name, + sub_key=metric_types.SubKey(class_id=1), + ), + keys, + ) + aggregation_type = ( + metric_types.AggregationType(macro_average=True) + if output_name + else None + ) + self.assertIn( + metric_types.MetricKey( + name="mean_label", + model_name="model_name", + output_name=output_name, + aggregation_type=aggregation_type, + ), + keys, + ) - # This tests b/155810786 - def testToComputationsWithMixedAggregationAndNonAggregationMetrics(self): - computations = metric_specs.to_computations([ - config_pb2.MetricsSpec( - metrics=[config_pb2.MetricConfig( - class_name='CategoricalAccuracy')]), - config_pb2.MetricsSpec( - metrics=[config_pb2.MetricConfig(class_name='BinaryCrossentropy')], - binarize=config_pb2.BinarizationOptions(class_ids={'values': [1]}), - aggregate=config_pb2.AggregationOptions(micro_average=True)) - ], config_pb2.EvalConfig()) + # This tests b/155810786 + def testToComputationsWithMixedAggregationAndNonAggregationMetrics(self): + computations = metric_specs.to_computations( + [ + config_pb2.MetricsSpec( + metrics=[config_pb2.MetricConfig(class_name="CategoricalAccuracy")] + ), + config_pb2.MetricsSpec( + metrics=[config_pb2.MetricConfig(class_name="BinaryCrossentropy")], + binarize=config_pb2.BinarizationOptions(class_ids={"values": [1]}), + aggregate=config_pb2.AggregationOptions(micro_average=True), + ), + ], + config_pb2.EvalConfig(), + ) - # 3 separate computations should be used (one for aggregated metrics, one - # for non-aggregated metrics, and one for metrics associated with class 1) - self.assertLen(computations, 3) + # 3 separate computations should be used (one for aggregated metrics, one + # for non-aggregated metrics, and one for metrics associated with class 1) + self.assertLen(computations, 3) -if __name__ == '__main__': - tf.test.main() +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow_model_analysis/metrics/metric_types.py b/tensorflow_model_analysis/metrics/metric_types.py index 66904c1069..8406a6a7d5 100644 --- a/tensorflow_model_analysis/metrics/metric_types.py +++ b/tensorflow_model_analysis/metrics/metric_types.py @@ -16,121 +16,131 @@ import copy import functools import inspect -from typing import Any, Callable, Dict, Iterable, Iterator, List, MutableMapping, NamedTuple, Optional, Type, Union +from typing import ( + Any, + Callable, + Dict, + Iterable, + Iterator, + List, + MutableMapping, + NamedTuple, + Optional, + Type, + Union, +) import apache_beam as beam +from google.protobuf import text_format +from tensorflow_metadata.proto.v0 import schema_pb2 + from tensorflow_model_analysis import constants from tensorflow_model_analysis.api import types -from tensorflow_model_analysis.proto import config_pb2 -from tensorflow_model_analysis.proto import metrics_for_slice_pb2 +from tensorflow_model_analysis.proto import config_pb2, metrics_for_slice_pb2 from tensorflow_model_analysis.utils import util -from google.protobuf import text_format -from tensorflow_metadata.proto.v0 import schema_pb2 - # LINT.IfChange # A separate version from proto is used here because protos are not hashable and # SerializeToString is not guaranteed to be stable between different binaries. @functools.total_ordering -class SubKey( - NamedTuple('SubKey', [('class_id', int), ('k', int), ('top_k', int)]) -): - """A SubKey identifies a sub-types of metrics and plots. - - Only one of class_id, k, or top_k can be set at a time. - - Attributes: - class_id: Used with multi-class metrics to identify a specific class ID. - k: Used with multi-class metrics to identify the kth predicted value. - top_k: Used with multi-class and ranking metrics to identify top-k predicted - values. - """ - - # IfChange (should be preceded by LINT, but cannot nest LINT) - def __new__( - cls, - class_id: Optional[int] = None, - k: Optional[int] = None, - top_k: Optional[int] = None, - ): - if k is not None: - if top_k is not None: - raise ValueError( - 'k and top_k cannot both be set at the same time: ' - f'k={k}, top_k={top_k}' - ) - if class_id is not None: - raise ValueError( - 'k and class_id cannot both be set at the same time: ' - f'k={k}, class_id={class_id}' - ) - if k is not None and k < 1: - raise ValueError('attempt to create metric with k < 1: k={}'.format(k)) - if top_k is not None and top_k < 1: - raise ValueError( - 'attempt to create metric with top_k < 1: top_k={}'.format(top_k) - ) - return super(SubKey, cls).__new__(cls, class_id, k, top_k) - - # ThenChange(../api/model_eval_lib.py) - - def __eq__(self, other): - return tuple(self) == other - - def __lt__(self, other): - # Python3 does not allow comparison of NoneType, remove if present. - return tuple(x if x is not None else -1 for x in self) < tuple( - x if x is not None else -1 for x in other or () - ) +class SubKey(NamedTuple("SubKey", [("class_id", int), ("k", int), ("top_k", int)])): + """A SubKey identifies a sub-types of metrics and plots. - def __hash__(self): - return hash(tuple(self)) + Only one of class_id, k, or top_k can be set at a time. - def __str__(self) -> str: - if self.k is not None: - return 'k:' + str(self.k) - else: - sub_key_str_list = [] - if self.class_id is not None: - sub_key_str_list.append('classId:' + str(self.class_id)) - if self.top_k is not None: - sub_key_str_list.append('topK:' + str(self.top_k)) - if not sub_key_str_list: - raise NotImplementedError(( - 'A non-existent SubKey should be represented as None, not as ', - 'SubKey(None, None, None).', - )) - return ' '.join(sub_key_str_list) - - def to_proto(self) -> metrics_for_slice_pb2.SubKey: - """Converts key to proto.""" - sub_key = metrics_for_slice_pb2.SubKey() - if self.class_id is not None: - sub_key.class_id.value = self.class_id - if self.k is not None: - sub_key.k.value = self.k - if self.top_k is not None: - sub_key.top_k.value = self.top_k - return sub_key - - @staticmethod - def from_proto(pb: metrics_for_slice_pb2.SubKey) -> Optional['SubKey']: - """Creates class from proto.""" - class_id = None - if pb.HasField('class_id'): - class_id = pb.class_id.value - k = None - if pb.HasField('k'): - k = pb.k.value - top_k = None - if pb.HasField('top_k'): - top_k = pb.top_k.value - if class_id is None and k is None and top_k is None: - return None - else: - return SubKey(class_id=class_id, k=k, top_k=top_k) + Attributes + ---------- + class_id: Used with multi-class metrics to identify a specific class ID. + k: Used with multi-class metrics to identify the kth predicted value. + top_k: Used with multi-class and ranking metrics to identify top-k predicted + values. + """ + + # IfChange (should be preceded by LINT, but cannot nest LINT) + def __new__( + cls, + class_id: Optional[int] = None, + k: Optional[int] = None, + top_k: Optional[int] = None, + ): + if k is not None: + if top_k is not None: + raise ValueError( + "k and top_k cannot both be set at the same time: " + f"k={k}, top_k={top_k}" + ) + if class_id is not None: + raise ValueError( + "k and class_id cannot both be set at the same time: " + f"k={k}, class_id={class_id}" + ) + if k is not None and k < 1: + raise ValueError(f"attempt to create metric with k < 1: k={k}") + if top_k is not None and top_k < 1: + raise ValueError(f"attempt to create metric with top_k < 1: top_k={top_k}") + return super(SubKey, cls).__new__(cls, class_id, k, top_k) + + # ThenChange(../api/model_eval_lib.py) + + def __eq__(self, other): + return tuple(self) == other + + def __lt__(self, other): + # Python3 does not allow comparison of NoneType, remove if present. + return tuple(x if x is not None else -1 for x in self) < tuple( + x if x is not None else -1 for x in other or () + ) + + def __hash__(self): + return hash(tuple(self)) + + def __str__(self) -> str: + if self.k is not None: + return "k:" + str(self.k) + else: + sub_key_str_list = [] + if self.class_id is not None: + sub_key_str_list.append("classId:" + str(self.class_id)) + if self.top_k is not None: + sub_key_str_list.append("topK:" + str(self.top_k)) + if not sub_key_str_list: + raise NotImplementedError( + ( + "A non-existent SubKey should be represented as None, not as ", + "SubKey(None, None, None).", + ) + ) + return " ".join(sub_key_str_list) + + def to_proto(self) -> metrics_for_slice_pb2.SubKey: + """Converts key to proto.""" + sub_key = metrics_for_slice_pb2.SubKey() + if self.class_id is not None: + sub_key.class_id.value = self.class_id + if self.k is not None: + sub_key.k.value = self.k + if self.top_k is not None: + sub_key.top_k.value = self.top_k + return sub_key + + @staticmethod + def from_proto(pb: metrics_for_slice_pb2.SubKey) -> Optional["SubKey"]: + """Creates class from proto.""" + class_id = None + if pb.HasField("class_id"): + class_id = pb.class_id.value + k = None + if pb.HasField("k"): + k = pb.k.value + top_k = None + if pb.HasField("top_k"): + top_k = pb.top_k.value + if class_id is None and k is None and top_k is None: + return None + else: + return SubKey(class_id=class_id, k=k, top_k=top_k) # A separate version from proto is used here because protos are not hashable and @@ -138,102 +148,103 @@ def from_proto(pb: metrics_for_slice_pb2.SubKey) -> Optional['SubKey']: @functools.total_ordering class AggregationType( NamedTuple( - 'AggregationType', + "AggregationType", [ - ('micro_average', bool), - ('macro_average', bool), - ('weighted_macro_average', bool), + ("micro_average", bool), + ("macro_average", bool), + ("weighted_macro_average", bool), ], ) ): - """AggregationType identifies aggregation types used with AggregationOptions. - - Only one of micro_average, macro_average, or weighted_macro_average can be set - at a time. - - Attributes: - micro_average: True of macro averaging used. - macro_average: True of macro averaging used. - weighted_macro_average: True of weighted macro averaging used. - """ - - # IfChange (should be preceded by LINT, but cannot nest LINT) - def __new__( - cls, - micro_average: Optional[bool] = None, - macro_average: Optional[bool] = None, - weighted_macro_average: Optional[bool] = None, - ): - if ( - sum([ - micro_average or False, - macro_average or False, - weighted_macro_average or False, - ]) - > 1 - ): - raise ValueError( - 'only one of micro_average, macro_average, or ' - 'weighted_macro_average should be set: micro_average={}, ' - 'macro_average={}, weighted_macro_average={}'.format( - micro_average, macro_average, weighted_macro_average - ) - ) - return super(AggregationType, cls).__new__( - cls, micro_average, macro_average, weighted_macro_average - ) + """AggregationType identifies aggregation types used with AggregationOptions. - # ThenChange(../api/model_eval_lib.py) + Only one of micro_average, macro_average, or weighted_macro_average can be set + at a time. - def __eq__(self, other): - return tuple(self) == other + Attributes + ---------- + micro_average: True of macro averaging used. + macro_average: True of macro averaging used. + weighted_macro_average: True of weighted macro averaging used. + """ - def __lt__(self, other): - # Python3 does not allow comparison of NoneType, replace with -1. - return tuple(x if x is not None else -1 for x in self) < tuple( - x if x is not None else -1 for x in other or () - ) + # IfChange (should be preceded by LINT, but cannot nest LINT) + def __new__( + cls, + micro_average: Optional[bool] = None, + macro_average: Optional[bool] = None, + weighted_macro_average: Optional[bool] = None, + ): + if ( + sum( + [ + micro_average or False, + macro_average or False, + weighted_macro_average or False, + ] + ) + > 1 + ): + raise ValueError( + "only one of micro_average, macro_average, or " + f"weighted_macro_average should be set: micro_average={micro_average}, " + f"macro_average={macro_average}, weighted_macro_average={weighted_macro_average}" + ) + return super(AggregationType, cls).__new__( + cls, micro_average, macro_average, weighted_macro_average + ) - def __hash__(self): - return hash(tuple(self)) + # ThenChange(../api/model_eval_lib.py) - def __str__(self) -> str: - if self.micro_average is not None: - return 'micro' - elif self.macro_average is not None: - return 'macro' - elif self.weighted_macro_average is not None: - return 'weighted_macro' - else: - raise NotImplementedError(( - 'A non-existent AggregationType should be represented as None, not ' - 'as AggregationType(None, None, None).' - )) - - def to_proto(self) -> metrics_for_slice_pb2.AggregationType: - """Converts key to proto.""" - aggregration_type = metrics_for_slice_pb2.AggregationType() - if self.micro_average is not None: - aggregration_type.micro_average = True - if self.macro_average is not None: - aggregration_type.macro_average = True - if self.weighted_macro_average is not None: - aggregration_type.weighted_macro_average = True - return aggregration_type - - @staticmethod - def from_proto( - pb: metrics_for_slice_pb2.AggregationType, - ) -> Optional['AggregationType']: - """Creates class from proto.""" - if pb.micro_average or pb.macro_average or pb.weighted_macro_average: - return AggregationType( - micro_average=pb.micro_average or None, - macro_average=pb.macro_average or None, - weighted_macro_average=pb.weighted_macro_average or None, - ) - else: - return None + def __eq__(self, other): + return tuple(self) == other + + def __lt__(self, other): + # Python3 does not allow comparison of NoneType, replace with -1. + return tuple(x if x is not None else -1 for x in self) < tuple( + x if x is not None else -1 for x in other or () + ) + + def __hash__(self): + return hash(tuple(self)) + + def __str__(self) -> str: + if self.micro_average is not None: + return "micro" + elif self.macro_average is not None: + return "macro" + elif self.weighted_macro_average is not None: + return "weighted_macro" + else: + raise NotImplementedError( + "A non-existent AggregationType should be represented as None, not " + "as AggregationType(None, None, None)." + ) + + def to_proto(self) -> metrics_for_slice_pb2.AggregationType: + """Converts key to proto.""" + aggregration_type = metrics_for_slice_pb2.AggregationType() + if self.micro_average is not None: + aggregration_type.micro_average = True + if self.macro_average is not None: + aggregration_type.macro_average = True + if self.weighted_macro_average is not None: + aggregration_type.weighted_macro_average = True + return aggregration_type + + @staticmethod + def from_proto( + pb: metrics_for_slice_pb2.AggregationType, + ) -> Optional["AggregationType"]: + """Creates class from proto.""" + if pb.micro_average or pb.macro_average or pb.weighted_macro_average: + return AggregationType( + micro_average=pb.micro_average or None, + macro_average=pb.macro_average or None, + weighted_macro_average=pb.weighted_macro_average or None, + ) + else: + return None # A separate version from proto is used here because protos are not hashable and @@ -241,129 +252,130 @@ def from_proto( @functools.total_ordering class MetricKey( NamedTuple( - 'MetricKey', + "MetricKey", [ - ('name', str), - ('model_name', str), - ('output_name', str), - ('sub_key', Optional[SubKey]), - ('aggregation_type', Optional[AggregationType]), - ('example_weighted', Optional[bool]), - ('is_diff', bool), + ("name", str), + ("model_name", str), + ("output_name", str), + ("sub_key", Optional[SubKey]), + ("aggregation_type", Optional[AggregationType]), + ("example_weighted", Optional[bool]), + ("is_diff", bool), ], ) ): - """A MetricKey uniquely identifies a metric. - - Attributes: - name: Metric name. Names starting with '_' are private and will be filtered - from the final results. Names starting with two underscores, '__' are - reserved for internal use. - model_name: Optional model name (if multi-model evaluation). - output_name: Optional output name (if multi-output model type). - sub_key: Optional sub key. - aggregation_type: Optional Aggregation type. - example_weighted: Indicates whether this metric was weighted by examples. - is_diff: Optional flag to indicate whether this metrics is a diff metric. - """ - - def __new__( - cls, - name: str, - model_name: str = '', - output_name: str = '', - sub_key: Optional[SubKey] = None, - aggregation_type: Optional[AggregationType] = None, - example_weighted: Optional[bool] = False, - is_diff: Optional[bool] = False, - ): - return super(MetricKey, cls).__new__( + """A MetricKey uniquely identifies a metric. + + Attributes + ---------- + name: Metric name. Names starting with '_' are private and will be filtered + from the final results. Names starting with two underscores, '__' are + reserved for internal use. + model_name: Optional model name (if multi-model evaluation). + output_name: Optional output name (if multi-output model type). + sub_key: Optional sub key. + aggregation_type: Optional Aggregation type. + example_weighted: Indicates whether this metric was weighted by examples. + is_diff: Optional flag to indicate whether this metrics is a diff metric. + """ + + def __new__( cls, - name, - model_name, - output_name, - sub_key, - aggregation_type, - example_weighted, - is_diff, - ) + name: str, + model_name: str = "", + output_name: str = "", + sub_key: Optional[SubKey] = None, + aggregation_type: Optional[AggregationType] = None, + example_weighted: Optional[bool] = False, + is_diff: Optional[bool] = False, + ): + return super(MetricKey, cls).__new__( + cls, + name, + model_name, + output_name, + sub_key, + aggregation_type, + example_weighted, + is_diff, + ) - def __eq__(self, other): - return tuple(self) == other - - def __lt__(self, other): - if other is None: - return False - # Python3 does not allow comparison of NoneType, remove if present. - sub_key = self.sub_key if self.sub_key else () - other_sub_key = other.sub_key if other.sub_key else () - agg_type = self.aggregation_type if self.aggregation_type else () - other_agg_type = other.aggregation_type if other.aggregation_type else () - example_weighted = self.example_weighted if self.example_weighted else () - other_example_weighted = ( - other.example_weighted if other.example_weighted else () - ) - is_diff = self.is_diff - other_is_diff = other.is_diff - # -4 for sub_key, aggregation_type, example_weighted, and is_diff - return ( - (tuple(self[:-4])) < tuple(other[:-4]) - and sub_key < other_sub_key - and agg_type < other_agg_type - and example_weighted < other_example_weighted - and is_diff < other_is_diff - ) + def __eq__(self, other): + return tuple(self) == other + + def __lt__(self, other): + if other is None: + return False + # Python3 does not allow comparison of NoneType, remove if present. + sub_key = self.sub_key if self.sub_key else () + other_sub_key = other.sub_key if other.sub_key else () + agg_type = self.aggregation_type if self.aggregation_type else () + other_agg_type = other.aggregation_type if other.aggregation_type else () + example_weighted = self.example_weighted if self.example_weighted else () + other_example_weighted = ( + other.example_weighted if other.example_weighted else () + ) + is_diff = self.is_diff + other_is_diff = other.is_diff + # -4 for sub_key, aggregation_type, example_weighted, and is_diff + return ( + (tuple(self[:-4])) < tuple(other[:-4]) + and sub_key < other_sub_key + and agg_type < other_agg_type + and example_weighted < other_example_weighted + and is_diff < other_is_diff + ) - def __hash__(self): - return hash(tuple(self)) + def __hash__(self): + return hash(tuple(self)) - def __str__(self): - return text_format.MessageToString( - self.to_proto(), as_one_line=True, force_colon=True - ) + def __str__(self): + return text_format.MessageToString( + self.to_proto(), as_one_line=True, force_colon=True + ) - def to_proto(self) -> metrics_for_slice_pb2.MetricKey: - """Converts key to proto.""" - metric_key = metrics_for_slice_pb2.MetricKey() - if self.name: - metric_key.name = self.name - if self.model_name: - metric_key.model_name = self.model_name - if self.output_name: - metric_key.output_name = self.output_name - if self.sub_key: - metric_key.sub_key.CopyFrom(self.sub_key.to_proto()) - if self.aggregation_type: - metric_key.aggregation_type.CopyFrom(self.aggregation_type.to_proto()) - if self.example_weighted is not None: - metric_key.example_weighted.value = self.example_weighted - if self.is_diff: - metric_key.is_diff = self.is_diff - return metric_key - - @staticmethod - def from_proto(pb: metrics_for_slice_pb2.MetricKey) -> 'MetricKey': - """Configures class from proto.""" - example_weighted = None - if pb.HasField('example_weighted'): - example_weighted = pb.example_weighted.value - return MetricKey( - name=pb.name, - model_name=pb.model_name, - output_name=pb.output_name, - sub_key=SubKey.from_proto(pb.sub_key), - aggregation_type=AggregationType.from_proto(pb.aggregation_type), - example_weighted=example_weighted, - is_diff=pb.is_diff, - ) + def to_proto(self) -> metrics_for_slice_pb2.MetricKey: + """Converts key to proto.""" + metric_key = metrics_for_slice_pb2.MetricKey() + if self.name: + metric_key.name = self.name + if self.model_name: + metric_key.model_name = self.model_name + if self.output_name: + metric_key.output_name = self.output_name + if self.sub_key: + metric_key.sub_key.CopyFrom(self.sub_key.to_proto()) + if self.aggregation_type: + metric_key.aggregation_type.CopyFrom(self.aggregation_type.to_proto()) + if self.example_weighted is not None: + metric_key.example_weighted.value = self.example_weighted + if self.is_diff: + metric_key.is_diff = self.is_diff + return metric_key + + @staticmethod + def from_proto(pb: metrics_for_slice_pb2.MetricKey) -> "MetricKey": + """Configures class from proto.""" + example_weighted = None + if pb.HasField("example_weighted"): + example_weighted = pb.example_weighted.value + return MetricKey( + name=pb.name, + model_name=pb.model_name, + output_name=pb.output_name, + sub_key=SubKey.from_proto(pb.sub_key), + aggregation_type=AggregationType.from_proto(pb.aggregation_type), + example_weighted=example_weighted, + is_diff=pb.is_diff, + ) - # Generate a copy of the key except that the is_diff is True. - def make_diff_key(self) -> 'MetricKey': - return self._replace(is_diff=True) + # Generate a copy of the key except that the is_diff is True. + def make_diff_key(self) -> "MetricKey": + return self._replace(is_diff=True) - # Generate a copy of the key with a different model name and is_diff False. - def make_baseline_key(self, model_name: str) -> 'MetricKey': - return self._replace(model_name=model_name, is_diff=False) + # Generate a copy of the key with a different model name and is_diff False. + def make_baseline_key(self, model_name: str) -> "MetricKey": + return self._replace(model_name=model_name, is_diff=False) # The output type of a MetricComputation combiner. @@ -375,36 +387,40 @@ def make_baseline_key(self, model_name: str) -> 'MetricKey': # In addition internally PlotKey is a subclass of MetricKey as each plot is # stored separately. class PlotKey(MetricKey): - """A PlotKey is a metric key that uniquely identifies a plot.""" - - def to_proto(self) -> metrics_for_slice_pb2.PlotKey: # pytype: disable=signature-mismatch # overriding-return-type-checks - """Converts key to proto.""" - plot_key = metrics_for_slice_pb2.PlotKey() - if self.name: - plot_key.name = self.name - if self.model_name: - plot_key.model_name = self.model_name - if self.output_name: - plot_key.output_name = self.output_name - if self.sub_key: - plot_key.sub_key.CopyFrom(self.sub_key.to_proto()) - if self.example_weighted is not None: - plot_key.example_weighted.value = self.example_weighted - return plot_key - - @staticmethod - def from_proto(pb: metrics_for_slice_pb2.PlotKey) -> 'PlotKey': - """Configures class from proto.""" - example_weighted = None - if pb.HasField('example_weighted'): - example_weighted = pb.example_weighted.value - return PlotKey( - name=pb.name, - model_name=pb.model_name, - output_name=pb.output_name, - sub_key=SubKey.from_proto(pb.sub_key), - example_weighted=example_weighted, - ) + """A PlotKey is a metric key that uniquely identifies a plot.""" + + def to_proto( + self, + ) -> ( + metrics_for_slice_pb2.PlotKey + ): # pytype: disable=signature-mismatch # overriding-return-type-checks + """Converts key to proto.""" + plot_key = metrics_for_slice_pb2.PlotKey() + if self.name: + plot_key.name = self.name + if self.model_name: + plot_key.model_name = self.model_name + if self.output_name: + plot_key.output_name = self.output_name + if self.sub_key: + plot_key.sub_key.CopyFrom(self.sub_key.to_proto()) + if self.example_weighted is not None: + plot_key.example_weighted.value = self.example_weighted + return plot_key + + @staticmethod + def from_proto(pb: metrics_for_slice_pb2.PlotKey) -> "PlotKey": + """Configures class from proto.""" + example_weighted = None + if pb.HasField("example_weighted"): + example_weighted = pb.example_weighted.value + return PlotKey( + name=pb.name, + model_name=pb.model_name, + output_name=pb.output_name, + sub_key=SubKey.from_proto(pb.sub_key), + example_weighted=example_weighted, + ) # A separate version from proto is used here because protos are not hashable and @@ -412,85 +428,90 @@ def from_proto(pb: metrics_for_slice_pb2.PlotKey) -> 'PlotKey': # In addition internally AttributionsKey is a subclass of MetricKey as each # attribution is stored separately. class AttributionsKey(MetricKey): - """An AttributionsKey is a metric key uniquely identifying attributions.""" - - def to_proto(self) -> metrics_for_slice_pb2.AttributionsKey: # pytype: disable=signature-mismatch # overriding-return-type-checks - """Converts key to proto.""" - attribution_key = metrics_for_slice_pb2.AttributionsKey() - if self.name: - attribution_key.name = self.name - if self.model_name: - attribution_key.model_name = self.model_name - if self.output_name: - attribution_key.output_name = self.output_name - if self.sub_key: - attribution_key.sub_key.CopyFrom(self.sub_key.to_proto()) - if self.example_weighted is not None: - attribution_key.example_weighted.value = self.example_weighted - if self.is_diff: - attribution_key.is_diff = self.is_diff - return attribution_key - - @staticmethod - def from_proto( - pb: metrics_for_slice_pb2.AttributionsKey, - ) -> 'AttributionsKey': - """Configures class from proto.""" - example_weighted = None - if pb.HasField('example_weighted'): - example_weighted = pb.example_weighted.value - return AttributionsKey( - name=pb.name, - model_name=pb.model_name, - output_name=pb.output_name, - sub_key=SubKey.from_proto(pb.sub_key), - example_weighted=example_weighted, - is_diff=pb.is_diff, - ) + """An AttributionsKey is a metric key uniquely identifying attributions.""" + + def to_proto( + self, + ) -> ( + metrics_for_slice_pb2.AttributionsKey + ): # pytype: disable=signature-mismatch # overriding-return-type-checks + """Converts key to proto.""" + attribution_key = metrics_for_slice_pb2.AttributionsKey() + if self.name: + attribution_key.name = self.name + if self.model_name: + attribution_key.model_name = self.model_name + if self.output_name: + attribution_key.output_name = self.output_name + if self.sub_key: + attribution_key.sub_key.CopyFrom(self.sub_key.to_proto()) + if self.example_weighted is not None: + attribution_key.example_weighted.value = self.example_weighted + if self.is_diff: + attribution_key.is_diff = self.is_diff + return attribution_key + + @staticmethod + def from_proto( + pb: metrics_for_slice_pb2.AttributionsKey, + ) -> "AttributionsKey": + """Configures class from proto.""" + example_weighted = None + if pb.HasField("example_weighted"): + example_weighted = pb.example_weighted.value + return AttributionsKey( + name=pb.name, + model_name=pb.model_name, + output_name=pb.output_name, + sub_key=SubKey.from_proto(pb.sub_key), + example_weighted=example_weighted, + is_diff=pb.is_diff, + ) class Preprocessor(beam.DoFn): - """Preprocessor wrapper for preprocessing data in the metric computation. - - The preprocessor is a beam.DoFn that takes a extracts (or a list of extracts) - as input (which typically will contain labels, predictions, example weights, - and optionally features) and should return the initial state that the combiner - will use as input. The output of a processor should only contain - information needed by the combiner. Note that if a query_key is used the - preprocessor will be passed a list of extracts as input representing the - extracts that matched the query_key. The special FeaturePreprocessor can - be used to add additional features to the default standard metric inputs. - - Attributes: - name: The name of the preprocessor. It should only be accessed by a property - function. It is a read only attribute, and is used to distinguish - different preprocessors. - """ - - def __init__(self, name: Optional[str] = None, **kwargs): - super().__init__(**kwargs) - self._name = name - - @property - def name(self) -> str: - # if name is not specified, it returns the class name instead. - return self._name or self.__class__.__name__ - - @property - def preprocessor_id(self): - # TODO(b/243206889) develop a more robust hash id for deduplication of - # preprocessors. The name is used as the preprocessor_id to distinguish - # preprocessors. However, it could be brittle. - return self.name - - def __eq__(self, other): - if isinstance(other, Preprocessor): - return self.preprocessor_id == other.preprocessor_id - else: - return False + """Preprocessor wrapper for preprocessing data in the metric computation. + + The preprocessor is a beam.DoFn that takes a extracts (or a list of extracts) + as input (which typically will contain labels, predictions, example weights, + and optionally features) and should return the initial state that the combiner + will use as input. The output of a processor should only contain + information needed by the combiner. Note that if a query_key is used the + preprocessor will be passed a list of extracts as input representing the + extracts that matched the query_key. The special FeaturePreprocessor can + be used to add additional features to the default standard metric inputs. + + Attributes + ---------- + name: The name of the preprocessor. It should only be accessed by a property + function. It is a read only attribute, and is used to distinguish + different preprocessors. + """ + + def __init__(self, name: Optional[str] = None, **kwargs): + super().__init__(**kwargs) + self._name = name - def __hash__(self): - return hash(self._preprocessor_id()) + @property + def name(self) -> str: + # if name is not specified, it returns the class name instead. + return self._name or self.__class__.__name__ + + @property + def preprocessor_id(self): + # TODO(b/243206889) develop a more robust hash id for deduplication of + # preprocessors. The name is used as the preprocessor_id to distinguish + # preprocessors. However, it could be brittle. + return self.name + + def __eq__(self, other): + if isinstance(other, Preprocessor): + return self.preprocessor_id == other.preprocessor_id + else: + return False + + def __hash__(self): + return hash(self._preprocessor_id()) # LINT.ThenChange(../proto/metrics_for_slice.proto) @@ -498,128 +519,130 @@ def __hash__(self): class MetricComputation( NamedTuple( - 'MetricComputation', + "MetricComputation", [ - ('keys', List[MetricKey]), - ('preprocessors', List[Preprocessor]), - ('combiner', beam.CombineFn), + ("keys", List[MetricKey]), + ("preprocessors", List[Preprocessor]), + ("combiner", beam.CombineFn), ], ) ): - """MetricComputation represents one or more metric computations. - - The preprocessors are called with a PCollection of extracts (or list of - extracts if query_key is used) to compute the initial combiner input state - which is then passed to the combiner. This needs to be done in two steps - because slicing happens between the call to the preprocessors and the combiner - and this state may end up in multiple slices so we want the representation to - be as efficient as possible. If the preprocessors are None, then - StandardMetricInputs will be passed. - - A MetricComputation is uniquely identified by the combination of the - combiner's name and the keys. Duplicate computations will be removed - automatically. - - Attributes: - keys: List of metric keys associated with computation. If the keys are - defined as part of the computation then this may be empty in which case - only the combiner name will be used for identifying computation - uniqueness. - preprocessors: Takes a extracts (or a list of extracts) as input (which - typically will contain labels, predictions, example weights, and - optionally features) and should return the initial state that the combiner - will use as input. The output of a processor should only contain - information needed by the combiner. - combiner: Takes preprocessor output as input and outputs a tuple: (slice, - metric results). The metric results should be a dict from MetricKey to - value (float, int, distribution, ...). - """ - - def __new__( - cls, - keys: List[MetricKey], - preprocessors: Optional[List[Preprocessor]], - combiner: beam.CombineFn, - ): - # if preprocessors are passed as None, it will be initialized as [] - return super(MetricComputation, cls).__new__( - cls, keys, preprocessors or [], combiner - ) + """MetricComputation represents one or more metric computations. + + The preprocessors are called with a PCollection of extracts (or list of + extracts if query_key is used) to compute the initial combiner input state + which is then passed to the combiner. This needs to be done in two steps + because slicing happens between the call to the preprocessors and the combiner + and this state may end up in multiple slices so we want the representation to + be as efficient as possible. If the preprocessors are None, then + StandardMetricInputs will be passed. + + A MetricComputation is uniquely identified by the combination of the + combiner's name and the keys. Duplicate computations will be removed + automatically. + + Attributes + ---------- + keys: List of metric keys associated with computation. If the keys are + defined as part of the computation then this may be empty in which case + only the combiner name will be used for identifying computation + uniqueness. + preprocessors: Takes a extracts (or a list of extracts) as input (which + typically will contain labels, predictions, example weights, and + optionally features) and should return the initial state that the combiner + will use as input. The output of a processor should only contain + information needed by the combiner. + combiner: Takes preprocessor output as input and outputs a tuple: (slice, + metric results). The metric results should be a dict from MetricKey to + value (float, int, distribution, ...). + """ - def _computation_id(self): - # Some computations do not define the keys until the end of the computation - # is complete. In these cases the keys will be empty so we also distinguish - # based on the combiner name used. We don't use __class__ since classes may - # be defined inline which wouldn't compare equal. - return ( - self.combiner.__class__.__name__, - tuple(sorted(self.keys or [])), - tuple(p.preprocessor_id for p in self.preprocessors or []), - ) + def __new__( + cls, + keys: List[MetricKey], + preprocessors: Optional[List[Preprocessor]], + combiner: beam.CombineFn, + ): + # if preprocessors are passed as None, it will be initialized as [] + return super(MetricComputation, cls).__new__( + cls, keys, preprocessors or [], combiner + ) - def __eq__(self, other): - if isinstance(other, MetricComputation): - return self._computation_id() == other._computation_id() - else: - return False + def _computation_id(self): + # Some computations do not define the keys until the end of the computation + # is complete. In these cases the keys will be empty so we also distinguish + # based on the combiner name used. We don't use __class__ since classes may + # be defined inline which wouldn't compare equal. + return ( + self.combiner.__class__.__name__, + tuple(sorted(self.keys or [])), + tuple(p.preprocessor_id for p in self.preprocessors or []), + ) + + def __eq__(self, other): + if isinstance(other, MetricComputation): + return self._computation_id() == other._computation_id() + else: + return False - def __hash__(self): - return hash(self._computation_id()) + def __hash__(self): + return hash(self._computation_id()) class DerivedMetricComputation( NamedTuple( - 'DerivedMetricComputation', + "DerivedMetricComputation", [ - ('keys', List[MetricKey]), - ('result', Callable), + ("keys", List[MetricKey]), + ("result", Callable), ], # Dict[MetricKey,Any] -> Dict[MetricKey,Any] ) ): - """DerivedMetricComputation derives its result from other computations. - - When creating derived metric computations it is recommended (but not required) - that the underlying MetricComputations that they depend on are defined at the - same time. This is to avoid having to pre-construct and pass around all the - required dependencies in order to construct a derived metric. The evaluation - pipeline is responsible for de-duplicating overlapping MetricComputations so - that only one computation is actually run. - - A DerivedMetricComputation is uniquely identified by the combination of the - result function's name and the keys. Duplicate computations will be removed - automatically. - - Attributes: - keys: List of metric keys associated with derived computation. If the keys - are defined as part of the computation then this may be empty in which - case only the result function name will be used for identifying - computation uniqueness. - result: Function (called per slice) to compute the result using the results - of other metric computations. - """ - - def __new__( - cls, - keys: List[MetricKey], - result: Callable[[Dict[MetricKey, Any]], Dict[MetricKey, Any]], - ): - return super(DerivedMetricComputation, cls).__new__(cls, keys, result) - - def _computation_id(self): - # Some computations do not define the keys until the end of the computation - # is complete. In these cases the keys will be empty so we also distinguish - # based on the result function name used. We don't use __class__ since - # functions may be defined inline which wouldn't compare equal. - return (self.result.__class__.__name__, tuple(sorted(self.keys or []))) - - def __eq__(self, other): - if isinstance(other, DerivedMetricComputation): - return self._computation_id() == other._computation_id() - else: - return False + """DerivedMetricComputation derives its result from other computations. + + When creating derived metric computations it is recommended (but not required) + that the underlying MetricComputations that they depend on are defined at the + same time. This is to avoid having to pre-construct and pass around all the + required dependencies in order to construct a derived metric. The evaluation + pipeline is responsible for de-duplicating overlapping MetricComputations so + that only one computation is actually run. + + A DerivedMetricComputation is uniquely identified by the combination of the + result function's name and the keys. Duplicate computations will be removed + automatically. + + Attributes + ---------- + keys: List of metric keys associated with derived computation. If the keys + are defined as part of the computation then this may be empty in which + case only the result function name will be used for identifying + computation uniqueness. + result: Function (called per slice) to compute the result using the results + of other metric computations. + """ - def __hash__(self): - return hash(self._computation_id()) + def __new__( + cls, + keys: List[MetricKey], + result: Callable[[Dict[MetricKey, Any]], Dict[MetricKey, Any]], + ): + return super(DerivedMetricComputation, cls).__new__(cls, keys, result) + + def _computation_id(self): + # Some computations do not define the keys until the end of the computation + # is complete. In these cases the keys will be empty so we also distinguish + # based on the result function name used. We don't use __class__ since + # functions may be defined inline which wouldn't compare equal. + return (self.result.__class__.__name__, tuple(sorted(self.keys or []))) + + def __eq__(self, other): + if isinstance(other, DerivedMetricComputation): + return self._computation_id() == other._computation_id() + else: + return False + + def __hash__(self): + return hash(self._computation_id()) CrossSliceComparisonCallable = Callable[ @@ -629,79 +652,81 @@ def __hash__(self): class CrossSliceMetricComputation( NamedTuple( - 'CrossSliceMetricComputation', + "CrossSliceMetricComputation", [ - ('keys', List[MetricKey]), - ('cross_slice_comparison', CrossSliceComparisonCallable), + ("keys", List[MetricKey]), + ("cross_slice_comparison", CrossSliceComparisonCallable), ], ) ): - """CrossSliceMetricComputation derives its result from other computations. - - It is used for metrics which are based upon cross slice comparison. - When creating these metric computations it is recommended (but not required) - that the underlying MetricComputations that they depend on are defined at the - same time. This is to avoid having to pre-construct and pass around all the - required dependencies in order to construct a derived metric. The evaluation - pipeline is responsible for de-duplicating overlapping MetricComputations so - that only one computation is actually run. - - A CrossSliceMetricComputation is uniquely identified by the combination of the - result function's name and the keys. Duplicate computations will be removed - automatically. - - Attributes: - keys: List of metric keys associated with derived computation. If the keys - are defined as part of the computation then this may be empty in which - case only the result function name will be used for identifying - computation uniqueness. - cross_slice_comparison: Function called to perform cross slice comparison - using the results of the other metric computations. - """ - - def __new__( - cls, - keys: List[MetricKey], - cross_slice_comparison: CrossSliceComparisonCallable, - ): - return super(CrossSliceMetricComputation, cls).__new__( - cls, keys, cross_slice_comparison - ) - - def _computation_id(self): - # Some computations do not define the keys until the end of the computation - # is complete. In these cases the keys will be empty so we also distinguish - # based on the result function name used. We don't use __class__ since - # functions may be defined inline which wouldn't compare equal. - return ( - self.cross_slice_comparison.__class__.__name__, - tuple(sorted(self.keys or [])), - ) + """CrossSliceMetricComputation derives its result from other computations. + + It is used for metrics which are based upon cross slice comparison. + When creating these metric computations it is recommended (but not required) + that the underlying MetricComputations that they depend on are defined at the + same time. This is to avoid having to pre-construct and pass around all the + required dependencies in order to construct a derived metric. The evaluation + pipeline is responsible for de-duplicating overlapping MetricComputations so + that only one computation is actually run. + + A CrossSliceMetricComputation is uniquely identified by the combination of the + result function's name and the keys. Duplicate computations will be removed + automatically. + + Attributes + ---------- + keys: List of metric keys associated with derived computation. If the keys + are defined as part of the computation then this may be empty in which + case only the result function name will be used for identifying + computation uniqueness. + cross_slice_comparison: Function called to perform cross slice comparison + using the results of the other metric computations. + """ - def __eq__(self, other): - if isinstance(other, CrossSliceMetricComputation): - return self._computation_id() == other._computation_id() - else: - return False + def __new__( + cls, + keys: List[MetricKey], + cross_slice_comparison: CrossSliceComparisonCallable, + ): + return super(CrossSliceMetricComputation, cls).__new__( + cls, keys, cross_slice_comparison + ) - def __hash__(self): - return hash(self._computation_id()) + def _computation_id(self): + # Some computations do not define the keys until the end of the computation + # is complete. In these cases the keys will be empty so we also distinguish + # based on the result function name used. We don't use __class__ since + # functions may be defined inline which wouldn't compare equal. + return ( + self.cross_slice_comparison.__class__.__name__, + tuple(sorted(self.keys or [])), + ) + def __eq__(self, other): + if isinstance(other, CrossSliceMetricComputation): + return self._computation_id() == other._computation_id() + else: + return False -class CIDerivedMetricComputation(DerivedMetricComputation): - """CIDerivedMetricComputation runs after Confidence Interval is computed. + def __hash__(self): + return hash(self._computation_id()) - A CIDerivedMetricComputation is uniquely identified by the combination of - result function's name and the keys. Duplicate computations will be removed - automatically. - Attributes: - keys: List of metric keys associated with derived computation. If the keys - are defined as part of the computation then this may be empty in which - case only the result function name will be used for identifying - computation uniqueness. - result: Function called to perform compute the metrics. - """ +class CIDerivedMetricComputation(DerivedMetricComputation): + """CIDerivedMetricComputation runs after Confidence Interval is computed. + + A CIDerivedMetricComputation is uniquely identified by the combination of + result function's name and the keys. Duplicate computations will be removed + automatically. + + Attributes + ---------- + keys: List of metric keys associated with derived computation. If the keys + are defined as part of the computation then this may be empty in which + case only the result function name will be used for identifying + computation uniqueness. + result: Function called to perform compute the metrics. + """ # MetricComputations is a list of derived and non-derived computations used to @@ -730,317 +755,322 @@ def validate_and_update_create_computations_fn_kwargs( example_weighted: bool = False, query_key: Optional[str] = None, ): - """Validates and updates create_computations_fn kwargs based on arg_names. - - Each metric's create_computations_fn is invoked with a variable set of - parameters, depending on the argument names of the callable. If an argument - name matches one of the reserved names, this function will update the kwargs - with the appropriate value for that arg. - - Args: - arg_names: The arg_names for the create_computations_fn. - kwargs: The existing kwargs for create_computations_fn. - eval_config: The value to use when `eval_config` is in arg_names. - schema: The value to use when `schema` is in arg_names. - model_names: The value to use when `model_names` is in arg_names. - output_names: The value to use when `output_names` is in arg_names. - sub_keys: The value to use when `sub_keys` is in arg_names. - aggregation_type: The value to use when `aggregation_type` is in arg_names. - class_weights: The value to use when `class_weights` is in arg_names. - example_weighted: The value to use when `exampled_weighted` is in arg_names. - query_key: The value to use when `query_key` is in arg_names. - - Returns: - The kwargs passed as input, updated with the appropriate additional args. - - Raises: - ValueError: If arg_names or kwargs don't support a requested parameter. - """ - if 'eval_config' in arg_names: - kwargs['eval_config'] = eval_config - if 'schema' in arg_names: - kwargs['schema'] = schema - if 'model_names' in arg_names: - kwargs['model_names'] = model_names - if 'output_names' in arg_names: - kwargs['output_names'] = output_names - if 'sub_keys' in arg_names: - kwargs['sub_keys'] = sub_keys - if 'aggregation_type' in arg_names: - kwargs['aggregation_type'] = aggregation_type - if 'class_weights' in arg_names: - kwargs['class_weights'] = class_weights - elif class_weights: - raise ValueError( - 'A metric that does not support class_weights is being used with ' - 'class_weights applied. This is likely caused because micro_averaging ' - 'was enabled for a metric that does not support it. ' - f'Metric args={arg_names}, kwargs={kwargs}' - ) - if 'query_key' in arg_names: - kwargs['query_key'] = query_key - if 'example_weighted' in arg_names: - kwargs['example_weighted'] = example_weighted - elif example_weighted: - raise ValueError( - 'A metric that does not support example weights is being used with ' - 'MetricsSpec.example_weights.weighted set to true. Contact the owner ' - 'of the Metric implementation to ask if support can be added. ' - f'Metric args={arg_names}, kwargs={kwargs}' - ) - return kwargs - - -class Metric: - """Metric wraps a set of metric computations. - - This class exists to provide similarity between tfma.metrics.Metric and - tf.keras.metics.Metric. + """Validates and updates create_computations_fn kwargs based on arg_names. - Calling computations creates the metric computations. The parameters passed to - __init__ will be combined with the parameters passed to the computations - method. This allows some of the parameters (e.g. model_names, output_names, - sub_keys) to be set at the time the computations are created instead of when - the metric is defined. - """ - - def __init__( - self, create_computations_fn: Callable[..., MetricComputations], **kwargs - ): - """Initializes metric. + Each metric's create_computations_fn is invoked with a variable set of + parameters, depending on the argument names of the callable. If an argument + name matches one of the reserved names, this function will update the kwargs + with the appropriate value for that arg. Args: - create_computations_fn: Function to create the metrics computations (e.g. - mean_label, etc). This function should take the args passed to __init__ - as as input along with any of eval_config, schema, model_names, - output_names, sub_keys, aggregation_type, or query_key (where needed). - **kwargs: Any additional kwargs to pass to create_computations_fn. These - should only contain primitive types or lists/dicts of primitive types. - The kwargs passed to computations have precendence over these kwargs. - """ - self.create_computations_fn = create_computations_fn - if 'name' in kwargs: - if not kwargs['name'] and self._default_name(): - kwargs['name'] = self._default_name() # pylint: disable=assignment-from-none - name = kwargs['name'] - else: - name = None - self.name = name - self.kwargs = kwargs - if hasattr(inspect, 'getfullargspec'): - self._args = inspect.getfullargspec(self.create_computations_fn).args - else: - self._args = inspect.getargspec(self.create_computations_fn).args # pylint: disable=deprecated-method + ---- + arg_names: The arg_names for the create_computations_fn. + kwargs: The existing kwargs for create_computations_fn. + eval_config: The value to use when `eval_config` is in arg_names. + schema: The value to use when `schema` is in arg_names. + model_names: The value to use when `model_names` is in arg_names. + output_names: The value to use when `output_names` is in arg_names. + sub_keys: The value to use when `sub_keys` is in arg_names. + aggregation_type: The value to use when `aggregation_type` is in arg_names. + class_weights: The value to use when `class_weights` is in arg_names. + example_weighted: The value to use when `exampled_weighted` is in arg_names. + query_key: The value to use when `query_key` is in arg_names. - def _default_name(self) -> Optional[str]: - return None + Returns: + ------- + The kwargs passed as input, updated with the appropriate additional args. - def get_config(self) -> Dict[str, Any]: - """Returns serializable config.""" - return self.kwargs + Raises: + ------ + ValueError: If arg_names or kwargs don't support a requested parameter. + """ + if "eval_config" in arg_names: + kwargs["eval_config"] = eval_config + if "schema" in arg_names: + kwargs["schema"] = schema + if "model_names" in arg_names: + kwargs["model_names"] = model_names + if "output_names" in arg_names: + kwargs["output_names"] = output_names + if "sub_keys" in arg_names: + kwargs["sub_keys"] = sub_keys + if "aggregation_type" in arg_names: + kwargs["aggregation_type"] = aggregation_type + if "class_weights" in arg_names: + kwargs["class_weights"] = class_weights + elif class_weights: + raise ValueError( + "A metric that does not support class_weights is being used with " + "class_weights applied. This is likely caused because micro_averaging " + "was enabled for a metric that does not support it. " + f"Metric args={arg_names}, kwargs={kwargs}" + ) + if "query_key" in arg_names: + kwargs["query_key"] = query_key + if "example_weighted" in arg_names: + kwargs["example_weighted"] = example_weighted + elif example_weighted: + raise ValueError( + "A metric that does not support example weights is being used with " + "MetricsSpec.example_weights.weighted set to true. Contact the owner " + "of the Metric implementation to ask if support can be added. " + f"Metric args={arg_names}, kwargs={kwargs}" + ) + return kwargs - @classmethod - def from_config(cls, config: Dict[str, Any]) -> 'Metric': - # `fn` key is unnecessary for wrapper due to - # `create_computation_fn` key serialization. - config.pop('fn', None) - return cls(**config) - @property - def compute_confidence_interval(self) -> bool: - """Whether to compute confidence intervals for this metric. +class Metric: + """Metric wraps a set of metric computations. - Note that this may not completely remove the computational overhead - involved in computing a given metric. This is only respected by the - jackknife confidence interval method. + This class exists to provide similarity between tfma.metrics.Metric and + tf.keras.metics.Metric. - Returns: - Whether to compute confidence intervals for this metric. + Calling computations creates the metric computations. The parameters passed to + __init__ will be combined with the parameters passed to the computations + method. This allows some of the parameters (e.g. model_names, output_names, + sub_keys) to be set at the time the computations are created instead of when + the metric is defined. """ - return True - - def computations( - self, - eval_config: Optional[config_pb2.EvalConfig] = None, - schema: Optional[schema_pb2.Schema] = None, - model_names: Optional[List[str]] = None, - output_names: Optional[List[str]] = None, - sub_keys: Optional[List[Optional[SubKey]]] = None, - aggregation_type: Optional[AggregationType] = None, - class_weights: Optional[Dict[int, float]] = None, - example_weighted: bool = False, - query_key: Optional[str] = None, - ) -> MetricComputations: - """Creates computations associated with metric.""" - updated_kwargs = validate_and_update_create_computations_fn_kwargs( - self._args, - self.kwargs.copy(), - eval_config, - schema, - model_names, - output_names, - sub_keys, - aggregation_type, - class_weights, - example_weighted, - query_key, - ) - return self.create_computations_fn(**updated_kwargs) + + def __init__( + self, create_computations_fn: Callable[..., MetricComputations], **kwargs + ): + """Initializes metric. + + Args: + ---- + create_computations_fn: Function to create the metrics computations (e.g. + mean_label, etc). This function should take the args passed to __init__ + as as input along with any of eval_config, schema, model_names, + output_names, sub_keys, aggregation_type, or query_key (where needed). + **kwargs: Any additional kwargs to pass to create_computations_fn. These + should only contain primitive types or lists/dicts of primitive types. + The kwargs passed to computations have precendence over these kwargs. + """ + self.create_computations_fn = create_computations_fn + if "name" in kwargs: + if not kwargs["name"] and self._default_name(): + kwargs["name"] = self._default_name() # pylint: disable=assignment-from-none + name = kwargs["name"] + else: + name = None + self.name = name + self.kwargs = kwargs + if hasattr(inspect, "getfullargspec"): + self._args = inspect.getfullargspec(self.create_computations_fn).args + else: + self._args = inspect.getargspec(self.create_computations_fn).args # pylint: disable=deprecated-method + + def _default_name(self) -> Optional[str]: + return None + + def get_config(self) -> Dict[str, Any]: + """Returns serializable config.""" + return self.kwargs + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "Metric": + # `fn` key is unnecessary for wrapper due to + # `create_computation_fn` key serialization. + config.pop("fn", None) + return cls(**config) + + @property + def compute_confidence_interval(self) -> bool: + """Whether to compute confidence intervals for this metric. + + Note that this may not completely remove the computational overhead + involved in computing a given metric. This is only respected by the + jackknife confidence interval method. + + Returns + ------- + Whether to compute confidence intervals for this metric. + """ + return True + + def computations( + self, + eval_config: Optional[config_pb2.EvalConfig] = None, + schema: Optional[schema_pb2.Schema] = None, + model_names: Optional[List[str]] = None, + output_names: Optional[List[str]] = None, + sub_keys: Optional[List[Optional[SubKey]]] = None, + aggregation_type: Optional[AggregationType] = None, + class_weights: Optional[Dict[int, float]] = None, + example_weighted: bool = False, + query_key: Optional[str] = None, + ) -> MetricComputations: + """Creates computations associated with metric.""" + updated_kwargs = validate_and_update_create_computations_fn_kwargs( + self._args, + self.kwargs.copy(), + eval_config, + schema, + model_names, + output_names, + sub_keys, + aggregation_type, + class_weights, + example_weighted, + query_key, + ) + return self.create_computations_fn(**updated_kwargs) _METRIC_OBJECTS = {} def register_metric(cls: Type[Metric]): - """Registers metric under the list of standard TFMA metrics.""" - _METRIC_OBJECTS[cls.__name__] = cls + """Registers metric under the list of standard TFMA metrics.""" + _METRIC_OBJECTS[cls.__name__] = cls def registered_metrics() -> Dict[str, Type[Metric]]: - """Returns standard TFMA metrics.""" - return copy.copy(_METRIC_OBJECTS) + """Returns standard TFMA metrics.""" + return copy.copy(_METRIC_OBJECTS) def is_registered_metric(metric_class_name: str) -> bool: - """Returns True if given metric class name is registered.""" - return metric_class_name in _METRIC_OBJECTS + """Returns True if given metric class name is registered.""" + return metric_class_name in _METRIC_OBJECTS class StandardMetricInputs(util.StandardExtracts): - """Standard inputs used by most metric computations. - - StandardMetricInputs is a wrapper around Extracts where only the extracts keys - used by one or more ExtractsPreprocessors will be present. - """ - - @property - def label(self) -> Optional[types.TensorValueMaybeMultiLevelDict]: - """Same as labels (DEPRECATED - use labels).""" - return self.labels - - @property - def prediction(self) -> Optional[types.TensorValueMaybeMultiLevelDict]: - """Same as predictions (DEPRECATED - use predictions).""" - return self.predictions - - @property - def example_weight(self) -> Optional[types.TensorValueMaybeMultiLevelDict]: - """Same as example_weights (DEPRECATED - use example_weights).""" - return self.example_weights - - def get_by_key( - self, - key: str, - model_name: Optional[str] = None, - output_name: Optional[str] = None, - ) -> Any: - if key not in self and key.endswith('s'): - # The previous version of StandardMetricInputs was a NamedTuple that - # used label, prediction, and example_weight as the field names. Some - # tests may be creating StandardMetricInputs using these names, so also - # search under the non-pluralized form of the key. - key = key[:-1] - return super().get_by_key(key, model_name, output_name) - - -_DEFAULT_STANDARD_METRIC_INPUT_PREPROCESSOR_NAME = ( - 'standard_metric_input_preprocessor' -) -_DEFAULT_INPUT_PREPROCESSOR_NAME = 'input_preprocessor' -_DEFAULT_FEATURE_PREPROCESSOR_NAME = 'feature_preprocessor' -_DEFAULT_TRANSFORMED_FEATURE_PREPROCESSOR_NAME = ( - 'transformed_feature_preprocessor' -) -_DEFAULT_COMBINED_FEATURE_PREPROCESSOR_NAME = 'combined_feature_preprocessor' -_DEFAULT_ATTRIBUTION_PREPROCESSOR_NAME = 'attribution_preprocessor' + """Standard inputs used by most metric computations. + + StandardMetricInputs is a wrapper around Extracts where only the extracts keys + used by one or more ExtractsPreprocessors will be present. + """ + + @property + def label(self) -> Optional[types.TensorValueMaybeMultiLevelDict]: + """Same as labels (DEPRECATED - use labels).""" + return self.labels + + @property + def prediction(self) -> Optional[types.TensorValueMaybeMultiLevelDict]: + """Same as predictions (DEPRECATED - use predictions).""" + return self.predictions + + @property + def example_weight(self) -> Optional[types.TensorValueMaybeMultiLevelDict]: + """Same as example_weights (DEPRECATED - use example_weights).""" + return self.example_weights + + def get_by_key( + self, + key: str, + model_name: Optional[str] = None, + output_name: Optional[str] = None, + ) -> Any: + if key not in self and key.endswith("s"): + # The previous version of StandardMetricInputs was a NamedTuple that + # used label, prediction, and example_weight as the field names. Some + # tests may be creating StandardMetricInputs using these names, so also + # search under the non-pluralized form of the key. + key = key[:-1] + return super().get_by_key(key, model_name, output_name) + + +_DEFAULT_STANDARD_METRIC_INPUT_PREPROCESSOR_NAME = "standard_metric_input_preprocessor" +_DEFAULT_INPUT_PREPROCESSOR_NAME = "input_preprocessor" +_DEFAULT_FEATURE_PREPROCESSOR_NAME = "feature_preprocessor" +_DEFAULT_TRANSFORMED_FEATURE_PREPROCESSOR_NAME = "transformed_feature_preprocessor" +_DEFAULT_COMBINED_FEATURE_PREPROCESSOR_NAME = "combined_feature_preprocessor" +_DEFAULT_ATTRIBUTION_PREPROCESSOR_NAME = "attribution_preprocessor" _DEFAULT_STANDARD_METRIC_INPUT_PREPROCESSOR_LIST_NAME = ( - 'standard_metric_input_preprocessor_list' + "standard_metric_input_preprocessor_list" ) class StandardMetricInputsPreprocessor(Preprocessor): - """Preprocessor for filtering the extracts used in StandardMetricInputs.""" - - def __init__( - self, - include_filter: Optional[Union[Iterable[str], Dict[str, Any]]] = None, - include_default_inputs: bool = True, - model_names: Optional[Iterable[str]] = None, - output_names: Optional[Iterable[str]] = None, - name: Optional[str] = None, - ): - """Initializes preprocessor. - - Args: - include_filter: Optional list or map of extracts keys to include in - output. If a map of keys is passed then the keys and sub-keys that exist - in the map will be included in the output. An empty dict behaves as a - wildcard matching all keys or the value itself. Since matching on values - is not currently supported, an empty dict must be used to represent the - leaf nodes. For example, {'key1': {'key1-subkey': {}}, 'key2': {}}. - include_default_inputs: True to include default inputs (labels, - predictions, example weights) in addition to any inputs that may be - specified using include_filter. - model_names: Optional model names. Only used if include_default_inputs is - True. If unset all models will be included with the default inputs. - output_names: Optional output names. Only used if include_default_inputs - is True. If unset all outputs will be included with the default inputs. - name: Optional preprocessor name. Used to distinguish with other - preprocessors. - """ - super().__init__( - include_filter=include_filter, - include_default_inputs=include_default_inputs, - model_names=model_names, - output_names=output_names, - name=name, - ) - if include_filter is None: - include_filter = {} - if not isinstance(include_filter, MutableMapping): - if isinstance(include_filter, Iterable): - include_filter = {k: {} for k in include_filter or []} - else: - raise ValueError('include_filter must be a list or dict') - - if include_default_inputs: - default_filter = {} - if output_names: - default_filter = {name: default_filter for name in output_names} - if model_names: - default_filter = {name: default_filter for name in model_names} - include_filter = copy.copy(include_filter) - include_filter.update({ - constants.LABELS_KEY: default_filter, - constants.PREDICTIONS_KEY: default_filter, - constants.EXAMPLE_WEIGHTS_KEY: default_filter, - }) - self.include_filter = include_filter - - def _default_name(self) -> str: - return _DEFAULT_STANDARD_METRIC_INPUT_PREPROCESSOR_NAME - - def process(self, extracts: types.Extracts) -> Iterator[types.Extracts]: - if not self.include_filter: - yield {} - result = util.include_filter(self.include_filter, extracts) - yield result + """Preprocessor for filtering the extracts used in StandardMetricInputs.""" + + def __init__( + self, + include_filter: Optional[Union[Iterable[str], Dict[str, Any]]] = None, + include_default_inputs: bool = True, + model_names: Optional[Iterable[str]] = None, + output_names: Optional[Iterable[str]] = None, + name: Optional[str] = None, + ): + """Initializes preprocessor. + + Args: + ---- + include_filter: Optional list or map of extracts keys to include in + output. If a map of keys is passed then the keys and sub-keys that exist + in the map will be included in the output. An empty dict behaves as a + wildcard matching all keys or the value itself. Since matching on values + is not currently supported, an empty dict must be used to represent the + leaf nodes. For example, {'key1': {'key1-subkey': {}}, 'key2': {}}. + include_default_inputs: True to include default inputs (labels, + predictions, example weights) in addition to any inputs that may be + specified using include_filter. + model_names: Optional model names. Only used if include_default_inputs is + True. If unset all models will be included with the default inputs. + output_names: Optional output names. Only used if include_default_inputs + is True. If unset all outputs will be included with the default inputs. + name: Optional preprocessor name. Used to distinguish with other + preprocessors. + """ + super().__init__( + include_filter=include_filter, + include_default_inputs=include_default_inputs, + model_names=model_names, + output_names=output_names, + name=name, + ) + if include_filter is None: + include_filter = {} + if not isinstance(include_filter, MutableMapping): + if isinstance(include_filter, Iterable): + include_filter = {k: {} for k in include_filter or []} + else: + raise ValueError("include_filter must be a list or dict") + + if include_default_inputs: + default_filter = {} + if output_names: + default_filter = {name: default_filter for name in output_names} + if model_names: + default_filter = {name: default_filter for name in model_names} + include_filter = copy.copy(include_filter) + include_filter.update( + { + constants.LABELS_KEY: default_filter, + constants.PREDICTIONS_KEY: default_filter, + constants.EXAMPLE_WEIGHTS_KEY: default_filter, + } + ) + self.include_filter = include_filter + + def _default_name(self) -> str: + return _DEFAULT_STANDARD_METRIC_INPUT_PREPROCESSOR_NAME + + def process(self, extracts: types.Extracts) -> Iterator[types.Extracts]: + if not self.include_filter: + yield {} + result = util.include_filter(self.include_filter, extracts) + yield result def InputPreprocessor( # pylint: disable=invalid-name include_default_inputs: bool = False, ) -> StandardMetricInputsPreprocessor: - """Returns preprocessor for including raw inputs in StandardMetricInputs. + """Returns preprocessor for including raw inputs in StandardMetricInputs. - Args: - include_default_inputs: True to include default inputs (labels, predictions, - example weights) in addition to the inputs. - """ - return StandardMetricInputsPreprocessor( - include_filter={constants.INPUT_KEY: {}}, - include_default_inputs=include_default_inputs, - name=_DEFAULT_INPUT_PREPROCESSOR_NAME, - ) + Args: + ---- + include_default_inputs: True to include default inputs (labels, predictions, + example weights) in addition to the inputs. + """ + return StandardMetricInputsPreprocessor( + include_filter={constants.INPUT_KEY: {}}, + include_default_inputs=include_default_inputs, + name=_DEFAULT_INPUT_PREPROCESSOR_NAME, + ) def FeaturePreprocessor( # pylint: disable=invalid-name @@ -1049,28 +1079,29 @@ def FeaturePreprocessor( # pylint: disable=invalid-name model_names: Optional[Iterable[str]] = None, output_names: Optional[Iterable[str]] = None, ) -> StandardMetricInputsPreprocessor: - """Returns preprocessor for including features in StandardMetricInputs. - - Args: - feature_keys: List of feature keys. An empty list means all. - include_default_inputs: True to include default inputs (labels, predictions, - example weights) in addition to the features. - model_names: Optional model names. Only used if include_default_inputs is - True. If unset all models will be included with the default inputs. - output_names: Optional output names. Only used if include_default_inputs is - True. If unset all outputs will be included with the default inputs. - """ - if feature_keys: - include_features = {k: {} for k in feature_keys} - else: - include_features = {} - return StandardMetricInputsPreprocessor( - include_filter={constants.FEATURES_KEY: include_features}, - include_default_inputs=include_default_inputs, - model_names=model_names, - output_names=output_names, - name=_DEFAULT_FEATURE_PREPROCESSOR_NAME, - ) + """Returns preprocessor for including features in StandardMetricInputs. + + Args: + ---- + feature_keys: List of feature keys. An empty list means all. + include_default_inputs: True to include default inputs (labels, predictions, + example weights) in addition to the features. + model_names: Optional model names. Only used if include_default_inputs is + True. If unset all models will be included with the default inputs. + output_names: Optional output names. Only used if include_default_inputs is + True. If unset all outputs will be included with the default inputs. + """ + if feature_keys: + include_features = {k: {} for k in feature_keys} + else: + include_features = {} + return StandardMetricInputsPreprocessor( + include_filter={constants.FEATURES_KEY: include_features}, + include_default_inputs=include_default_inputs, + model_names=model_names, + output_names=output_names, + name=_DEFAULT_FEATURE_PREPROCESSOR_NAME, + ) def TransformedFeaturePreprocessor( # pylint: disable=invalid-name @@ -1079,30 +1110,31 @@ def TransformedFeaturePreprocessor( # pylint: disable=invalid-name model_names: Optional[Iterable[str]] = None, output_names: Optional[Iterable[str]] = None, ) -> StandardMetricInputsPreprocessor: - """Returns preprocessor for incl transformed features in StandardMetricInputs. - - Args: - feature_keys: List of feature keys. An empty list means all. - include_default_inputs: True to include default inputs (labels, predictions, - example weights) in addition to the transformed features. - model_names: Optional model names (required if transformed_features used - with multi-model evaluations). - output_names: Optional output names. Only used if include_default_inputs is - True. If unset all outputs will be included with the default inputs. - """ - if feature_keys: - include_features = {k: {} for k in feature_keys} - else: - include_features = {} - if model_names: - include_features = {name: include_features for name in model_names} - return StandardMetricInputsPreprocessor( - include_filter={constants.TRANSFORMED_FEATURES_KEY: include_features}, - include_default_inputs=include_default_inputs, - model_names=model_names, - output_names=output_names, - name=_DEFAULT_TRANSFORMED_FEATURE_PREPROCESSOR_NAME, - ) + """Returns preprocessor for incl transformed features in StandardMetricInputs. + + Args: + ---- + feature_keys: List of feature keys. An empty list means all. + include_default_inputs: True to include default inputs (labels, predictions, + example weights) in addition to the transformed features. + model_names: Optional model names (required if transformed_features used + with multi-model evaluations). + output_names: Optional output names. Only used if include_default_inputs is + True. If unset all outputs will be included with the default inputs. + """ + if feature_keys: + include_features = {k: {} for k in feature_keys} + else: + include_features = {} + if model_names: + include_features = {name: include_features for name in model_names} + return StandardMetricInputsPreprocessor( + include_filter={constants.TRANSFORMED_FEATURES_KEY: include_features}, + include_default_inputs=include_default_inputs, + model_names=model_names, + output_names=output_names, + name=_DEFAULT_TRANSFORMED_FEATURE_PREPROCESSOR_NAME, + ) def CombinedFeaturePreprocessor( # pylint: disable=invalid-name @@ -1111,33 +1143,34 @@ def CombinedFeaturePreprocessor( # pylint: disable=invalid-name model_names: Optional[Iterable[str]] = None, output_names: Optional[Iterable[str]] = None, ) -> StandardMetricInputsPreprocessor: - """Returns preprocessor for incl combined features in StandardMetricInputs. - - Args: - feature_keys: List of feature keys. An empty list means all. - include_default_inputs: True to include default inputs (labels, predictions, - example weights) in addition to the transformed features. - model_names: Optional model names (required if transformed_features used - with multi-model evaluations). - output_names: Optional output names. Only used if include_default_inputs is - True. If unset all outputs will be included with the default inputs. - """ - if feature_keys: - include_features = {k: {} for k in feature_keys} - else: - include_features = {} - if model_names: - include_features = {name: include_features for name in model_names} - return StandardMetricInputsPreprocessor( - include_filter={ - constants.TRANSFORMED_FEATURES_KEY: include_features, - constants.FEATURES_KEY: include_features, - }, - include_default_inputs=include_default_inputs, - model_names=model_names, - output_names=output_names, - name=_DEFAULT_COMBINED_FEATURE_PREPROCESSOR_NAME, - ) + """Returns preprocessor for incl combined features in StandardMetricInputs. + + Args: + ---- + feature_keys: List of feature keys. An empty list means all. + include_default_inputs: True to include default inputs (labels, predictions, + example weights) in addition to the transformed features. + model_names: Optional model names (required if transformed_features used + with multi-model evaluations). + output_names: Optional output names. Only used if include_default_inputs is + True. If unset all outputs will be included with the default inputs. + """ + if feature_keys: + include_features = {k: {} for k in feature_keys} + else: + include_features = {} + if model_names: + include_features = {name: include_features for name in model_names} + return StandardMetricInputsPreprocessor( + include_filter={ + constants.TRANSFORMED_FEATURES_KEY: include_features, + constants.FEATURES_KEY: include_features, + }, + include_default_inputs=include_default_inputs, + model_names=model_names, + output_names=output_names, + name=_DEFAULT_COMBINED_FEATURE_PREPROCESSOR_NAME, + ) def AttributionPreprocessor( # pylint: disable=invalid-name @@ -1146,55 +1179,57 @@ def AttributionPreprocessor( # pylint: disable=invalid-name model_names: Optional[Iterable[str]] = None, output_names: Optional[Iterable[str]] = None, ) -> StandardMetricInputsPreprocessor: - """Returns preprocessor for including attributions in StandardMetricInputs. - - Args: - feature_keys: List of feature keys under attributions to keep. An empty list - means all. - include_default_inputs: True to include default inputs (labels, predictions, - example weights) in addition to the transformed features. - model_names: Optional model names (required for multi-model evaluations). - output_names: Optional output names (required for multi-output evaluations). - """ - if feature_keys: - include_features = {k: {} for k in feature_keys} - else: - include_features = {} - if output_names: - include_features = {name: include_features for name in output_names} - if model_names: - include_features = {name: include_features for name in model_names} - return StandardMetricInputsPreprocessor( - include_filter={constants.ATTRIBUTIONS_KEY: include_features}, - include_default_inputs=include_default_inputs, - model_names=model_names, - output_names=output_names, - name=_DEFAULT_ATTRIBUTION_PREPROCESSOR_NAME, - ) + """Returns preprocessor for including attributions in StandardMetricInputs. + + Args: + ---- + feature_keys: List of feature keys under attributions to keep. An empty list + means all. + include_default_inputs: True to include default inputs (labels, predictions, + example weights) in addition to the transformed features. + model_names: Optional model names (required for multi-model evaluations). + output_names: Optional output names (required for multi-output evaluations). + """ + if feature_keys: + include_features = {k: {} for k in feature_keys} + else: + include_features = {} + if output_names: + include_features = {name: include_features for name in output_names} + if model_names: + include_features = {name: include_features for name in model_names} + return StandardMetricInputsPreprocessor( + include_filter={constants.ATTRIBUTIONS_KEY: include_features}, + include_default_inputs=include_default_inputs, + model_names=model_names, + output_names=output_names, + name=_DEFAULT_ATTRIBUTION_PREPROCESSOR_NAME, + ) def StandardMetricInputsPreprocessorList( # pylint: disable=invalid-name preprocessors: List[StandardMetricInputsPreprocessor], ) -> StandardMetricInputsPreprocessor: - """Returns preprocessor combining multiple standard preprocessors together. - - Args: - preprocessors: List of StandardMetricInputsPreprocessors. Must be of type - StandardMetricInputsPreprocessor (subclasses not supported). - """ - include_filter = {} - for p in preprocessors: - if type(p) != StandardMetricInputsPreprocessor: # pylint: disable=unidiomatic-typecheck - raise ValueError( - 'Only direct instances of StandardMetricsInputPreprocessor ' - '(excluding sub-classes) are supported' - ) - if not include_filter: - include_filter = p.include_filter - else: - include_filter = util.merge_filters(include_filter, p.include_filter) - return StandardMetricInputsPreprocessor( - include_filter=include_filter, - include_default_inputs=False, - name=_DEFAULT_STANDARD_METRIC_INPUT_PREPROCESSOR_LIST_NAME, - ) + """Returns preprocessor combining multiple standard preprocessors together. + + Args: + ---- + preprocessors: List of StandardMetricInputsPreprocessors. Must be of type + StandardMetricInputsPreprocessor (subclasses not supported). + """ + include_filter = {} + for p in preprocessors: + if type(p) != StandardMetricInputsPreprocessor: # pylint: disable=unidiomatic-typecheck + raise ValueError( + "Only direct instances of StandardMetricsInputPreprocessor " + "(excluding sub-classes) are supported" + ) + if not include_filter: + include_filter = p.include_filter + else: + include_filter = util.merge_filters(include_filter, p.include_filter) + return StandardMetricInputsPreprocessor( + include_filter=include_filter, + include_default_inputs=False, + name=_DEFAULT_STANDARD_METRIC_INPUT_PREPROCESSOR_LIST_NAME, + ) diff --git a/tensorflow_model_analysis/metrics/metric_types_test.py b/tensorflow_model_analysis/metrics/metric_types_test.py index 4c041c6013..e0654dbf0b 100644 --- a/tensorflow_model_analysis/metrics/metric_types_test.py +++ b/tensorflow_model_analysis/metrics/metric_types_test.py @@ -14,231 +14,253 @@ """Tests for metric_types.""" import tensorflow as tf + from tensorflow_model_analysis.metrics import metric_types class MetricTypesTest(tf.test.TestCase): + def testMetricKeyStrForMetricKeyWithOneField(self): + self.assertEqual( + str(metric_types.MetricKey(name="metric_name")), + 'name: "metric_name" example_weighted: { }', + ) - def testMetricKeyStrForMetricKeyWithOneField(self): - self.assertEqual( - str(metric_types.MetricKey(name='metric_name')), - 'name: "metric_name" example_weighted: { }') + def testMetricKeyStrForMetricKeyWithAllFields(self): + self.assertEqual( + str( + metric_types.MetricKey( + name="metric_name", + model_name="model_name", + output_name="output_name", + sub_key=metric_types.SubKey(class_id=1), + example_weighted=True, + is_diff=True, + ) + ), + 'name: "metric_name" output_name: "output_name" ' + + 'sub_key: { class_id: { value: 1 } } model_name: "model_name" ' + + "is_diff: true example_weighted: { value: true }", + ) - def testMetricKeyStrForMetricKeyWithAllFields(self): - self.assertEqual( - str( + def testMetricKeyFromProto(self): + metric_keys = [ + metric_types.MetricKey(name="metric_name"), metric_types.MetricKey( - name='metric_name', - model_name='model_name', - output_name='output_name', + name="metric_name", + model_name="model_name", + output_name="output_name", sub_key=metric_types.SubKey(class_id=1), - example_weighted=True, - is_diff=True)), - 'name: "metric_name" output_name: "output_name" ' + - 'sub_key: { class_id: { value: 1 } } model_name: "model_name" ' + - 'is_diff: true example_weighted: { value: true }') - - def testMetricKeyFromProto(self): - metric_keys = [ - metric_types.MetricKey(name='metric_name'), - metric_types.MetricKey( - name='metric_name', - model_name='model_name', - output_name='output_name', - sub_key=metric_types.SubKey(class_id=1), - is_diff=True), - metric_types.MetricKey( - name='metric_name', - model_name='model_name', - output_name='output_name', - sub_key=metric_types.SubKey(top_k=2), - example_weighted=None, - aggregation_type=metric_types.AggregationType(micro_average=True)), - metric_types.MetricKey( - name='metric_name', - model_name='model_name', - output_name='output_name', - example_weighted=False) - ] - for key in metric_keys: - got_key = metric_types.MetricKey.from_proto(key.to_proto()) - self.assertEqual(key, got_key, '{} != {}'.format(key, got_key)) - - def testPlotKeyFromProto(self): - plot_keys = [ - metric_types.PlotKey(name='plot_name'), - metric_types.PlotKey( - name='plot_name', - model_name='model_name', - output_name='output_name', - sub_key=metric_types.SubKey(class_id=1), - ), - metric_types.MetricKey( - name='plot_name', - model_name='model_name', - output_name='output_name', - sub_key=metric_types.SubKey(top_k=2), - ), - ] - for key in plot_keys: - got_key = metric_types.PlotKey.from_proto(key.to_proto()) - self.assertEqual(key, got_key, '{} != {}'.format(key, got_key)) - - def testSubKeyStr(self): - self.assertEqual(str(metric_types.SubKey(class_id=1)), 'classId:1') - self.assertEqual(str(metric_types.SubKey(top_k=2)), 'topK:2') - self.assertEqual(str(metric_types.SubKey(k=3)), 'k:3') - self.assertEqual( - str(metric_types.SubKey(class_id=1, top_k=2)), 'classId:1 topK:2' - ) + is_diff=True, + ), + metric_types.MetricKey( + name="metric_name", + model_name="model_name", + output_name="output_name", + sub_key=metric_types.SubKey(top_k=2), + example_weighted=None, + aggregation_type=metric_types.AggregationType(micro_average=True), + ), + metric_types.MetricKey( + name="metric_name", + model_name="model_name", + output_name="output_name", + example_weighted=False, + ), + ] + for key in metric_keys: + got_key = metric_types.MetricKey.from_proto(key.to_proto()) + self.assertEqual(key, got_key, f"{key} != {got_key}") - def testSubKeySetUp(self): - with self.assertRaises( - NotImplementedError, - msg=( - 'A non-existent SubKey should be represented as None, not as ', - 'SubKey(None, None, None).', - ), - ): - str(metric_types.SubKey()) - with self.assertRaises( - ValueError, - msg=('k and top_k cannot both be set at the same time',), - ): - str(metric_types.SubKey(k=2, top_k=2)) - with self.assertRaises( - ValueError, - msg=('k and class_id cannot both be set at the same time',), - ): - str(metric_types.SubKey(k=2, class_id=2)) + def testPlotKeyFromProto(self): + plot_keys = [ + metric_types.PlotKey(name="plot_name"), + metric_types.PlotKey( + name="plot_name", + model_name="model_name", + output_name="output_name", + sub_key=metric_types.SubKey(class_id=1), + ), + metric_types.MetricKey( + name="plot_name", + model_name="model_name", + output_name="output_name", + sub_key=metric_types.SubKey(top_k=2), + ), + ] + for key in plot_keys: + got_key = metric_types.PlotKey.from_proto(key.to_proto()) + self.assertEqual(key, got_key, f"{key} != {got_key}") - def testAggregationTypeLessThan(self): - self.assertLess( - metric_types.AggregationType(macro_average=True), - metric_types.AggregationType(micro_average=True), - ) - self.assertLess( - metric_types.AggregationType(weighted_macro_average=True), - metric_types.AggregationType(macro_average=True), - ) + def testSubKeyStr(self): + self.assertEqual(str(metric_types.SubKey(class_id=1)), "classId:1") + self.assertEqual(str(metric_types.SubKey(top_k=2)), "topK:2") + self.assertEqual(str(metric_types.SubKey(k=3)), "k:3") + self.assertEqual( + str(metric_types.SubKey(class_id=1, top_k=2)), "classId:1 topK:2" + ) - def testPreprocessors(self): - preprocessor = metric_types.StandardMetricInputsPreprocessorList([ - metric_types.FeaturePreprocessor(feature_keys=['feature1', 'feature2']), - metric_types.TransformedFeaturePreprocessor(feature_keys=['feature1']), - metric_types.AttributionPreprocessor(feature_keys=['feature1']), - ]) - self.assertEqual( - preprocessor.include_filter, - { - 'labels': {}, - 'predictions': {}, - 'example_weights': {}, - 'features': { - 'feature1': {}, - 'feature2': {}, - }, - 'transformed_features': { - 'feature1': {}, - }, - 'attributions': { - 'feature1': {}, - }, - }, - ) + def testSubKeySetUp(self): + with self.assertRaises( + NotImplementedError, + msg=( + "A non-existent SubKey should be represented as None, not as ", + "SubKey(None, None, None).", + ), + ): + str(metric_types.SubKey()) + with self.assertRaises( + ValueError, + msg=("k and top_k cannot both be set at the same time",), + ): + str(metric_types.SubKey(k=2, top_k=2)) + with self.assertRaises( + ValueError, + msg=("k and class_id cannot both be set at the same time",), + ): + str(metric_types.SubKey(k=2, class_id=2)) - def testPreprocessorsWithoutDefaults(self): - preprocessor = metric_types.StandardMetricInputsPreprocessorList([ - metric_types.FeaturePreprocessor( - feature_keys=['feature1', 'feature2'], - include_default_inputs=False), - metric_types.TransformedFeaturePreprocessor( - feature_keys=['feature1'], include_default_inputs=False), - metric_types.AttributionPreprocessor( - feature_keys=['feature1'], include_default_inputs=False) - ]) - self.assertEqual( - preprocessor.include_filter, { - 'features': { - 'feature1': {}, - 'feature2': {}, - }, - 'transformed_features': { - 'feature1': {}, - }, - 'attributions': { - 'feature1': {}, - }, - }) + def testAggregationTypeLessThan(self): + self.assertLess( + metric_types.AggregationType(macro_average=True), + metric_types.AggregationType(micro_average=True), + ) + self.assertLess( + metric_types.AggregationType(weighted_macro_average=True), + metric_types.AggregationType(macro_average=True), + ) - def testMultiModelMultiOutputPreprocessors(self): - preprocessor = metric_types.StandardMetricInputsPreprocessorList([ - metric_types.FeaturePreprocessor( - feature_keys=['feature1', 'feature2'], - model_names=['model1', 'model2'], - output_names=['output1', 'output2']), - metric_types.TransformedFeaturePreprocessor( - feature_keys=['feature1'], - model_names=['model1', 'model2'], - output_names=['output1', 'output2']), - metric_types.AttributionPreprocessor( - feature_keys=['feature1'], - model_names=['model1'], - output_names=['output2']) - ]) - self.assertEqual( - preprocessor.include_filter, { - 'labels': { - 'model1': { - 'output1': {}, - 'output2': {}, + def testPreprocessors(self): + preprocessor = metric_types.StandardMetricInputsPreprocessorList( + [ + metric_types.FeaturePreprocessor(feature_keys=["feature1", "feature2"]), + metric_types.TransformedFeaturePreprocessor(feature_keys=["feature1"]), + metric_types.AttributionPreprocessor(feature_keys=["feature1"]), + ] + ) + self.assertEqual( + preprocessor.include_filter, + { + "labels": {}, + "predictions": {}, + "example_weights": {}, + "features": { + "feature1": {}, + "feature2": {}, }, - 'model2': { - 'output1': {}, - 'output2': {}, + "transformed_features": { + "feature1": {}, + }, + "attributions": { + "feature1": {}, }, }, - 'predictions': { - 'model1': { - 'output1': {}, - 'output2': {}, + ) + + def testPreprocessorsWithoutDefaults(self): + preprocessor = metric_types.StandardMetricInputsPreprocessorList( + [ + metric_types.FeaturePreprocessor( + feature_keys=["feature1", "feature2"], include_default_inputs=False + ), + metric_types.TransformedFeaturePreprocessor( + feature_keys=["feature1"], include_default_inputs=False + ), + metric_types.AttributionPreprocessor( + feature_keys=["feature1"], include_default_inputs=False + ), + ] + ) + self.assertEqual( + preprocessor.include_filter, + { + "features": { + "feature1": {}, + "feature2": {}, + }, + "transformed_features": { + "feature1": {}, }, - 'model2': { - 'output1': {}, - 'output2': {}, + "attributions": { + "feature1": {}, }, }, - 'example_weights': { - 'model1': { - 'output1': {}, - 'output2': {}, + ) + + def testMultiModelMultiOutputPreprocessors(self): + preprocessor = metric_types.StandardMetricInputsPreprocessorList( + [ + metric_types.FeaturePreprocessor( + feature_keys=["feature1", "feature2"], + model_names=["model1", "model2"], + output_names=["output1", "output2"], + ), + metric_types.TransformedFeaturePreprocessor( + feature_keys=["feature1"], + model_names=["model1", "model2"], + output_names=["output1", "output2"], + ), + metric_types.AttributionPreprocessor( + feature_keys=["feature1"], + model_names=["model1"], + output_names=["output2"], + ), + ] + ) + self.assertEqual( + preprocessor.include_filter, + { + "labels": { + "model1": { + "output1": {}, + "output2": {}, + }, + "model2": { + "output1": {}, + "output2": {}, + }, }, - 'model2': { - 'output1': {}, - 'output2': {}, + "predictions": { + "model1": { + "output1": {}, + "output2": {}, + }, + "model2": { + "output1": {}, + "output2": {}, + }, }, - }, - 'features': { - 'feature1': {}, - 'feature2': {}, - }, - 'transformed_features': { - 'model1': { - 'feature1': {}, + "example_weights": { + "model1": { + "output1": {}, + "output2": {}, + }, + "model2": { + "output1": {}, + "output2": {}, + }, }, - 'model2': { - 'feature1': {}, + "features": { + "feature1": {}, + "feature2": {}, }, - }, - 'attributions': { - 'model1': { - 'output2': { - 'feature1': {}, - } + "transformed_features": { + "model1": { + "feature1": {}, + }, + "model2": { + "feature1": {}, + }, + }, + "attributions": { + "model1": { + "output2": { + "feature1": {}, + } + }, }, }, - }) + ) -if __name__ == '__main__': - tf.test.main() +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow_model_analysis/metrics/metric_util.py b/tensorflow_model_analysis/metrics/metric_util.py index 18ee80d419..385dc2a1ad 100644 --- a/tensorflow_model_analysis/metrics/metric_util.py +++ b/tensorflow_model_analysis/metrics/metric_util.py @@ -15,10 +15,23 @@ import inspect import math -from typing import Any, Callable, Dict, Iterable, Iterator, List, Mapping, Optional, Tuple, Union +from typing import ( + Any, + Callable, + Dict, + Iterable, + Iterator, + List, + Mapping, + Optional, + Tuple, + Union, +) import numpy as np import tensorflow as tf +from tensorflow_metadata.proto.v0 import schema_pb2 + from tensorflow_model_analysis import constants from tensorflow_model_analysis.api import types from tensorflow_model_analysis.metrics import metric_types @@ -26,16 +39,14 @@ from tensorflow_model_analysis.utils import util from tensorflow_model_analysis.utils.keras_lib import tf_keras -from tensorflow_metadata.proto.v0 import schema_pb2 - -_ALL_CLASSES = 'all_classes' -_PREDICTIONS = 'predictions' -_LOGISTIC = 'logistic' -_PROBABILITIES = 'probabilities' -_LOGITS = 'logits' +_ALL_CLASSES = "all_classes" +_PREDICTIONS = "predictions" +_LOGISTIC = "logistic" +_PROBABILITIES = "probabilities" +_LOGITS = "logits" -_MEAN_METRIC_WRAPPER = 'MeanMetricWrapper' -_LOSS_FUNCTION_WRAPPER = 'LossFunctionWrapper' +_MEAN_METRIC_WRAPPER = "MeanMetricWrapper" +_LOSS_FUNCTION_WRAPPER = "LossFunctionWrapper" _EPSILON = 1e-7 @@ -47,197 +58,204 @@ def validate_object_detection_arguments( max_num_detections: Optional[int] = None, labels_to_stack: Optional[List[str]] = None, predictions_to_stack: Optional[List[str]] = None, - output_name: Optional[str] = None) -> None: - """Validate the arguments for object detection related functions.""" - if class_id is None: - raise ValueError('class_id must be provided if use object' ' detection.') - if isinstance(class_id, int): - class_id = [class_id] - if class_weight is not None: - if isinstance(class_weight, float): - class_weight = [class_weight] - for weight in class_weight: - if weight < 0: - raise ValueError(f'class_weight = {class_weight} must ' - 'not be negative.') - if len(class_id) != len(class_weight): - raise ValueError('Mismatch of length between class_id = ' - f'{class_id} and class_weight = ' - f'{class_weight}.') - if area_range is not None: - if len(area_range) != 2 or area_range[0] > area_range[1]: - raise ValueError(f'area_range = {area_range} must be a valid interval.') - if max_num_detections is not None and max_num_detections <= 0: - raise ValueError(f'max_num_detections = {max_num_detections} must be ' - 'positive.') - if output_name and (labels_to_stack or predictions_to_stack): - raise ValueError('The metric does not support specifying the output name' - ' when there are keys/outputs specified to be stacked.') + output_name: Optional[str] = None, +) -> None: + """Validate the arguments for object detection related functions.""" + if class_id is None: + raise ValueError("class_id must be provided if use object" " detection.") + if isinstance(class_id, int): + class_id = [class_id] + if class_weight is not None: + if isinstance(class_weight, float): + class_weight = [class_weight] + for weight in class_weight: + if weight < 0: + raise ValueError( + f"class_weight = {class_weight} must " "not be negative." + ) + if len(class_id) != len(class_weight): + raise ValueError( + "Mismatch of length between class_id = " + f"{class_id} and class_weight = " + f"{class_weight}." + ) + if area_range is not None: + if len(area_range) != 2 or area_range[0] > area_range[1]: + raise ValueError(f"area_range = {area_range} must be a valid interval.") + if max_num_detections is not None and max_num_detections <= 0: + raise ValueError( + f"max_num_detections = {max_num_detections} must be " "positive." + ) + if output_name and (labels_to_stack or predictions_to_stack): + raise ValueError( + "The metric does not support specifying the output name" + " when there are keys/outputs specified to be stacked." + ) def generate_private_name_from_arguments(name: str, **kwargs) -> str: - """Generate names for used metrics. - - Args: - name: The name user defined for this metric. - **kwargs: (Optional) The dict of arguments with their corresponding names - and values. If the argument value is None, it will be obmitted. - - Returns: - A name for the metric, generated from the specified arguments. - """ - if not name.startswith('_'): - # tfma treats metrics starting with '_' as private and does not show it to - # users. - name = '_' + name - # sort the arguments in alphabetical order so that it generates the same name - # for the same group of arguments. - if kwargs is not None: - name = name + ':' + ','.join(k + '=' + str(kwargs[k]) - for k in sorted(kwargs) - if kwargs[k] is not None) - return name + """Generate names for used metrics. + + Args: + ---- + name: The name user defined for this metric. + **kwargs: (Optional) The dict of arguments with their corresponding names + and values. If the argument value is None, it will be obmitted. + + Returns: + ------- + A name for the metric, generated from the specified arguments. + """ + if not name.startswith("_"): + # tfma treats metrics starting with '_' as private and does not show it to + # users. + name = "_" + name + # sort the arguments in alphabetical order so that it generates the same name + # for the same group of arguments. + if kwargs is not None: + name = ( + name + + ":" + + ",".join( + k + "=" + str(kwargs[k]) + for k in sorted(kwargs) + if kwargs[k] is not None + ) + ) + return name def within_interval(value: float, left: float, right: float) -> bool: - """Returns true if value is within [left, right].""" - # EPSILON is used to handle rounding errors that may occur if the value was - # created using floating point operations. - return value >= left - _EPSILON and value <= right + _EPSILON + """Returns true if value is within [left, right].""" + # EPSILON is used to handle rounding errors that may occur if the value was + # created using floating point operations. + return value >= left - _EPSILON and value <= right + _EPSILON def serialize_metric( metric: tf_keras.metrics.Metric, use_legacy_format=False ) -> Dict[str, Any]: - """Serializes keras metric.""" - if ( - 'use_legacy_format' - in inspect.getfullargspec(tf_keras.metrics.serialize).args - ): - cfg = tf_keras.metrics.serialize( - metric, use_legacy_format=use_legacy_format - ) - else: - cfg = tf_keras.metrics.serialize(metric) - # If a metric function (vs a class) is passed directly to compile, it - # will be wrapped in a MeanMetricWrapper which is not deserializable. - # If this happens, set the class name to the CamelCase from of the - # function name since most keras metric functions have both forms. - if ('class_name' in cfg and cfg['class_name'] == _MEAN_METRIC_WRAPPER and - 'config' in cfg and 'name' in cfg['config']): - cfg['class_name'] = _camel_case(cfg['config']['name']) - return cfg + """Serializes keras metric.""" + if "use_legacy_format" in inspect.getfullargspec(tf_keras.metrics.serialize).args: + cfg = tf_keras.metrics.serialize(metric, use_legacy_format=use_legacy_format) + else: + cfg = tf_keras.metrics.serialize(metric) + # If a metric function (vs a class) is passed directly to compile, it + # will be wrapped in a MeanMetricWrapper which is not deserializable. + # If this happens, set the class name to the CamelCase from of the + # function name since most keras metric functions have both forms. + if ( + "class_name" in cfg + and cfg["class_name"] == _MEAN_METRIC_WRAPPER + and "config" in cfg + and "name" in cfg["config"] + ): + cfg["class_name"] = _camel_case(cfg["config"]["name"]) + return cfg def serialize_loss( loss: tf_keras.losses.Loss, use_legacy_format=False ) -> Dict[str, Any]: - """Serializes keras loss.""" - if ( - 'use_legacy_format' - in inspect.getfullargspec(tf_keras.losses.serialize).args - ): - cfg = tf_keras.losses.serialize(loss, use_legacy_format=use_legacy_format) - else: - cfg = tf_keras.losses.serialize(loss) - # If a metric function (vs a class) is passed directly to compile, it - # will be wrapped in a LossFunctionWrapper which is not deserializable. - # If this happens, set the class name to the CamelCase from of the - # function name since most keras loss functions have both forms. - if ('class_name' in cfg and cfg['class_name'] == _LOSS_FUNCTION_WRAPPER and - 'config' in cfg and 'name' in cfg['config']): - cfg['class_name'] = _camel_case(cfg['config']['name']) - return cfg + """Serializes keras loss.""" + if "use_legacy_format" in inspect.getfullargspec(tf_keras.losses.serialize).args: + cfg = tf_keras.losses.serialize(loss, use_legacy_format=use_legacy_format) + else: + cfg = tf_keras.losses.serialize(loss) + # If a metric function (vs a class) is passed directly to compile, it + # will be wrapped in a LossFunctionWrapper which is not deserializable. + # If this happens, set the class name to the CamelCase from of the + # function name since most keras loss functions have both forms. + if ( + "class_name" in cfg + and cfg["class_name"] == _LOSS_FUNCTION_WRAPPER + and "config" in cfg + and "name" in cfg["config"] + ): + cfg["class_name"] = _camel_case(cfg["config"]["name"]) + return cfg def deserialize_metric(config, use_legacy_format=False): - if ( - 'use_legacy_format' - in inspect.getfullargspec(tf_keras.metrics.deserialize).args - ): - return tf_keras.metrics.deserialize( - config, use_legacy_format=use_legacy_format - ) - else: - return tf_keras.metrics.deserialize(config) + if "use_legacy_format" in inspect.getfullargspec(tf_keras.metrics.deserialize).args: + return tf_keras.metrics.deserialize(config, use_legacy_format=use_legacy_format) + else: + return tf_keras.metrics.deserialize(config) def deserialize_loss(config, use_legacy_format=False): - if ( - 'use_legacy_format' - in inspect.getfullargspec(tf_keras.losses.deserialize).args - ): - return tf_keras.losses.deserialize( - config, use_legacy_format=use_legacy_format - ) - else: - return tf_keras.losses.deserialize(config) + if "use_legacy_format" in inspect.getfullargspec(tf_keras.losses.deserialize).args: + return tf_keras.losses.deserialize(config, use_legacy_format=use_legacy_format) + else: + return tf_keras.losses.deserialize(config) def serialize_keras_object(obj): - if hasattr(tf_keras.utils, 'legacy'): - return tf_keras.utils.legacy.serialize_keras_object(obj) - else: - return tf_keras.utils.serialize_keras_object(obj) + if hasattr(tf_keras.utils, "legacy"): + return tf_keras.utils.legacy.serialize_keras_object(obj) + else: + return tf_keras.utils.serialize_keras_object(obj) def deserialize_keras_object( config, module_objects=None, custom_objects=None, printable_module_name=None ): - if hasattr(tf_keras.utils, 'legacy'): - return tf_keras.utils.legacy.deserialize_keras_object( - config, custom_objects, module_objects, printable_module_name - ) - else: - return tf_keras.utils.deserialize_keras_object( - config, custom_objects, module_objects, printable_module_name - ) + if hasattr(tf_keras.utils, "legacy"): + return tf_keras.utils.legacy.deserialize_keras_object( + config, custom_objects, module_objects, printable_module_name + ) + else: + return tf_keras.utils.deserialize_keras_object( + config, custom_objects, module_objects, printable_module_name + ) def _camel_case(txt: str) -> str: - return ''.join(s.capitalize() for s in txt.split('_')) - - -def to_scalar(tensor: Optional[Union[types.TensorValue, - tf.compat.v1.SparseTensorValue]], - tensor_name: str = 'unknown') -> Optional[Union[float, int, str]]: - """Returns value as a scalar or raises ValueError.""" - if tensor is None: - return None - if util.is_sparse_or_ragged_tensor_value(tensor): - tensor = tensor.values - if tensor.size != 1: - raise ValueError(f'"{tensor_name}" should have exactly 1 value, but found ' - f'{tensor.size} instead: values={tensor}') - return tensor.item() + return "".join(s.capitalize() for s in txt.split("_")) + + +def to_scalar( + tensor: Optional[Union[types.TensorValue, tf.compat.v1.SparseTensorValue]], + tensor_name: str = "unknown", +) -> Optional[Union[float, int, str]]: + """Returns value as a scalar or raises ValueError.""" + if tensor is None: + return None + if util.is_sparse_or_ragged_tensor_value(tensor): + tensor = tensor.values + if tensor.size != 1: + raise ValueError( + f'"{tensor_name}" should have exactly 1 value, but found ' + f"{tensor.size} instead: values={tensor}" + ) + return tensor.item() def safe_to_scalar(arr: Any) -> Any: - """Returns array/list as a scalar, 0.0 if empty else raises ValueError.""" - if isinstance(arr, list): - if not arr: - return 0.0 - else: - raise ValueError('Array should have exactly 1 value to a Python scalar') - else: - if arr.size == 0: - return 0.0 - elif arr.size == 1: - return arr.item() + """Returns array/list as a scalar, 0.0 if empty else raises ValueError.""" + if isinstance(arr, list): + if not arr: + return 0.0 + else: + raise ValueError("Array should have exactly 1 value to a Python scalar") else: - raise ValueError('Array should have exactly 1 value to a Python scalar') + if arr.size == 0: + return 0.0 + elif arr.size == 1: + return arr.item() + else: + raise ValueError("Array should have exactly 1 value to a Python scalar") def pad(arr: np.ndarray, last_dim: int, value: float) -> np.ndarray: - """Pads the given array with value until last dim is of size last_dim.""" - if arr.shape[-1] == last_dim: - return arr - pad_width = [] - for _ in arr.shape[:-1]: - pad_width.append((0, 0)) # Don't pad inner dimensions - pad_width.append((0, last_dim - arr.shape[-1])) # Pad up to last_dim - return np.pad( - arr, pad_width=pad_width, mode='constant', constant_values=value) + """Pads the given array with value until last dim is of size last_dim.""" + if arr.shape[-1] == last_dim: + return arr + pad_width = [] + for _ in arr.shape[:-1]: + pad_width.append((0, 0)) # Don't pad inner dimensions + pad_width.append((0, last_dim - arr.shape[-1])) # Pad up to last_dim + return np.pad(arr, pad_width=pad_width, mode="constant", constant_values=value) def to_standard_metric_inputs( @@ -247,149 +265,174 @@ def to_standard_metric_inputs( include_features: bool = False, include_transformed_features: bool = False, include_any_feature: bool = False, - include_attributions: bool = False) -> metric_types.StandardMetricInputs: - """Verifies extract keys and converts extracts to StandardMetricInputs.""" - if include_labels and constants.LABELS_KEY not in extracts: - raise ValueError(f'"{constants.LABELS_KEY}" key not found in extracts. ' - 'Check that the configuration is setup properly to ' - 'specify the name of label input and that the proper ' - 'extractor has been configured to extract the labels from ' - f'the inputs. Existing keys: {extracts.keys()}') - if include_predictions and constants.PREDICTIONS_KEY not in extracts: - raise ValueError(f'"{constants.PREDICTIONS_KEY}" key not found in ' - 'extracts. Check that the proper extractor has been ' - 'configured to perform model inference.') - if include_features and constants.FEATURES_KEY not in extracts: - raise ValueError(f'"{constants.FEATURES_KEY}" key not found in extracts. ' - 'Check that the proper extractor has been configured to ' - 'extract the features from the inputs. Existing keys: ' - f'{extracts.keys()}') - if (include_transformed_features and - constants.TRANSFORMED_FEATURES_KEY not in extracts): - raise ValueError(f'"{constants.TRANSFORMED_FEATURES_KEY}" key not found in ' - 'extracts. Check that the proper extractor has been ' - 'configured to extract the transformed features from the ' - f'inputs. Existing keys: {extracts.keys()}') - if (include_any_feature and constants.FEATURES_KEY not in extracts and - constants.TRANSFORMED_FEATURES_KEY not in extracts): - raise ValueError( - f'"{constants.FEATURES_KEY}" or {constants.TRANSFORMED_FEATURES_KEY} ' - 'key not found in extracts. Check that the proper extractor has been ' - 'configured to extract the attributions from the inputs.' - f'Existing keys: {extracts.keys()}') - if (include_attributions and constants.ATTRIBUTIONS_KEY not in extracts): - raise ValueError(f'"{constants.ATTRIBUTIONS_KEY}" key not found in ' - 'extracts. Check that the proper extractor has been ' - 'configured to extract the attributions from the inputs.' - f'Existing keys: {extracts.keys()}') - return metric_types.StandardMetricInputs(extracts) + include_attributions: bool = False, +) -> metric_types.StandardMetricInputs: + """Verifies extract keys and converts extracts to StandardMetricInputs.""" + if include_labels and constants.LABELS_KEY not in extracts: + raise ValueError( + f'"{constants.LABELS_KEY}" key not found in extracts. ' + "Check that the configuration is setup properly to " + "specify the name of label input and that the proper " + "extractor has been configured to extract the labels from " + f"the inputs. Existing keys: {extracts.keys()}" + ) + if include_predictions and constants.PREDICTIONS_KEY not in extracts: + raise ValueError( + f'"{constants.PREDICTIONS_KEY}" key not found in ' + "extracts. Check that the proper extractor has been " + "configured to perform model inference." + ) + if include_features and constants.FEATURES_KEY not in extracts: + raise ValueError( + f'"{constants.FEATURES_KEY}" key not found in extracts. ' + "Check that the proper extractor has been configured to " + "extract the features from the inputs. Existing keys: " + f"{extracts.keys()}" + ) + if ( + include_transformed_features + and constants.TRANSFORMED_FEATURES_KEY not in extracts + ): + raise ValueError( + f'"{constants.TRANSFORMED_FEATURES_KEY}" key not found in ' + "extracts. Check that the proper extractor has been " + "configured to extract the transformed features from the " + f"inputs. Existing keys: {extracts.keys()}" + ) + if ( + include_any_feature + and constants.FEATURES_KEY not in extracts + and constants.TRANSFORMED_FEATURES_KEY not in extracts + ): + raise ValueError( + f'"{constants.FEATURES_KEY}" or {constants.TRANSFORMED_FEATURES_KEY} ' + "key not found in extracts. Check that the proper extractor has been " + "configured to extract the attributions from the inputs." + f"Existing keys: {extracts.keys()}" + ) + if include_attributions and constants.ATTRIBUTIONS_KEY not in extracts: + raise ValueError( + f'"{constants.ATTRIBUTIONS_KEY}" key not found in ' + "extracts. Check that the proper extractor has been " + "configured to extract the attributions from the inputs." + f"Existing keys: {extracts.keys()}" + ) + return metric_types.StandardMetricInputs(extracts) def top_k_indices( - top_k: int, - scores: Any, - sort: bool = False) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: - """Returns top_k indices into a list of scores. - - Note that the indices are returned in a form that is useful for assigning - values to the array. If using to select values from an array you may need to - reshape the output. Examples: - - # Assigning values to scores based on indices - indices = top_k_indices(1, scores) - scores[indices] = 0.0 - - # Selecting top_k - indices = top_k_indices(scores) - scores[indices].reshape(scores.shape[:-1] + (top_k,)) - - Args: - top_k: Number of top k values to return. - scores: Array or list of scores for computing the top_k indices. - sort: True if the indices should be sorted (in descending order). - - Returns: - An array of indices into scores that can be used with either 1D or 2D - arrays. If sort was True the indices will be returned in descending order of - score (i.e. top score first). - - Raises: - ValueError: If top_k doesn't match scores or input has more than 2 dims. - """ - scores = util.to_numpy(scores) - if scores.shape[-1] < top_k: - raise ValueError( - 'not enough values were provided to perform the requested ' - f'calcuations for top k. The requested value for k is {top_k}, but the ' - f'values are {scores}\n\nThis may be caused by a metric configuration ' - 'error or an error in the pipeline.') - - if len(scores.shape) == 1: - # 1D data - indices = np.argpartition(scores, -top_k)[-top_k:] - if sort: - indices = indices[np.argsort(-scores[indices])] - return indices - elif len(scores.shape) == 2: - # 2D data - indices = np.argpartition(scores, -top_k, axis=-1)[:, -top_k:] - # The above creates an n x top_k matrix where each row in indices matches - # the corresponding row in scores. For example: - # [ - # [, , ...], - # [, , ...], - # ... - # ] - # However numpy indexing wants the index to be be a 2-tuple of where the - # first tuple value contains the row indices (repeated top k times for each - # row) and the second tuple value contains the column values. - # (row1, row1, ..., row2, ...), (row1_top_k_index1, row1_top_index_2,...) - if sort: - for i in range(indices.shape[0]): - indices[i] = indices[i][np.argsort(-scores[i][indices[i]])] - return np.arange(indices.shape[0]).repeat(top_k), indices.flatten() - else: - raise NotImplementedError( - 'top_k not supported for shapes > 2: scores = {}'.format(scores)) + top_k: int, scores: Any, sort: bool = False +) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: + """Returns top_k indices into a list of scores. + + Note that the indices are returned in a form that is useful for assigning + values to the array. If using to select values from an array you may need to + reshape the output. Examples: + + # Assigning values to scores based on indices + indices = top_k_indices(1, scores) + scores[indices] = 0.0 + + # Selecting top_k + indices = top_k_indices(scores) + scores[indices].reshape(scores.shape[:-1] + (top_k,)) + + Args: + ---- + top_k: Number of top k values to return. + scores: Array or list of scores for computing the top_k indices. + sort: True if the indices should be sorted (in descending order). + + Returns: + ------- + An array of indices into scores that can be used with either 1D or 2D + arrays. If sort was True the indices will be returned in descending order of + score (i.e. top score first). + + Raises: + ------ + ValueError: If top_k doesn't match scores or input has more than 2 dims. + """ + scores = util.to_numpy(scores) + if scores.shape[-1] < top_k: + raise ValueError( + "not enough values were provided to perform the requested " + f"calcuations for top k. The requested value for k is {top_k}, but the " + f"values are {scores}\n\nThis may be caused by a metric configuration " + "error or an error in the pipeline." + ) + + if len(scores.shape) == 1: + # 1D data + indices = np.argpartition(scores, -top_k)[-top_k:] + if sort: + indices = indices[np.argsort(-scores[indices])] + return indices + elif len(scores.shape) == 2: + # 2D data + indices = np.argpartition(scores, -top_k, axis=-1)[:, -top_k:] + # The above creates an n x top_k matrix where each row in indices matches + # the corresponding row in scores. For example: + # [ + # [, , ...], + # [, , ...], + # ... + # ] + # However numpy indexing wants the index to be be a 2-tuple of where the + # first tuple value contains the row indices (repeated top k times for each + # row) and the second tuple value contains the column values. + # (row1, row1, ..., row2, ...), (row1_top_k_index1, row1_top_index_2,...) + if sort: + for i in range(indices.shape[0]): + indices[i] = indices[i][np.argsort(-scores[i][indices[i]])] + return np.arange(indices.shape[0]).repeat(top_k), indices.flatten() + else: + raise NotImplementedError( + f"top_k not supported for shapes > 2: scores = {scores}" + ) def select_indices( - arr: np.ndarray, - indices: Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]) -> np.ndarray: - """Selects values from tensor at given indices. - - Args: - arr: Array to select values from. - indices: Indices that are given either by an np.ndarray (1D) or a tuple of - np.ndarray's where the first value identifies the rows and the second the - columns (2D). - - Returns: - Values with the same shape as tensor except the last dimension will match - the number of indices selected. - """ - values = arr[indices] - if len(arr.shape) == 1: - return values - elif len(arr.shape) == 2: - # The indices[0] contains rows of the form [row1, row1, ..., row2, ...] - # the rows are repeated for each column. Since the first dimension of the - # array tells us the number of rows, dividing the length of indices[0] by - # the number of rows tells us the number of columns we are returning (i.e. - # the size of the last dim). - last_dim = int(len(indices[0]) / arr.shape[0]) - values = values.reshape(arr.shape[:-1] + (last_dim,)) - return values - else: - raise NotImplementedError('select_indices not supported for shapes > 2: ' - 'arr={}, indices={}'.format(arr, indices)) + arr: np.ndarray, indices: Union[np.ndarray, Tuple[np.ndarray, np.ndarray]] +) -> np.ndarray: + """Selects values from tensor at given indices. + + Args: + ---- + arr: Array to select values from. + indices: Indices that are given either by an np.ndarray (1D) or a tuple of + np.ndarray's where the first value identifies the rows and the second the + columns (2D). + + Returns: + ------- + Values with the same shape as tensor except the last dimension will match + the number of indices selected. + """ + values = arr[indices] + if len(arr.shape) == 1: + return values + elif len(arr.shape) == 2: + # The indices[0] contains rows of the form [row1, row1, ..., row2, ...] + # the rows are repeated for each column. Since the first dimension of the + # array tells us the number of rows, dividing the length of indices[0] by + # the number of rows tells us the number of columns we are returning (i.e. + # the size of the last dim). + last_dim = int(len(indices[0]) / arr.shape[0]) + values = values.reshape(arr.shape[:-1] + (last_dim,)) + return values + else: + raise NotImplementedError( + "select_indices not supported for shapes > 2: " + f"arr={arr}, indices={indices}" + ) def to_label_prediction_example_weight( inputs: metric_types.StandardMetricInputs, eval_config: Optional[config_pb2.EvalConfig] = None, - model_name: str = '', - output_name: str = '', + model_name: str = "", + output_name: str = "", sub_key: Optional[metric_types.SubKey] = None, aggregation_type: Optional[metric_types.AggregationType] = None, class_weights: Optional[Dict[int, float]] = None, @@ -398,515 +441,561 @@ def to_label_prediction_example_weight( flatten: bool = True, squeeze: bool = True, allow_none: bool = False, - require_single_example_weight: bool = False + require_single_example_weight: bool = False, ) -> Iterator[Tuple[np.ndarray, np.ndarray, np.ndarray]]: - """Yields label, prediction, and example weights for use in calculations. - - Where applicable this function will perform model and output name lookups as - well as any required class ID, top K, etc conversions. It will also apply - prediction keys and label vocabularies given the necessary information is - provided as part of the EvalConfig (or standard estimator based naming is - used). The sparseness of labels will be inferred from the shapes of the labels - and predictions (i.e. if the shapes are different then the labels will be - assumed to be sparse). - - If successful, the final output of calling this function will be a tuple of - numpy arrays representing the label, prediction, and example weight - respectively. Labels and predictions will be returned in the same shape - provided (default behavior) unless (1) flatten is True in which case a series - of values (one per class ID) will be returned with last dimension of size 1 or - (2) a sub_key is used in which case the last dimension may be re-shaped to - match the new number of outputs (1 for class_id or k, top_k for top k with - aggregation). - - Note that for top_k without aggregation, the non-top_k prediction values will - be set to float('-inf'), but for top_k with aggregation the values will be - truncated to only return the top k values. - - Examples: - - # default behavior - # - # Binary classification - Input : labels=[1] predictions=[0.6] - Output : (np.array([1]), np.array([0.6]), np.array([1.0])) - # Multi-class classification w/ sparse labels - Input : labels=[2] predictions=[0.3, 0.6, 0.1] - Output: (np.array([2]), np.array([0.3, 0.6, 0.1]), np.array([1.0])) - # Multi-class / multi-label classification w/ dense labels - Input : labels=[0, 1, 1] predictions=[0.3, 0.6, 0.1] - Output : (np.array([0, 1, 1]), np.array([0.3, 0.6, 0.1]), np.array([1.0])) - - # flatten=True - # - # Multi-class classification w/ sparse labels - Input : labels=[2], predictions=[0.3, 0.6, 0.1] - Output : (np.array([0]), np.array([0.3]), np.array([1.0])), - (np.array([0]), np.array([0.6]), np.array([1.0])), - (np.array([1]), np.array([0.1]), np.array([1.0])) - # Multi-class/multi-label classification w/ dense labels - Input : labels=[0, 0, 1], predictions=[0.3, 0.6, 0.1] - Output : (np.array([0]), np.array([0.3]), np.array([1.0])), - (np.array([0]), np.array([0.6]), np.array([1.0])), - (np.array([1]), np.array([0.1]), np.array([1.0])) - - # sub_key.class_id=[2] - # - # Multi-class classification w/ sparse labels - Input : labels=[2] predictions=[0.3, 0.6, 0.1] - Output : (np.array([1]), np.array([0.1]), np.array([1.0])) - # Multi-class classification w/ dense labels - Input : labels=[0, 0, 1] predictions=[0.3, 0.6, 0.1] - Output : (np.array([1]), np.array([0.1]), np.array([1.0])) - - # sub_key.top_k=2 and aggregation_type is None (i.e. binarization of top 2). - # - # Multi-class classification w/ sparse labels - Input : labels=[2] predictions=[0.3, 0.6, 0.1] - Output : (np.array([0, 0, 1]), np.array([0.3, 0.6, -inf]), np.array([1.0])) - # Multi-class classification w/ dense labels - Input : labels=[0, 0, 1] predictions=[0.3, 0.1, 0.6] - Output : (np.array([0, 0, 1]), np.array([0.3, -inf, 0.6]), np.array([1.0])) - - # sub_key.top_k=2 and aggregation_type is not None (i.e. aggregate top 2). - # - # Multi-class classification w/ sparse labels - Input : labels=[2] predictions=[0.3, 0.6, 0.1] - Output : (np.array([0, 1]), np.array([0.3, 0.6]), np.array([1.0])) - # Multi-class classification w/ dense labels - Input : labels=[0, 0, 1] predictions=[0.3, 0.1, 0.6] - Output : (np.array([0, 0]), np.array([0.3, 0.6]), np.array([1.0])) - - # sub_key.k=2 (i.e. binarization by choosing 2nd largest predicted value). - # - # Multi-class classification w/ sparse labels - Input : labels=[0] predictions=[0.3, 0.6, 0.1] - Output : (np.array([1]), np.array([0.3]), np.array([1.0])) - # Multi-class classification w/ dense labels - Input : labels=[0] predictions=[0.3] - Output : (np.array([0]), np.array([0.3]), np.array([1.0])) - - Args: - inputs: Standard metric inputs. - eval_config: Eval config - model_name: Optional model name (if multi-model evaluation). - output_name: Optional output name (if multi-output model type). - sub_key: Optional sub key. - aggregation_type: Optional aggregation type. - class_weights: Optional class weights to apply to multi-class / multi-label - labels and predictions. If used, flatten must also be True. - example_weighted: True if example weights should be applied. - fractional_labels: If true, each incoming tuple of (label, prediction, and - example weight) will be split into two tuples as follows (where l, p, w - represent the resulting label, prediction, and example weight values): (1) - l = 0.0, p = prediction, and w = example_weight * (1.0 - label) (2) l = - 1.0, p = prediction, and w = example_weight * label If enabled, an - exception will be raised if labels are not within [0, 1]. The - implementation is such that tuples associated with a weight of zero are - not yielded. This means it is safe to enable fractional_labels even when - the labels only take on the values of 0.0 or 1.0. - flatten: True to flatten the final label and prediction outputs so that the - yielded values are always arrays of size 1. For example, multi-class / - multi-label outputs would be converted into label and prediction pairs - that could then be processed by a binary classification metric in order to - compute a micro average over all classes. If the example weight is not a - scalar, then they will be flattened as well, otherwise the same example - weight value will be output for each pair of labels and predictions. - squeeze: True to squeeze any outputs that have rank > 1. This transforms - outputs such as np.array([[1]]) to np.array([1]). - allow_none: True to allow labels or predictions with None values to be - returned. When used, the values will be returned as empty np.ndarrays. The - example weight will always be non-empty. - require_single_example_weight: True to require that the example_weight be a - single value. - - Yields: - Tuple of (label, prediction, example_weight). - """ - - def fn_call_str(): - return (f'to_label_prediction_example_weight(inputs={inputs}, ' - f'eval_config={eval_config}, model_name={model_name}, ' - f'output_name={output_name}, sub_key={sub_key}, ' - f'aggregation_type={aggregation_type}, ' - f'class_weights={class_weights}, ' - f'fractional_labels={fractional_labels}, flatten={flatten}, ' - f'squeeze={squeeze}, allow_none={allow_none})') - - def optionally_get_by_keys(value: Any, keys: List[str]) -> Any: - - class NotFound(object): - pass - - if isinstance(value, Mapping): - new_value = util.get_by_keys(value, keys, default_value=NotFound()) - if not isinstance(new_value, NotFound): - # Might be None if that's what is in the dict - return new_value - return value - - try: - prediction_key = '' - label_key = '' - if eval_config and eval_config.model_specs: - for spec in eval_config.model_specs: - # To maintain consistency between settings where single models are used, - # always use '' as the model name regardless of whether a name is passed - spec_name = spec.name if len(eval_config.model_specs) > 1 else '' - if spec_name == model_name: - prediction_key = spec.prediction_key - label_key = spec.label_key - break - - label = inputs.label - if label_key: - # This is to support a custom EvalSavedModel where the labels are a dict - # but the keys are not output_names. - label = optionally_get_by_keys(label, [label_key]) - prediction = inputs.prediction - example_weight = inputs.example_weight - if model_name: - if prediction is not None: - prediction = util.get_by_keys(prediction, [model_name]) - # Labels and weights can optionally be keyed by model name. - label = optionally_get_by_keys(label, [model_name]) - example_weight = optionally_get_by_keys(example_weight, [model_name]) - if output_name: - if prediction is not None: - prediction = util.get_by_keys(prediction, [output_name]) - # Labels and example weights can optionally be keyed by output name. - label = optionally_get_by_keys(label, [output_name]) - example_weight = optionally_get_by_keys(example_weight, [output_name]) - - if not example_weighted or example_weight is None: - example_weight = np.array( - 1.0, dtype=np.float32 - ) # tf-ranking needs float32 - - if isinstance(label, Mapping): - raise ValueError( - 'unable to prepare label for metric computation because the label is ' - 'a dict with unrecognized keys. If a multi-output model was used ' - f'check that an output name was provided in all the relevant ' - 'settings (ModelSpec.label_keys, MetricsSpec.output_names, etc): ' - f'label={label}, output_name={output_name}') - if isinstance(example_weight, Mapping): - raise ValueError( - 'unable to prepare example_weight for metric computation because the ' - 'example_weight is a dict with unrecognized keys. If a multi-output ' - 'model was used check that an output name was provided in all the ' - 'relevant settings (ModelSpec.example_weight_keys, ' - f'MetricsSpec.output_names, etc): example_weight={example_weight}, ' - f'output_name={output_name}') - - label, prediction = prepare_labels_and_predictions(label, prediction, - prediction_key) - - if not allow_none: - for txt, value in zip(('label', 'prediction'), (label, prediction)): - if value is None: - raise ValueError( - f'no value provided for {txt}\n\n' - 'This may be caused by a configuration error (i.e. label, ' - 'and/or prediction keys were not specified) or an ' - 'error in the pipeline.') - - example_weight = util.to_numpy(example_weight) - if require_single_example_weight and example_weight.size > 1: - example_weight = example_weight.flatten() - if not np.all(example_weight == example_weight[0]): - raise ValueError( - 'if example_weight size > 0, the values must all be the same: ' - f'example_weight={example_weight}\n\n' - 'This is most likely a configuration error.') - example_weight = np.array(example_weight[0]) - - if sub_key is not None and label is not None and prediction is not None: - if sub_key.k is not None: - indices = top_k_indices(sub_key.k, prediction) - if len(prediction.shape) == 1: - indices = indices[0] # 1D - else: - # 2D, take kth values - indices = (indices[0][0::sub_key.k], indices[1][0::sub_key.k]) - if label.shape != prediction.shape: - label = one_hot(label, prediction) - label = select_indices(label, indices) - prediction = select_indices(prediction, indices) - else: - if sub_key.top_k is not None: - # Set all non-top-k predictions to -inf. Note that we do not sort. - indices = top_k_indices(sub_key.top_k, prediction) - if aggregation_type is None: - top_k_predictions = np.full(prediction.shape, float('-inf')) - top_k_predictions[indices] = prediction[indices] - prediction = top_k_predictions - else: - if label.shape != prediction.shape: - label = one_hot(label, prediction) - label = select_indices(label, indices) - prediction = select_indices(prediction, indices) - if sub_key.class_id is not None: - label, prediction = select_class_id( - sub_key.class_id, label, prediction - ) - - # For consistency, make sure all outputs are arrays (i.e. convert scalars) - if label is not None and not label.shape: - label = label.reshape((1,)) - if prediction is not None and not prediction.shape: - prediction = prediction.reshape((1,)) - if not example_weight.shape: - example_weight = example_weight.reshape((1,)) - - label = label if label is not None else np.array([]) - prediction = prediction if prediction is not None else np.array([]) - - flatten_size = prediction.size or label.size - if flatten: - if example_weight.size == 1: - example_weight = np.array( - [float(example_weight) for i in range(flatten_size)]) - elif example_weight.size != flatten_size: - raise ValueError( - 'example_weight size does not match the size of labels and ' - 'predictions: label={}, prediction={}, example_weight={}'.format( - label, prediction, example_weight)) - - if class_weights: - if not flatten: - raise ValueError( - 'class_weights can only be used when flatten is also used: ' - f'class_weights={class_weights}, flatten={flatten}\n\n' - 'This is likely caused by a configuration error (i.e. micro ' - "averaging being applied to metrics that don't support micro " - 'averaging') - example_weight = np.array([ - example_weight[i] * class_weights[i] if i in class_weights else 0.0 - for i in range(flatten_size) - ]) - - def yield_results(label, prediction, example_weight): - if (not flatten or (label.size == 0 and prediction.size == 0) or - (label.size == 1 and prediction.size == 1 and - example_weight.size == 1)): - if squeeze: - yield _squeeze(label), _squeeze(prediction), _squeeze(example_weight) - else: - yield label, prediction, example_weight - elif label.size == 0: - for p, w in zip(prediction.flatten(), example_weight.flatten()): - yield label, np.array([p]), np.array([w]) - elif prediction.size == 0: - for l, w in zip(label.flatten(), example_weight.flatten()): - yield np.array([l]), prediction, np.array([w]) - elif label.size == prediction.size and label.size == example_weight.size: - for l, p, w in zip(label.flatten(), prediction.flatten(), - example_weight.flatten()): - yield np.array([l]), np.array([p]), np.array([w]) - elif label.shape[-1] == 1 and prediction.size == example_weight.size: - label = one_hot(label, prediction) - for l, p, w in zip(label.flatten(), prediction.flatten(), - example_weight.flatten()): - yield np.array([l]), np.array([p]), np.array([w]) - else: - raise ValueError( - 'unable to pair labels, predictions, and example weights: ' - f'label={label}, prediction={prediction}, ' - f'example_weight={example_weight}\n\n' - 'This is most likely a configuration error.') - - for result in yield_results(label, prediction, example_weight): - if fractional_labels and label.size: - for new_result in _yield_fractional_labels(*result): - yield new_result - else: - yield result - except Exception as e: - import sys # pylint: disable=g-import-not-at-top - raise type(e)(str(e) + f'\n\n{fn_call_str()}').with_traceback( - sys.exc_info()[2]) + """Yields label, prediction, and example weights for use in calculations. + + Where applicable this function will perform model and output name lookups as + well as any required class ID, top K, etc conversions. It will also apply + prediction keys and label vocabularies given the necessary information is + provided as part of the EvalConfig (or standard estimator based naming is + used). The sparseness of labels will be inferred from the shapes of the labels + and predictions (i.e. if the shapes are different then the labels will be + assumed to be sparse). + + If successful, the final output of calling this function will be a tuple of + numpy arrays representing the label, prediction, and example weight + respectively. Labels and predictions will be returned in the same shape + provided (default behavior) unless (1) flatten is True in which case a series + of values (one per class ID) will be returned with last dimension of size 1 or + (2) a sub_key is used in which case the last dimension may be re-shaped to + match the new number of outputs (1 for class_id or k, top_k for top k with + aggregation). + + Note that for top_k without aggregation, the non-top_k prediction values will + be set to float('-inf'), but for top_k with aggregation the values will be + truncated to only return the top k values. + + Examples: + -------- + # default behavior + # + # Binary classification + Input : labels=[1] predictions=[0.6] + Output : (np.array([1]), np.array([0.6]), np.array([1.0])) + # Multi-class classification w/ sparse labels + Input : labels=[2] predictions=[0.3, 0.6, 0.1] + Output: (np.array([2]), np.array([0.3, 0.6, 0.1]), np.array([1.0])) + # Multi-class / multi-label classification w/ dense labels + Input : labels=[0, 1, 1] predictions=[0.3, 0.6, 0.1] + Output : (np.array([0, 1, 1]), np.array([0.3, 0.6, 0.1]), np.array([1.0])) + + # flatten=True + # + # Multi-class classification w/ sparse labels + Input : labels=[2], predictions=[0.3, 0.6, 0.1] + Output : (np.array([0]), np.array([0.3]), np.array([1.0])), + (np.array([0]), np.array([0.6]), np.array([1.0])), + (np.array([1]), np.array([0.1]), np.array([1.0])) + # Multi-class/multi-label classification w/ dense labels + Input : labels=[0, 0, 1], predictions=[0.3, 0.6, 0.1] + Output : (np.array([0]), np.array([0.3]), np.array([1.0])), + (np.array([0]), np.array([0.6]), np.array([1.0])), + (np.array([1]), np.array([0.1]), np.array([1.0])) + + # sub_key.class_id=[2] + # + # Multi-class classification w/ sparse labels + Input : labels=[2] predictions=[0.3, 0.6, 0.1] + Output : (np.array([1]), np.array([0.1]), np.array([1.0])) + # Multi-class classification w/ dense labels + Input : labels=[0, 0, 1] predictions=[0.3, 0.6, 0.1] + Output : (np.array([1]), np.array([0.1]), np.array([1.0])) + + # sub_key.top_k=2 and aggregation_type is None (i.e. binarization of top 2). + # + # Multi-class classification w/ sparse labels + Input : labels=[2] predictions=[0.3, 0.6, 0.1] + Output : (np.array([0, 0, 1]), np.array([0.3, 0.6, -inf]), np.array([1.0])) + # Multi-class classification w/ dense labels + Input : labels=[0, 0, 1] predictions=[0.3, 0.1, 0.6] + Output : (np.array([0, 0, 1]), np.array([0.3, -inf, 0.6]), np.array([1.0])) + + # sub_key.top_k=2 and aggregation_type is not None (i.e. aggregate top 2). + # + # Multi-class classification w/ sparse labels + Input : labels=[2] predictions=[0.3, 0.6, 0.1] + Output : (np.array([0, 1]), np.array([0.3, 0.6]), np.array([1.0])) + # Multi-class classification w/ dense labels + Input : labels=[0, 0, 1] predictions=[0.3, 0.1, 0.6] + Output : (np.array([0, 0]), np.array([0.3, 0.6]), np.array([1.0])) + + # sub_key.k=2 (i.e. binarization by choosing 2nd largest predicted value). + # + # Multi-class classification w/ sparse labels + Input : labels=[0] predictions=[0.3, 0.6, 0.1] + Output : (np.array([1]), np.array([0.3]), np.array([1.0])) + # Multi-class classification w/ dense labels + Input : labels=[0] predictions=[0.3] + Output : (np.array([0]), np.array([0.3]), np.array([1.0])) + + Args: + ---- + inputs: Standard metric inputs. + eval_config: Eval config + model_name: Optional model name (if multi-model evaluation). + output_name: Optional output name (if multi-output model type). + sub_key: Optional sub key. + aggregation_type: Optional aggregation type. + class_weights: Optional class weights to apply to multi-class / multi-label + labels and predictions. If used, flatten must also be True. + example_weighted: True if example weights should be applied. + fractional_labels: If true, each incoming tuple of (label, prediction, and + example weight) will be split into two tuples as follows (where l, p, w + represent the resulting label, prediction, and example weight values): (1) + l = 0.0, p = prediction, and w = example_weight * (1.0 - label) (2) l = + 1.0, p = prediction, and w = example_weight * label If enabled, an + exception will be raised if labels are not within [0, 1]. The + implementation is such that tuples associated with a weight of zero are + not yielded. This means it is safe to enable fractional_labels even when + the labels only take on the values of 0.0 or 1.0. + flatten: True to flatten the final label and prediction outputs so that the + yielded values are always arrays of size 1. For example, multi-class / + multi-label outputs would be converted into label and prediction pairs + that could then be processed by a binary classification metric in order to + compute a micro average over all classes. If the example weight is not a + scalar, then they will be flattened as well, otherwise the same example + weight value will be output for each pair of labels and predictions. + squeeze: True to squeeze any outputs that have rank > 1. This transforms + outputs such as np.array([[1]]) to np.array([1]). + allow_none: True to allow labels or predictions with None values to be + returned. When used, the values will be returned as empty np.ndarrays. The + example weight will always be non-empty. + require_single_example_weight: True to require that the example_weight be a + single value. + + Yields: + ------ + Tuple of (label, prediction, example_weight). + """ + + def fn_call_str(): + return ( + f"to_label_prediction_example_weight(inputs={inputs}, " + f"eval_config={eval_config}, model_name={model_name}, " + f"output_name={output_name}, sub_key={sub_key}, " + f"aggregation_type={aggregation_type}, " + f"class_weights={class_weights}, " + f"fractional_labels={fractional_labels}, flatten={flatten}, " + f"squeeze={squeeze}, allow_none={allow_none})" + ) + + def optionally_get_by_keys(value: Any, keys: List[str]) -> Any: + class NotFound: + pass + + if isinstance(value, Mapping): + new_value = util.get_by_keys(value, keys, default_value=NotFound()) + if not isinstance(new_value, NotFound): + # Might be None if that's what is in the dict + return new_value + return value + + try: + prediction_key = "" + label_key = "" + if eval_config and eval_config.model_specs: + for spec in eval_config.model_specs: + # To maintain consistency between settings where single models are used, + # always use '' as the model name regardless of whether a name is passed + spec_name = spec.name if len(eval_config.model_specs) > 1 else "" + if spec_name == model_name: + prediction_key = spec.prediction_key + label_key = spec.label_key + break + + label = inputs.label + if label_key: + # This is to support a custom EvalSavedModel where the labels are a dict + # but the keys are not output_names. + label = optionally_get_by_keys(label, [label_key]) + prediction = inputs.prediction + example_weight = inputs.example_weight + if model_name: + if prediction is not None: + prediction = util.get_by_keys(prediction, [model_name]) + # Labels and weights can optionally be keyed by model name. + label = optionally_get_by_keys(label, [model_name]) + example_weight = optionally_get_by_keys(example_weight, [model_name]) + if output_name: + if prediction is not None: + prediction = util.get_by_keys(prediction, [output_name]) + # Labels and example weights can optionally be keyed by output name. + label = optionally_get_by_keys(label, [output_name]) + example_weight = optionally_get_by_keys(example_weight, [output_name]) + + if not example_weighted or example_weight is None: + example_weight = np.array(1.0, dtype=np.float32) # tf-ranking needs float32 + + if isinstance(label, Mapping): + raise ValueError( + "unable to prepare label for metric computation because the label is " + "a dict with unrecognized keys. If a multi-output model was used " + f"check that an output name was provided in all the relevant " + "settings (ModelSpec.label_keys, MetricsSpec.output_names, etc): " + f"label={label}, output_name={output_name}" + ) + if isinstance(example_weight, Mapping): + raise ValueError( + "unable to prepare example_weight for metric computation because the " + "example_weight is a dict with unrecognized keys. If a multi-output " + "model was used check that an output name was provided in all the " + "relevant settings (ModelSpec.example_weight_keys, " + f"MetricsSpec.output_names, etc): example_weight={example_weight}, " + f"output_name={output_name}" + ) + + label, prediction = prepare_labels_and_predictions( + label, prediction, prediction_key + ) + + if not allow_none: + for txt, value in zip(("label", "prediction"), (label, prediction)): + if value is None: + raise ValueError( + f"no value provided for {txt}\n\n" + "This may be caused by a configuration error (i.e. label, " + "and/or prediction keys were not specified) or an " + "error in the pipeline." + ) + + example_weight = util.to_numpy(example_weight) + if require_single_example_weight and example_weight.size > 1: + example_weight = example_weight.flatten() + if not np.all(example_weight == example_weight[0]): + raise ValueError( + "if example_weight size > 0, the values must all be the same: " + f"example_weight={example_weight}\n\n" + "This is most likely a configuration error." + ) + example_weight = np.array(example_weight[0]) + + if sub_key is not None and label is not None and prediction is not None: + if sub_key.k is not None: + indices = top_k_indices(sub_key.k, prediction) + if len(prediction.shape) == 1: + indices = indices[0] # 1D + else: + # 2D, take kth values + indices = (indices[0][0 :: sub_key.k], indices[1][0 :: sub_key.k]) + if label.shape != prediction.shape: + label = one_hot(label, prediction) + label = select_indices(label, indices) + prediction = select_indices(prediction, indices) + else: + if sub_key.top_k is not None: + # Set all non-top-k predictions to -inf. Note that we do not sort. + indices = top_k_indices(sub_key.top_k, prediction) + if aggregation_type is None: + top_k_predictions = np.full(prediction.shape, float("-inf")) + top_k_predictions[indices] = prediction[indices] + prediction = top_k_predictions + else: + if label.shape != prediction.shape: + label = one_hot(label, prediction) + label = select_indices(label, indices) + prediction = select_indices(prediction, indices) + if sub_key.class_id is not None: + label, prediction = select_class_id( + sub_key.class_id, label, prediction + ) + + # For consistency, make sure all outputs are arrays (i.e. convert scalars) + if label is not None and not label.shape: + label = label.reshape((1,)) + if prediction is not None and not prediction.shape: + prediction = prediction.reshape((1,)) + if not example_weight.shape: + example_weight = example_weight.reshape((1,)) + + label = label if label is not None else np.array([]) + prediction = prediction if prediction is not None else np.array([]) + + flatten_size = prediction.size or label.size + if flatten: + if example_weight.size == 1: + example_weight = np.array( + [float(example_weight) for i in range(flatten_size)] + ) + elif example_weight.size != flatten_size: + raise ValueError( + "example_weight size does not match the size of labels and " + f"predictions: label={label}, prediction={prediction}, example_weight={example_weight}" + ) + + if class_weights: + if not flatten: + raise ValueError( + "class_weights can only be used when flatten is also used: " + f"class_weights={class_weights}, flatten={flatten}\n\n" + "This is likely caused by a configuration error (i.e. micro " + "averaging being applied to metrics that don't support micro " + "averaging" + ) + example_weight = np.array( + [ + example_weight[i] * class_weights[i] if i in class_weights else 0.0 + for i in range(flatten_size) + ] + ) + + def yield_results(label, prediction, example_weight): + if ( + not flatten + or (label.size == 0 and prediction.size == 0) + or ( + label.size == 1 + and prediction.size == 1 + and example_weight.size == 1 + ) + ): + if squeeze: + yield ( + _squeeze(label), + _squeeze(prediction), + _squeeze(example_weight), + ) + else: + yield label, prediction, example_weight + elif label.size == 0: + for p, w in zip(prediction.flatten(), example_weight.flatten()): + yield label, np.array([p]), np.array([w]) + elif prediction.size == 0: + for l, w in zip(label.flatten(), example_weight.flatten()): + yield np.array([l]), prediction, np.array([w]) + elif label.size == prediction.size and label.size == example_weight.size: + for l, p, w in zip( + label.flatten(), prediction.flatten(), example_weight.flatten() + ): + yield np.array([l]), np.array([p]), np.array([w]) + elif label.shape[-1] == 1 and prediction.size == example_weight.size: + label = one_hot(label, prediction) + for l, p, w in zip( + label.flatten(), prediction.flatten(), example_weight.flatten() + ): + yield np.array([l]), np.array([p]), np.array([w]) + else: + raise ValueError( + "unable to pair labels, predictions, and example weights: " + f"label={label}, prediction={prediction}, " + f"example_weight={example_weight}\n\n" + "This is most likely a configuration error." + ) + + for result in yield_results(label, prediction, example_weight): + if fractional_labels and label.size: + for new_result in _yield_fractional_labels(*result): + yield new_result + else: + yield result + except Exception as e: + import sys # pylint: disable=g-import-not-at-top + + raise type(e)(str(e) + f"\n\n{fn_call_str()}").with_traceback(sys.exc_info()[2]) def _yield_fractional_labels( label: np.ndarray, prediction: np.ndarray, example_weight: np.ndarray ) -> Iterable[Tuple[np.ndarray, np.ndarray, np.ndarray]]: - """Yields (label, prediction, example_weight) applying fractional labels. - - The incoming label, prediction, and example weights will be split into two - tuples such that if l, p, w represent the resulting tuple values we will get: - (1) l = 0.0, p = prediction, and w = example_weight * (1.0 - label) - (2) l = 1.0, p = prediction, and w = example_weight * label - - Args: - label: Label. - prediction: Prediction. - example_weight: Example weight. - - Raises: - ValueError: If labels are not within [0, 1]. - """ - # Verify that labels are also within [0, 1] - if not within_interval(float(label), 0.0, 1.0): - raise ValueError( - f'label must be within [0, 1]: label={label}, prediction={prediction}, ' - f'example_weight={example_weight}') - for l, w in ((np.array([0], dtype=label.dtype), example_weight * (1 - label)), - (np.array([1], dtype=label.dtype), example_weight * label)): - if not math.isclose(w, 0.0): - yield (l, prediction, w) + """Yields (label, prediction, example_weight) applying fractional labels. + + The incoming label, prediction, and example weights will be split into two + tuples such that if l, p, w represent the resulting tuple values we will get: + (1) l = 0.0, p = prediction, and w = example_weight * (1.0 - label) + (2) l = 1.0, p = prediction, and w = example_weight * label + + Args: + ---- + label: Label. + prediction: Prediction. + example_weight: Example weight. + + Raises: + ------ + ValueError: If labels are not within [0, 1]. + """ + # Verify that labels are also within [0, 1] + if not within_interval(float(label), 0.0, 1.0): + raise ValueError( + f"label must be within [0, 1]: label={label}, prediction={prediction}, " + f"example_weight={example_weight}" + ) + for l, w in ( + (np.array([0], dtype=label.dtype), example_weight * (1 - label)), + (np.array([1], dtype=label.dtype), example_weight * label), + ): + if not math.isclose(w, 0.0): + yield (l, prediction, w) def _squeeze(arr: np.ndarray): - """Squeezes arr while aways returning an array unless 'arr' is a scalar.""" - if arr.shape not in ((), (1,)): - arr = arr.squeeze() - if not arr.shape: - arr = np.expand_dims(arr, axis=0) - return arr + """Squeezes arr while aways returning an array unless 'arr' is a scalar.""" + if arr.shape not in ((), (1,)): + arr = arr.squeeze() + if not arr.shape: + arr = np.expand_dims(arr, axis=0) + return arr def prepare_labels_and_predictions( labels: Any, predictions: Any, prediction_key: Optional[str] = None, - label_vocabulary: Optional[Union[np.ndarray, List[str]]] = None + label_vocabulary: Optional[Union[np.ndarray, List[str]]] = None, ) -> Tuple[np.ndarray, np.ndarray]: - """Prepares labels and predictions for use in calculations. - - If the predictions are a dict (i.e. estimator based output) this function will - apply the necessary lookup based on the prediction_key provided (or using a - default set of common keys such as 'probabilities', etc). Note that the - predictions passed as args must be AFTER the model_name and/or output_name - lookups have been performed. This function also applies any label vocabulary - transformations where possible. - - If successful, the final output of calling this function will be a pair of - numpy arrays representing the labels and predictions. - - Args: - labels: List, np.ndarray, or SparseTensorValue of values (1D, 2D, or 3D). - predictions: List or np.ndarray of prediction values (1D, 2D, or 3D) or a - dict of prediction values keyed by prediction_key or common estimator keys - (logistic, probabilties, etc). - prediction_key: Optional predictions key. Used when the predict output is a - dict. - label_vocabulary: Optional label vocabulary to convert label values to ints - (if prediction is a dict containing an 'all_classes' key that will be used - if label_vocabulary is None). - - Returns: - A (labels, predictions) tuple suitable for metric calculations. - - Raises: - ValueError: If the labels or predictions are in an invalid format. - """ - if isinstance(predictions, Mapping): - if label_vocabulary is None: - if _ALL_CLASSES in predictions: - # Check for use of estimator label vocab under ALL_CLASSES. This was - # added in 06/2019 for eval signatures because the CLASSES only contains - # the label for the chosen class. - label_vocabulary = util.to_numpy(predictions[_ALL_CLASSES]) - elif (tf.saved_model.CLASSIFY_OUTPUT_SCORES in predictions and - tf.saved_model.CLASSIFY_OUTPUT_CLASSES in predictions): - # For classification model using the default serving signature, the - # CLASSES contains the full vocabulary. The check for scores is needed - # here to avoid matching CLASSES in the eval case (scores are not used - # in eval). - label_vocabulary = util.to_numpy( - predictions[tf.saved_model.CLASSIFY_OUTPUT_CLASSES]) - if label_vocabulary is not None: - while len(label_vocabulary.shape) > 1: - label_vocabulary = label_vocabulary[0] # Remove the bach dimensions - if not prediction_key: - # Estimator predictions use dicts of scores, probabilities, classes, etc. - for k in (tf.saved_model.CLASSIFY_OUTPUT_SCORES, - tf.saved_model.REGRESS_OUTPUTS, _PREDICTIONS, _LOGISTIC, - _PROBABILITIES, _LOGITS): - if k in predictions: - predictions = predictions[k] - prediction_key = k - break - elif prediction_key in predictions: - predictions = predictions[prediction_key] - - if isinstance(predictions, Mapping): - raise ValueError( - 'unable to prepare prediction for metric computation because the ' - 'prediction is a dict with unrecognized keys. If a multi-output model ' - 'was used check that an output name was provided in all the relevant ' - 'settings (MetricsSpec.output_names, etc). If the model returns a dict ' - 'for its output and the output does not contain one of the common ' - 'prediction keys (e.g. logistic, probabilities, etc), then ' - 'ModelSpec.prediction_key can be used to specify which key to use for ' - f'the predicted value: prediction={predictions}, ' - f'prediction_key={prediction_key}') - - if predictions is not None: - predictions = util.to_numpy(predictions) - - def _maybe_convert_labels_to_one_hot(labels, predictions): - # String lookups that fail result in a -1 label value. Most metrics - # won't accept this as a valid value so we convert to a one_hot value to - # ensure that we are only working with 0's (i.e. -1 maps to all 0's in - # one-hot). - if labels.size and np.all(labels == -1): - return one_hot(labels, predictions) - return labels - - if labels is not None: - if (isinstance(labels, types.SparseTensorValue) or - isinstance(labels, tf.compat.v1.SparseTensorValue)): - if predictions is None or predictions.size == 0: - raise ValueError('predictions must also be used if labels are of type ' - f'SparseTensorValue: labels={labels}') - values = labels.values if labels.values is not None else np.array([]) - indices = labels.indices if labels.indices is not None else np.array([]) - if label_vocabulary is not None and values.dtype.kind in ('U', 'S', 'O'): - values = _string_labels_to_class_ids(label_vocabulary, values) - # If vocab is used then the values will be the indices into the vocab - # and we should use multi-hot encoding to store the output. We can - # accomplish this by passing 1's for the values and using the values - # converted from the vocab as the indices to insert the 1's at the - # proper offsets in the resulting multi-hot vector. - labels = _to_dense_tensor( - np.ones(values.shape), values, predictions.shape) - labels = _maybe_convert_labels_to_one_hot(labels, predictions) - else: - labels = _to_dense_tensor(values, indices, predictions.shape) - else: - labels = util.to_numpy(labels) - if label_vocabulary is not None and labels.dtype.kind in ('U', 'S', 'O'): - labels = _string_labels_to_class_ids(label_vocabulary, labels) - labels = _maybe_convert_labels_to_one_hot(labels, predictions) + """Prepares labels and predictions for use in calculations. + + If the predictions are a dict (i.e. estimator based output) this function will + apply the necessary lookup based on the prediction_key provided (or using a + default set of common keys such as 'probabilities', etc). Note that the + predictions passed as args must be AFTER the model_name and/or output_name + lookups have been performed. This function also applies any label vocabulary + transformations where possible. + + If successful, the final output of calling this function will be a pair of + numpy arrays representing the labels and predictions. + + Args: + ---- + labels: List, np.ndarray, or SparseTensorValue of values (1D, 2D, or 3D). + predictions: List or np.ndarray of prediction values (1D, 2D, or 3D) or a + dict of prediction values keyed by prediction_key or common estimator keys + (logistic, probabilties, etc). + prediction_key: Optional predictions key. Used when the predict output is a + dict. + label_vocabulary: Optional label vocabulary to convert label values to ints + (if prediction is a dict containing an 'all_classes' key that will be used + if label_vocabulary is None). + + Returns: + ------- + A (labels, predictions) tuple suitable for metric calculations. + + Raises: + ------ + ValueError: If the labels or predictions are in an invalid format. + """ + if isinstance(predictions, Mapping): + if label_vocabulary is None: + if _ALL_CLASSES in predictions: + # Check for use of estimator label vocab under ALL_CLASSES. This was + # added in 06/2019 for eval signatures because the CLASSES only contains + # the label for the chosen class. + label_vocabulary = util.to_numpy(predictions[_ALL_CLASSES]) + elif ( + tf.saved_model.CLASSIFY_OUTPUT_SCORES in predictions + and tf.saved_model.CLASSIFY_OUTPUT_CLASSES in predictions + ): + # For classification model using the default serving signature, the + # CLASSES contains the full vocabulary. The check for scores is needed + # here to avoid matching CLASSES in the eval case (scores are not used + # in eval). + label_vocabulary = util.to_numpy( + predictions[tf.saved_model.CLASSIFY_OUTPUT_CLASSES] + ) + if label_vocabulary is not None: + while len(label_vocabulary.shape) > 1: + label_vocabulary = label_vocabulary[0] # Remove the bach dimensions + if not prediction_key: + # Estimator predictions use dicts of scores, probabilities, classes, etc. + for k in ( + tf.saved_model.CLASSIFY_OUTPUT_SCORES, + tf.saved_model.REGRESS_OUTPUTS, + _PREDICTIONS, + _LOGISTIC, + _PROBABILITIES, + _LOGITS, + ): + if k in predictions: + predictions = predictions[k] + prediction_key = k + break + elif prediction_key in predictions: + predictions = predictions[prediction_key] + + if isinstance(predictions, Mapping): + raise ValueError( + "unable to prepare prediction for metric computation because the " + "prediction is a dict with unrecognized keys. If a multi-output model " + "was used check that an output name was provided in all the relevant " + "settings (MetricsSpec.output_names, etc). If the model returns a dict " + "for its output and the output does not contain one of the common " + "prediction keys (e.g. logistic, probabilities, etc), then " + "ModelSpec.prediction_key can be used to specify which key to use for " + f"the predicted value: prediction={predictions}, " + f"prediction_key={prediction_key}" + ) + + if predictions is not None: + predictions = util.to_numpy(predictions) + + def _maybe_convert_labels_to_one_hot(labels, predictions): + # String lookups that fail result in a -1 label value. Most metrics + # won't accept this as a valid value so we convert to a one_hot value to + # ensure that we are only working with 0's (i.e. -1 maps to all 0's in + # one-hot). + if labels.size and np.all(labels == -1): + return one_hot(labels, predictions) + return labels + + if labels is not None: + if isinstance(labels, types.SparseTensorValue) or isinstance( + labels, tf.compat.v1.SparseTensorValue + ): + if predictions is None or predictions.size == 0: + raise ValueError( + "predictions must also be used if labels are of type " + f"SparseTensorValue: labels={labels}" + ) + values = labels.values if labels.values is not None else np.array([]) + indices = labels.indices if labels.indices is not None else np.array([]) + if label_vocabulary is not None and values.dtype.kind in ("U", "S", "O"): + values = _string_labels_to_class_ids(label_vocabulary, values) + # If vocab is used then the values will be the indices into the vocab + # and we should use multi-hot encoding to store the output. We can + # accomplish this by passing 1's for the values and using the values + # converted from the vocab as the indices to insert the 1's at the + # proper offsets in the resulting multi-hot vector. + labels = _to_dense_tensor( + np.ones(values.shape), values, predictions.shape + ) + labels = _maybe_convert_labels_to_one_hot(labels, predictions) + else: + labels = _to_dense_tensor(values, indices, predictions.shape) + else: + labels = util.to_numpy(labels) + if label_vocabulary is not None and labels.dtype.kind in ("U", "S", "O"): + labels = _string_labels_to_class_ids(label_vocabulary, labels) + labels = _maybe_convert_labels_to_one_hot(labels, predictions) - return (labels, predictions) + return (labels, predictions) -def _to_dense_tensor(values: np.ndarray, indices: np.ndarray, - dense_shape: Tuple[int, ...]) -> np.ndarray: - """Converts sparse tensor to dense given values, indices, and dense shape.""" - # Squeeze is used on the values, indices, and result to ensure that single - # value inputs that still have the batch dimension such as [1, n_classes] can - # still be indexed properly from SparseTensorValues that don't use batching. - result = _squeeze(np.zeros(dense_shape, dtype=values.dtype)) - for value, index in zip(_squeeze(values), _squeeze(indices)): - result[index] = value - return result.reshape(dense_shape) +def _to_dense_tensor( + values: np.ndarray, indices: np.ndarray, dense_shape: Tuple[int, ...] +) -> np.ndarray: + """Converts sparse tensor to dense given values, indices, and dense shape.""" + # Squeeze is used on the values, indices, and result to ensure that single + # value inputs that still have the batch dimension such as [1, n_classes] can + # still be indexed properly from SparseTensorValues that don't use batching. + result = _squeeze(np.zeros(dense_shape, dtype=values.dtype)) + for value, index in zip(_squeeze(values), _squeeze(indices)): + result[index] = value + return result.reshape(dense_shape) -def _string_labels_to_class_ids(label_vocabulary: Union[np.ndarray, List[str]], - labels: np.ndarray) -> np.ndarray: - """Returns class ID for given string label using given classes or -1.""" +def _string_labels_to_class_ids( + label_vocabulary: Union[np.ndarray, List[str]], labels: np.ndarray +) -> np.ndarray: + """Returns class ID for given string label using given classes or -1.""" - def lookup(label): - for i, c in enumerate(label_vocabulary): - if c == label: - return i - return -1 + def lookup(label): + for i, c in enumerate(label_vocabulary): + if c == label: + return i + return -1 - return np.array([lookup(l) for l in labels.flatten()]).reshape(labels.shape) + return np.array([lookup(l) for l in labels.flatten()]).reshape(labels.shape) def select_class_id( # pytype: disable=annotation-type-mismatch @@ -915,72 +1004,78 @@ def select_class_id( # pytype: disable=annotation-type-mismatch predictions: Any, sparse_labels: bool = None, ) -> Tuple[np.ndarray, np.ndarray]: - """Selects values for given class ID from multi-class labels and predictions. - - Args: - class_id: Class ID to filter the labels and predictions by. - labels: Array or list of processed labels (1D, 2D, or 3D). - predictions: Array or list of processed predictions (1D, 2D, or 3D). - sparse_labels: True if sparse labels are being used. If None then the - sparseness will be inferred from the shapes of the labels and predictions - (i.e. if the shapes are different then the labels will be assumed to be - sparse). - - Returns: - A (labels, predictions) tuple with the predictions returned in the same form - as the originals (except for the last dimension which will be 1). - - Raises: - ValueError: If the labels or predictions cannot be formatted properly. - """ - labels = util.to_numpy(labels) - predictions = util.to_numpy(predictions) - if labels.size == 0 or predictions.size == 0: - return (labels, predictions) + """Selects values for given class ID from multi-class labels and predictions. + + Args: + ---- + class_id: Class ID to filter the labels and predictions by. + labels: Array or list of processed labels (1D, 2D, or 3D). + predictions: Array or list of processed predictions (1D, 2D, or 3D). + sparse_labels: True if sparse labels are being used. If None then the + sparseness will be inferred from the shapes of the labels and predictions + (i.e. if the shapes are different then the labels will be assumed to be + sparse). + + Returns: + ------- + A (labels, predictions) tuple with the predictions returned in the same form + as the originals (except for the last dimension which will be 1). + + Raises: + ------ + ValueError: If the labels or predictions cannot be formatted properly. + """ + labels = util.to_numpy(labels) + predictions = util.to_numpy(predictions) + if labels.size == 0 or predictions.size == 0: + return (labels, predictions) + + def lookup(arr, target): + if class_id < 0 or class_id >= len(arr): + raise ValueError(f'class_id "{class_id}" out of range of {target}: {arr}') + return arr[class_id] + + # Convert scalars to arrays + if not labels.shape: + labels = labels.reshape((1,)) + if not predictions.shape: + predictions = predictions.reshape((1,)) + + sparse_labels = _verify_sparse_labels( + labels, predictions, sparse_labels=sparse_labels + ) + if sparse_labels and labels.shape[-1] != 1: + # Convert to [[class_id1], ...] + labels = labels.reshape((-1, 1)) + + labels_out_shape = list(labels.shape) + labels_out_shape[-1] = 1 + predictions_out_shape = list(predictions.shape) + predictions_out_shape[-1] = 1 + + # Convert labels and predictions into the form ([[...], [...]]) + if len(labels.shape) > 1: + # Flatten all but the last dim (a, b, c) -> (a * b, c) + labels = labels.reshape((-1, labels.shape[-1])) + else: + labels = labels.reshape((1, labels.shape[0])) + if len(predictions.shape) > 1: + predictions = predictions.reshape((-1, predictions.shape[-1])) + else: + predictions = predictions.reshape((1, predictions.shape[0])) - def lookup(arr, target): - if class_id < 0 or class_id >= len(arr): - raise ValueError(f'class_id "{class_id}" out of range of {target}: {arr}') - return arr[class_id] - - # Convert scalars to arrays - if not labels.shape: - labels = labels.reshape((1,)) - if not predictions.shape: - predictions = predictions.reshape((1,)) - - sparse_labels = _verify_sparse_labels( - labels, predictions, sparse_labels=sparse_labels) - if sparse_labels and labels.shape[-1] != 1: - # Convert to [[class_id1], ...] - labels = labels.reshape((-1, 1)) - - labels_out_shape = list(labels.shape) - labels_out_shape[-1] = 1 - predictions_out_shape = list(predictions.shape) - predictions_out_shape[-1] = 1 - - # Convert labels and predictions into the form ([[...], [...]]) - if len(labels.shape) > 1: - # Flatten all but the last dim (a, b, c) -> (a * b, c) - labels = labels.reshape((-1, labels.shape[-1])) - else: - labels = labels.reshape((1, labels.shape[0])) - if len(predictions.shape) > 1: - predictions = predictions.reshape((-1, predictions.shape[-1])) - else: - predictions = predictions.reshape((1, predictions.shape[0])) - - if sparse_labels: - # Labels are of the form [[class_id1], [class_id2], ...] - labels = np.array([int(l[0] == class_id) for l in labels]) - else: - # Labels are of the form [[0, 0, 1, ...], [0, 0, 0, ...], ...] - labels = np.array([lookup(l, 'labels') for l in labels]) - predictions = np.array([lookup(p, 'predictions') for p in predictions]) - - return (labels.reshape(labels_out_shape), - predictions.reshape(predictions_out_shape)) + if sparse_labels: + # Labels are of the form [[class_id1], [class_id2], ...] + labels = np.array([int(l[0] == class_id) for l in labels]) + else: + # Labels are of the form [[0, 0, 1, ...], [0, 0, 0, ...], ...] + labels = np.array([lookup(l, "labels") for l in labels]) + predictions = np.array([lookup(p, "predictions") for p in predictions]) + + return ( + labels.reshape(labels_out_shape), + predictions.reshape(predictions_out_shape), + ) def _verify_sparse_labels( @@ -988,127 +1083,150 @@ def _verify_sparse_labels( predictions: np.ndarray, sparse_labels: bool = None, ) -> bool: - """Checks if labels are sparse or not. - - Args: - labels: Numpy array of labels. - predictions: Numpy array of predictions. - sparse_labels: True if sparse labels should be used. If None then the - sparseness will be inferred from the shapes of the labels and predictions - (i.e. if the shapes are different then the labels will be assumed to be - sparse). - - Returns: - True if sparse. - - Raises: - ValueError: If the sparse_labels setting does not match labels and - predictions. - """ - if (len(labels.shape) != len(predictions.shape) or - labels.shape[-1] != predictions.shape[-1]): - # Labels are of the form [class_id1, ...] - # - # Note that it is possible that the labels could be multi-label of the form - # [[class_id1, class_id2], [class_id3, ...]]. However, this would require a - # ragged or sparse tensor input to support. Ragged data in np.array is - # encoded as a list object which doesn't accurately reflect the shape (i.e. - # np.array([[1, 2], [3]]) has shape (2,). As such we will assume that all - # multi-label use cases will use one-hot encodings (e.g. [0, 1. 1, ...]) and - # will update this if/when RaggedTensorValue and SparseTensorValue value - # types are supported in addition to np.ndarray. - if sparse_labels is not None and not sparse_labels: - raise ValueError( - 'The number of labels = 1, but sparse labels are not being used\n\n' - 'This is likley caused by a metric configuration error. Change to ' - 'use a non-sparse versions of the metrics or ensure that sparse ' - f'labels are passed as input: labels={labels}, ' - f'predictions={predictions}') - return True - else: - # Labels are of the form [0, 0, 1, ...] (i.e. one-hot). This includes - # regression and binary classification. - # - # Similar to the note above, this is only true if multi-label inputs are - # always encoded as one-hot since it is possible for a multi-label encoding - # using [class_id1, ...] to use all the classes and therefore match the - # prediction's last dimension in length. - if sparse_labels is not None and sparse_labels: - raise ValueError( - 'The number of labels > 1, but sparse labels are being used\n\n' - 'This is likley caused by a metric configuration error. Change to ' - 'use sparse versions of the metrics or ensure that non-sparse labels ' - f'are passed as input: labels={labels}, predictions={predictions}') - return False + """Checks if labels are sparse or not. + + Args: + ---- + labels: Numpy array of labels. + predictions: Numpy array of predictions. + sparse_labels: True if sparse labels should be used. If None then the + sparseness will be inferred from the shapes of the labels and predictions + (i.e. if the shapes are different then the labels will be assumed to be + sparse). + + Returns: + ------- + True if sparse. + + Raises: + ------ + ValueError: If the sparse_labels setting does not match labels and + predictions. + """ + if ( + len(labels.shape) != len(predictions.shape) + or labels.shape[-1] != predictions.shape[-1] + ): + # Labels are of the form [class_id1, ...] + # + # Note that it is possible that the labels could be multi-label of the form + # [[class_id1, class_id2], [class_id3, ...]]. However, this would require a + # ragged or sparse tensor input to support. Ragged data in np.array is + # encoded as a list object which doesn't accurately reflect the shape (i.e. + # np.array([[1, 2], [3]]) has shape (2,). As such we will assume that all + # multi-label use cases will use one-hot encodings (e.g. [0, 1. 1, ...]) and + # will update this if/when RaggedTensorValue and SparseTensorValue value + # types are supported in addition to np.ndarray. + if sparse_labels is not None and not sparse_labels: + raise ValueError( + "The number of labels = 1, but sparse labels are not being used\n\n" + "This is likley caused by a metric configuration error. Change to " + "use a non-sparse versions of the metrics or ensure that sparse " + f"labels are passed as input: labels={labels}, " + f"predictions={predictions}" + ) + return True + else: + # Labels are of the form [0, 0, 1, ...] (i.e. one-hot). This includes + # regression and binary classification. + # + # Similar to the note above, this is only true if multi-label inputs are + # always encoded as one-hot since it is possible for a multi-label encoding + # using [class_id1, ...] to use all the classes and therefore match the + # prediction's last dimension in length. + if sparse_labels is not None and sparse_labels: + raise ValueError( + "The number of labels > 1, but sparse labels are being used\n\n" + "This is likley caused by a metric configuration error. Change to " + "use sparse versions of the metrics or ensure that non-sparse labels " + f"are passed as input: labels={labels}, predictions={predictions}" + ) + return False def one_hot(tensor: np.ndarray, target: np.ndarray) -> np.ndarray: - """Convert tensor's last dimension into a one-hot vector of target's shape. - - Args: - tensor: Tensor to convert to one-hot vector. Must have no shape or a final - dimension of size 1. - target: Target tensor to reshape the tensor to. - - Returns: - Tensor with last dimension encoded as a one-hot vector with the overall - shape the same as that of target. - """ - try: - # For values that are OOV (i.e. set to -1) we will use a vector of all 0's. - # When np.eye is indexed by -1, a value of all 0's followed by 1 is used for - # the row. The following handles -1 values by adding an additional column - # for indexing the -1 and then removing it after. - tensor = np.delete( - np.eye(target.shape[-1] + 1)[tensor.astype(int)], -1, axis=-1) - return tensor.reshape(target.shape) - except IndexError as e: - raise ValueError( - f'invalid inputs to one_hot: tensor={tensor}, target={target}') from e + """Convert tensor's last dimension into a one-hot vector of target's shape. + + Args: + ---- + tensor: Tensor to convert to one-hot vector. Must have no shape or a final + dimension of size 1. + target: Target tensor to reshape the tensor to. + + Returns: + ------- + Tensor with last dimension encoded as a one-hot vector with the overall + shape the same as that of target. + """ + try: + # For values that are OOV (i.e. set to -1) we will use a vector of all 0's. + # When np.eye is indexed by -1, a value of all 0's followed by 1 is used for + # the row. The following handles -1 values by adding an additional column + # for indexing the -1 and then removing it after. + tensor = np.delete( + np.eye(target.shape[-1] + 1)[tensor.astype(int)], -1, axis=-1 + ) + return tensor.reshape(target.shape) + except IndexError as e: + raise ValueError( + f"invalid inputs to one_hot: tensor={tensor}, target={target}" + ) from e def merge_per_key_computations( create_computations_fn: Callable[..., metric_types.MetricComputations], ) -> Callable[..., metric_types.MetricComputations]: - """Wraps create_computations_fn to be called separately for each key.""" - - def merge_computations_fn( - eval_config: Optional[config_pb2.EvalConfig] = None, - schema: Optional[schema_pb2.Schema] = None, - model_names: Optional[List[str]] = None, - output_names: Optional[List[str]] = None, - sub_keys: Optional[List[Optional[metric_types.SubKey]]] = None, - aggregation_type: Optional[metric_types.AggregationType] = None, - class_weights: Optional[Dict[int, float]] = None, - example_weighted: bool = False, - query_key: Optional[str] = None, - **kwargs) -> metric_types.MetricComputations: - """Merge computations function.""" - if model_names is None: - model_names = [''] - if output_names is None: - output_names = [''] - if sub_keys is None: - sub_keys = [None] - computations = [] - for model_name in model_names: - for output_name in output_names: - for sub_key in sub_keys: - if hasattr(inspect, 'getfullargspec'): - args = inspect.getfullargspec(create_computations_fn).args - else: - args = inspect.getargspec(create_computations_fn).args # pylint: disable=deprecated-method - updated_kwargs = metric_types.validate_and_update_create_computations_fn_kwargs( - args, kwargs.copy(), eval_config, schema, model_names, - output_names, sub_keys, aggregation_type, class_weights, - example_weighted, query_key) - if 'model_name' in args: - updated_kwargs['model_name'] = model_name - if 'output_name' in args: - updated_kwargs['output_name'] = output_name - if 'sub_key' in args: - updated_kwargs['sub_key'] = sub_key - computations.extend(create_computations_fn(**updated_kwargs)) - return computations - - return merge_computations_fn + """Wraps create_computations_fn to be called separately for each key.""" + + def merge_computations_fn( + eval_config: Optional[config_pb2.EvalConfig] = None, + schema: Optional[schema_pb2.Schema] = None, + model_names: Optional[List[str]] = None, + output_names: Optional[List[str]] = None, + sub_keys: Optional[List[Optional[metric_types.SubKey]]] = None, + aggregation_type: Optional[metric_types.AggregationType] = None, + class_weights: Optional[Dict[int, float]] = None, + example_weighted: bool = False, + query_key: Optional[str] = None, + **kwargs, + ) -> metric_types.MetricComputations: + """Merge computations function.""" + if model_names is None: + model_names = [""] + if output_names is None: + output_names = [""] + if sub_keys is None: + sub_keys = [None] + computations = [] + for model_name in model_names: + for output_name in output_names: + for sub_key in sub_keys: + if hasattr(inspect, "getfullargspec"): + args = inspect.getfullargspec(create_computations_fn).args + else: + args = inspect.getargspec(create_computations_fn).args # pylint: disable=deprecated-method + updated_kwargs = ( + metric_types.validate_and_update_create_computations_fn_kwargs( + args, + kwargs.copy(), + eval_config, + schema, + model_names, + output_names, + sub_keys, + aggregation_type, + class_weights, + example_weighted, + query_key, + ) + ) + if "model_name" in args: + updated_kwargs["model_name"] = model_name + if "output_name" in args: + updated_kwargs["output_name"] = output_name + if "sub_key" in args: + updated_kwargs["sub_key"] = sub_key + computations.extend(create_computations_fn(**updated_kwargs)) + return computations + + return merge_computations_fn diff --git a/tensorflow_model_analysis/metrics/metric_util_test.py b/tensorflow_model_analysis/metrics/metric_util_test.py index 759f63f93e..67ede0cc57 100644 --- a/tensorflow_model_analysis/metrics/metric_util_test.py +++ b/tensorflow_model_analysis/metrics/metric_util_test.py @@ -15,898 +15,948 @@ import numpy as np import tensorflow as tf + from tensorflow_model_analysis.api import types -from tensorflow_model_analysis.metrics import metric_types -from tensorflow_model_analysis.metrics import metric_util +from tensorflow_model_analysis.metrics import metric_types, metric_util from tensorflow_model_analysis.proto import config_pb2 class UtilTest(tf.test.TestCase): - - def testNameGeneratorFromArguments(self): - # Basic usage - self.assertEqual( - metric_util.generate_private_name_from_arguments('', threshold=[0.5]), - '_:threshold=[0.5]', - ) - # Private case with private name - self.assertEqual( - metric_util.generate_private_name_from_arguments( - '_private', threshold=[0.5] - ), - '_private:threshold=[0.5]', - ) - # Multiple arguments - self.assertEqual( - metric_util.generate_private_name_from_arguments( - '_private', threshold=[0.5], class_id=[0], class_type=None - ), - '_private:class_id=[0],threshold=[0.5]', - ) - - def testToScalar(self): - self.assertEqual(1, metric_util.to_scalar(np.array([1]))) - self.assertEqual(1.0, metric_util.to_scalar(np.array(1.0))) - self.assertEqual('string', metric_util.to_scalar(np.array([['string']]))) - sparse_tensor = types.SparseTensorValue( - indices=np.array([0]), values=np.array([1]), dense_shape=np.array([1]) - ) - self.assertEqual(1, metric_util.to_scalar(sparse_tensor)) - - def testSafeToScalar(self): - self.assertEqual(1, metric_util.safe_to_scalar(np.array([1]))) - self.assertEqual(1.0, metric_util.safe_to_scalar(np.array(1.0))) - self.assertEqual( - 'string', metric_util.safe_to_scalar(np.array([['string']])) - ) - self.assertEqual(0.0, metric_util.safe_to_scalar(np.array([]))) - self.assertEqual(0.0, metric_util.safe_to_scalar([])) - with self.assertRaisesRegex( - ValueError, 'Array should have exactly 1 value to a Python scalar' - ): - _ = 1, metric_util.safe_to_scalar([1]) - with self.assertRaisesRegex( - ValueError, 'Array should have exactly 1 value to a Python scalar' - ): - _ = metric_util.safe_to_scalar([1, 2]) - with self.assertRaisesRegex( - ValueError, 'Array should have exactly 1 value to a Python scalar' - ): - _ = metric_util.safe_to_scalar(np.array([1, 2])) - - def testPadNoChange(self): - self.assertAllClose( - np.array([1.0, 2.0]), metric_util.pad(np.array([1.0, 2.0]), 2, -1.0) - ) - - def testPad1DSingleValue(self): - self.assertAllClose( - np.array([1.0, -1.0]), metric_util.pad(np.array([1.0]), 2, -1.0) - ) - - def testPad1DMultipleValues(self): - self.assertAllClose( - np.array([1.0, 2.0, -1.0, -1.0]), - metric_util.pad(np.array([1.0, 2.0]), 4, -1.0), - ) - - def testPad2D(self): - self.assertAllClose( - np.array([[1.0, 2.0, 0.0, 0.0, 0.0], [3.0, 4.0, 0.0, 0.0, 0.0]]), - metric_util.pad(np.array([[1.0, 2.0], [3.0, 4.0]]), 5, 0.0), - ) - - def testStandardMetricInputsToNumpy(self): - example = metric_types.StandardMetricInputs( - label={'output_name': np.array([2])}, - prediction={'output_name': np.array([0, 0.5, 0.3, 0.9])}, - example_weight={'output_name': np.array([1.0])}, - ) - iterator = metric_util.to_label_prediction_example_weight( - example, output_name='output_name' - ) - - for expected_label, expected_prediction in zip( - (0.0, 0.0, 1.0, 0.0), (0.0, 0.5, 0.3, 0.9) - ): - got_label, got_pred, got_example_weight = next(iterator) - self.assertAllClose(got_label, np.array([expected_label])) - self.assertAllClose(got_pred, np.array([expected_prediction])) - self.assertAllClose(got_example_weight, np.array([1.0])) - - def testStandardMetricInputsToNumpyWithoutFlatten(self): - example = metric_types.StandardMetricInputs( - labels={'output_name': np.array([2])}, - predictions={'output_name': np.array([0, 0.5, 0.3, 0.9])}, - example_weights={'output_name': np.array([1.0])}) - got_label, got_pred, got_example_weight = next( - metric_util.to_label_prediction_example_weight( - example, output_name='output_name', flatten=False)) - - self.assertAllClose(got_label, np.array([2])) - self.assertAllClose(got_pred, np.array([0, 0.5, 0.3, 0.9])) - self.assertAllClose(got_example_weight, np.array([1.0])) - - def testStandardMetricInputsToNumpyWithoutFlattenAndWithSqueeze(self): - example = metric_types.StandardMetricInputs( - labels={'output_name': np.array([[2]])}, - predictions={'output_name': np.array([[0, 0.5, 0.3, 0.9]])}, - example_weights={'output_name': np.array([1.0])}) - got_label, got_pred, got_example_weight = next( - metric_util.to_label_prediction_example_weight( - example, output_name='output_name', flatten=False)) - - self.assertAllClose(got_label, np.array([2])) - self.assertAllClose(got_pred, np.array([0, 0.5, 0.3, 0.9])) - self.assertAllClose(got_example_weight, np.array([1.0])) - - def testStandardMetricInputsToNumpyWithoutFlattenAndWithoutSqueeze(self): - example = metric_types.StandardMetricInputs( - labels={'output_name': np.array([[2]])}, - predictions={'output_name': np.array([[0, 0.5, 0.3, 0.9]])}, - example_weights={'output_name': np.array([1.0])}) - got_label, got_pred, got_example_weight = next( - metric_util.to_label_prediction_example_weight( - example, output_name='output_name', flatten=False, squeeze=False)) - - self.assertAllClose(got_label, np.array([[2]])) - self.assertAllClose(got_pred, np.array([[0, 0.5, 0.3, 0.9]])) - self.assertAllClose(got_example_weight, np.array([1.0])) - - def testStandardMetricInputsWithZeroWeightsToNumpy(self): - example = metric_types.StandardMetricInputs( - labels=np.array([2]), - predictions=np.array([0, 0.5, 0.3, 0.9]), - example_weights=np.array([0.0])) - iterator = metric_util.to_label_prediction_example_weight( - example, example_weighted=True) - - for expected_label, expected_prediction in zip((0.0, 0.0, 1.0, 0.0), - (0.0, 0.5, 0.3, 0.9)): - got_label, got_pred, got_example_weight = next(iterator) - self.assertAllClose(got_label, np.array([expected_label])) - self.assertAllClose(got_pred, np.array([expected_prediction])) - self.assertAllClose(got_example_weight, np.array([0.0])) - - def testStandardMetricInputsWithSparseTensorValue(self): - example = metric_types.StandardMetricInputs( - labels=types.SparseTensorValue( - values=np.array([1]), - indices=np.array([2]), - dense_shape=np.array([0, 1])), - predictions=np.array([0, 0.5, 0.3, 0.9]), - example_weights=np.array([0.0])) - iterator = metric_util.to_label_prediction_example_weight( - example, example_weighted=True) - - for expected_label, expected_prediction in zip((0.0, 0.0, 1.0, 0.0), - (0.0, 0.5, 0.3, 0.9)): - got_label, got_pred, got_example_weight = next(iterator) - self.assertAllClose(got_label, np.array([expected_label])) - self.assertAllClose(got_pred, np.array([expected_prediction])) - self.assertAllClose(got_example_weight, np.array([0.0])) - - def testStandardMetricInputsWithZeroWeightsToNumpyWithoutFlatten(self): - example = metric_types.StandardMetricInputs( - labels=np.array([2]), - predictions=np.array([0, 0.5, 0.3, 0.9]), - example_weights=np.array([0.0])) - got_label, got_pred, got_example_weight = next( - metric_util.to_label_prediction_example_weight( - example, flatten=False, example_weighted=True)) - - self.assertAllClose(got_label, np.array([2])) - self.assertAllClose(got_pred, np.array([0, 0.5, 0.3, 0.9])) - self.assertAllClose(got_example_weight, np.array([0.0])) - - def testStandardMetricInputsWithClassIDToNumpy(self): - example = metric_types.StandardMetricInputs( - labels={'output_name': np.array([2])}, - predictions={'output_name': np.array([0, 0.5, 0.3, 0.9])}, - example_weights={'output_name': np.array([1.0])}) - got_label, got_pred, got_example_weight = next( - metric_util.to_label_prediction_example_weight( - example, - output_name='output_name', - sub_key=metric_types.SubKey(class_id=2))) - - self.assertAllClose(got_label, np.array([1.0])) - self.assertAllClose(got_pred, np.array([0.3])) - self.assertAllClose(got_example_weight, np.array([1.0])) - - def testStandardMetricInputsWithKToNumpy(self): - example = metric_types.StandardMetricInputs( - labels={'output_name': np.array([2])}, - predictions={'output_name': np.array([0, 0.5, 0.3, 0.9])}, - example_weights={'output_name': np.array([1.0])}) - got_label, got_pred, got_example_weight = next( - metric_util.to_label_prediction_example_weight( + def testNameGeneratorFromArguments(self): + # Basic usage + self.assertEqual( + metric_util.generate_private_name_from_arguments("", threshold=[0.5]), + "_:threshold=[0.5]", + ) + # Private case with private name + self.assertEqual( + metric_util.generate_private_name_from_arguments( + "_private", threshold=[0.5] + ), + "_private:threshold=[0.5]", + ) + # Multiple arguments + self.assertEqual( + metric_util.generate_private_name_from_arguments( + "_private", threshold=[0.5], class_id=[0], class_type=None + ), + "_private:class_id=[0],threshold=[0.5]", + ) + + def testToScalar(self): + self.assertEqual(1, metric_util.to_scalar(np.array([1]))) + self.assertEqual(1.0, metric_util.to_scalar(np.array(1.0))) + self.assertEqual("string", metric_util.to_scalar(np.array([["string"]]))) + sparse_tensor = types.SparseTensorValue( + indices=np.array([0]), values=np.array([1]), dense_shape=np.array([1]) + ) + self.assertEqual(1, metric_util.to_scalar(sparse_tensor)) + + def testSafeToScalar(self): + self.assertEqual(1, metric_util.safe_to_scalar(np.array([1]))) + self.assertEqual(1.0, metric_util.safe_to_scalar(np.array(1.0))) + self.assertEqual("string", metric_util.safe_to_scalar(np.array([["string"]]))) + self.assertEqual(0.0, metric_util.safe_to_scalar(np.array([]))) + self.assertEqual(0.0, metric_util.safe_to_scalar([])) + with self.assertRaisesRegex( + ValueError, "Array should have exactly 1 value to a Python scalar" + ): + _ = 1, metric_util.safe_to_scalar([1]) + with self.assertRaisesRegex( + ValueError, "Array should have exactly 1 value to a Python scalar" + ): + _ = metric_util.safe_to_scalar([1, 2]) + with self.assertRaisesRegex( + ValueError, "Array should have exactly 1 value to a Python scalar" + ): + _ = metric_util.safe_to_scalar(np.array([1, 2])) + + def testPadNoChange(self): + self.assertAllClose( + np.array([1.0, 2.0]), metric_util.pad(np.array([1.0, 2.0]), 2, -1.0) + ) + + def testPad1DSingleValue(self): + self.assertAllClose( + np.array([1.0, -1.0]), metric_util.pad(np.array([1.0]), 2, -1.0) + ) + + def testPad1DMultipleValues(self): + self.assertAllClose( + np.array([1.0, 2.0, -1.0, -1.0]), + metric_util.pad(np.array([1.0, 2.0]), 4, -1.0), + ) + + def testPad2D(self): + self.assertAllClose( + np.array([[1.0, 2.0, 0.0, 0.0, 0.0], [3.0, 4.0, 0.0, 0.0, 0.0]]), + metric_util.pad(np.array([[1.0, 2.0], [3.0, 4.0]]), 5, 0.0), + ) + + def testStandardMetricInputsToNumpy(self): + example = metric_types.StandardMetricInputs( + label={"output_name": np.array([2])}, + prediction={"output_name": np.array([0, 0.5, 0.3, 0.9])}, + example_weight={"output_name": np.array([1.0])}, + ) + iterator = metric_util.to_label_prediction_example_weight( + example, output_name="output_name" + ) + + for expected_label, expected_prediction in zip( + (0.0, 0.0, 1.0, 0.0), (0.0, 0.5, 0.3, 0.9) + ): + got_label, got_pred, got_example_weight = next(iterator) + self.assertAllClose(got_label, np.array([expected_label])) + self.assertAllClose(got_pred, np.array([expected_prediction])) + self.assertAllClose(got_example_weight, np.array([1.0])) + + def testStandardMetricInputsToNumpyWithoutFlatten(self): + example = metric_types.StandardMetricInputs( + labels={"output_name": np.array([2])}, + predictions={"output_name": np.array([0, 0.5, 0.3, 0.9])}, + example_weights={"output_name": np.array([1.0])}, + ) + got_label, got_pred, got_example_weight = next( + metric_util.to_label_prediction_example_weight( + example, output_name="output_name", flatten=False + ) + ) + + self.assertAllClose(got_label, np.array([2])) + self.assertAllClose(got_pred, np.array([0, 0.5, 0.3, 0.9])) + self.assertAllClose(got_example_weight, np.array([1.0])) + + def testStandardMetricInputsToNumpyWithoutFlattenAndWithSqueeze(self): + example = metric_types.StandardMetricInputs( + labels={"output_name": np.array([[2]])}, + predictions={"output_name": np.array([[0, 0.5, 0.3, 0.9]])}, + example_weights={"output_name": np.array([1.0])}, + ) + got_label, got_pred, got_example_weight = next( + metric_util.to_label_prediction_example_weight( + example, output_name="output_name", flatten=False + ) + ) + + self.assertAllClose(got_label, np.array([2])) + self.assertAllClose(got_pred, np.array([0, 0.5, 0.3, 0.9])) + self.assertAllClose(got_example_weight, np.array([1.0])) + + def testStandardMetricInputsToNumpyWithoutFlattenAndWithoutSqueeze(self): + example = metric_types.StandardMetricInputs( + labels={"output_name": np.array([[2]])}, + predictions={"output_name": np.array([[0, 0.5, 0.3, 0.9]])}, + example_weights={"output_name": np.array([1.0])}, + ) + got_label, got_pred, got_example_weight = next( + metric_util.to_label_prediction_example_weight( + example, output_name="output_name", flatten=False, squeeze=False + ) + ) + + self.assertAllClose(got_label, np.array([[2]])) + self.assertAllClose(got_pred, np.array([[0, 0.5, 0.3, 0.9]])) + self.assertAllClose(got_example_weight, np.array([1.0])) + + def testStandardMetricInputsWithZeroWeightsToNumpy(self): + example = metric_types.StandardMetricInputs( + labels=np.array([2]), + predictions=np.array([0, 0.5, 0.3, 0.9]), + example_weights=np.array([0.0]), + ) + iterator = metric_util.to_label_prediction_example_weight( + example, example_weighted=True + ) + + for expected_label, expected_prediction in zip( + (0.0, 0.0, 1.0, 0.0), (0.0, 0.5, 0.3, 0.9) + ): + got_label, got_pred, got_example_weight = next(iterator) + self.assertAllClose(got_label, np.array([expected_label])) + self.assertAllClose(got_pred, np.array([expected_prediction])) + self.assertAllClose(got_example_weight, np.array([0.0])) + + def testStandardMetricInputsWithSparseTensorValue(self): + example = metric_types.StandardMetricInputs( + labels=types.SparseTensorValue( + values=np.array([1]), + indices=np.array([2]), + dense_shape=np.array([0, 1]), + ), + predictions=np.array([0, 0.5, 0.3, 0.9]), + example_weights=np.array([0.0]), + ) + iterator = metric_util.to_label_prediction_example_weight( + example, example_weighted=True + ) + + for expected_label, expected_prediction in zip( + (0.0, 0.0, 1.0, 0.0), (0.0, 0.5, 0.3, 0.9) + ): + got_label, got_pred, got_example_weight = next(iterator) + self.assertAllClose(got_label, np.array([expected_label])) + self.assertAllClose(got_pred, np.array([expected_prediction])) + self.assertAllClose(got_example_weight, np.array([0.0])) + + def testStandardMetricInputsWithZeroWeightsToNumpyWithoutFlatten(self): + example = metric_types.StandardMetricInputs( + labels=np.array([2]), + predictions=np.array([0, 0.5, 0.3, 0.9]), + example_weights=np.array([0.0]), + ) + got_label, got_pred, got_example_weight = next( + metric_util.to_label_prediction_example_weight( + example, flatten=False, example_weighted=True + ) + ) + + self.assertAllClose(got_label, np.array([2])) + self.assertAllClose(got_pred, np.array([0, 0.5, 0.3, 0.9])) + self.assertAllClose(got_example_weight, np.array([0.0])) + + def testStandardMetricInputsWithClassIDToNumpy(self): + example = metric_types.StandardMetricInputs( + labels={"output_name": np.array([2])}, + predictions={"output_name": np.array([0, 0.5, 0.3, 0.9])}, + example_weights={"output_name": np.array([1.0])}, + ) + got_label, got_pred, got_example_weight = next( + metric_util.to_label_prediction_example_weight( + example, + output_name="output_name", + sub_key=metric_types.SubKey(class_id=2), + ) + ) + + self.assertAllClose(got_label, np.array([1.0])) + self.assertAllClose(got_pred, np.array([0.3])) + self.assertAllClose(got_example_weight, np.array([1.0])) + + def testStandardMetricInputsWithKToNumpy(self): + example = metric_types.StandardMetricInputs( + labels={"output_name": np.array([2])}, + predictions={"output_name": np.array([0, 0.5, 0.3, 0.9])}, + example_weights={"output_name": np.array([1.0])}, + ) + got_label, got_pred, got_example_weight = next( + metric_util.to_label_prediction_example_weight( + example, output_name="output_name", sub_key=metric_types.SubKey(k=2) + ) + ) + + self.assertAllClose(got_label, np.array([0.0])) + self.assertAllClose(got_pred, np.array([0.5])) + self.assertAllClose(got_example_weight, np.array([1.0])) + + def testStandardMetricInputsWithKToNumpy2D(self): + example = metric_types.StandardMetricInputs( + labels={"output_name": np.array([1, 2])}, + predictions={ + "output_name": np.array([[0, 0.5, 0.3, 0.9], [0.1, 0.4, 0.2, 0.3]]) + }, + example_weights={"output_name": np.array([1.0])}, + ) + got_label, got_pred, got_example_weight = next( + metric_util.to_label_prediction_example_weight( + example, + output_name="output_name", + sub_key=metric_types.SubKey(k=2), + flatten=False, + squeeze=False, + ) + ) + + self.assertAllClose(got_label, np.array([[1], [0]])) + self.assertAllClose(got_pred, np.array([[0.5], [0.3]])) + self.assertAllClose(got_example_weight, np.array([1.0])) + + def testStandardMetricInputsWithTopKToNumpy(self): + example = metric_types.StandardMetricInputs( + labels={"output_name": np.array([1])}, + predictions={"output_name": np.array([0, 0.5, 0.3, 0.9])}, + example_weights={"output_name": np.array([1.0])}, + ) + iterator = metric_util.to_label_prediction_example_weight( + example, output_name="output_name", sub_key=metric_types.SubKey(top_k=2) + ) + + for expected_label, expected_prediction in zip( + (0.0, 1.0, 0.0, 0.0), (float("-inf"), 0.5, float("-inf"), 0.9) + ): + got_label, got_pred, got_example_weight = next(iterator) + self.assertAllClose(got_label, np.array([expected_label])) + self.assertAllClose(got_pred, np.array([expected_prediction])) + self.assertAllClose(got_example_weight, np.array([1.0])) + + def testStandardMetricInputsWithTopKAndClassIdToNumpy(self): + example = metric_types.StandardMetricInputs( + labels={"output_name": np.array([1])}, + predictions={"output_name": np.array([0, 0.5, 0.3, 0.9])}, + example_weights={"output_name": np.array([1.0])}, + ) + iterator = metric_util.to_label_prediction_example_weight( example, - output_name='output_name', - sub_key=metric_types.SubKey(k=2))) - - self.assertAllClose(got_label, np.array([0.0])) - self.assertAllClose(got_pred, np.array([0.5])) - self.assertAllClose(got_example_weight, np.array([1.0])) - - def testStandardMetricInputsWithKToNumpy2D(self): - example = metric_types.StandardMetricInputs( - labels={'output_name': np.array([1, 2])}, - predictions={ - 'output_name': np.array([[0, 0.5, 0.3, 0.9], [0.1, 0.4, 0.2, 0.3]]) - }, - example_weights={'output_name': np.array([1.0])}) - got_label, got_pred, got_example_weight = next( - metric_util.to_label_prediction_example_weight( + output_name="output_name", + sub_key=metric_types.SubKey(top_k=2, class_id=1), + ) + + expected_label = 1.0 + expected_prediction = 0.5 + got_label, got_pred, got_example_weight = next(iterator) + self.assertAllClose(got_label, np.array([expected_label])) + self.assertAllClose(got_pred, np.array([expected_prediction])) + self.assertAllClose(got_example_weight, np.array([1.0])) + + def testStandardMetricInputsWithTopKAndAggregationTypeToNumpy(self): + example = metric_types.StandardMetricInputs( + labels={"output_name": np.array([1])}, + predictions={"output_name": np.array([0, 0.5, 0.3, 0.9])}, + example_weights={"output_name": np.array([1.0])}, + ) + iterator = metric_util.to_label_prediction_example_weight( example, - output_name='output_name', - sub_key=metric_types.SubKey(k=2), - flatten=False, - squeeze=False)) - - self.assertAllClose(got_label, np.array([[1], [0]])) - self.assertAllClose(got_pred, np.array([[0.5], [0.3]])) - self.assertAllClose(got_example_weight, np.array([1.0])) - - def testStandardMetricInputsWithTopKToNumpy(self): - example = metric_types.StandardMetricInputs( - labels={'output_name': np.array([1])}, - predictions={'output_name': np.array([0, 0.5, 0.3, 0.9])}, - example_weights={'output_name': np.array([1.0])}) - iterator = metric_util.to_label_prediction_example_weight( - example, - output_name='output_name', - sub_key=metric_types.SubKey(top_k=2)) - - for expected_label, expected_prediction in zip( - (0.0, 1.0, 0.0, 0.0), (float('-inf'), 0.5, float('-inf'), 0.9)): - got_label, got_pred, got_example_weight = next(iterator) - self.assertAllClose(got_label, np.array([expected_label])) - self.assertAllClose(got_pred, np.array([expected_prediction])) - self.assertAllClose(got_example_weight, np.array([1.0])) - - def testStandardMetricInputsWithTopKAndClassIdToNumpy(self): - example = metric_types.StandardMetricInputs( - labels={'output_name': np.array([1])}, - predictions={'output_name': np.array([0, 0.5, 0.3, 0.9])}, - example_weights={'output_name': np.array([1.0])}, - ) - iterator = metric_util.to_label_prediction_example_weight( - example, - output_name='output_name', - sub_key=metric_types.SubKey(top_k=2, class_id=1), - ) - - expected_label = 1.0 - expected_prediction = 0.5 - got_label, got_pred, got_example_weight = next(iterator) - self.assertAllClose(got_label, np.array([expected_label])) - self.assertAllClose(got_pred, np.array([expected_prediction])) - self.assertAllClose(got_example_weight, np.array([1.0])) - - def testStandardMetricInputsWithTopKAndAggregationTypeToNumpy(self): - example = metric_types.StandardMetricInputs( - labels={'output_name': np.array([1])}, - predictions={'output_name': np.array([0, 0.5, 0.3, 0.9])}, - example_weights={'output_name': np.array([1.0])}) - iterator = metric_util.to_label_prediction_example_weight( - example, - output_name='output_name', - sub_key=metric_types.SubKey(top_k=2), - aggregation_type=metric_types.AggregationType(micro_average=True)) - - for expected_label, expected_prediction in zip((1.0, 0.0), (0.5, 0.9)): - got_label, got_pred, got_example_weight = next(iterator) - self.assertAllClose(got_label, np.array([expected_label])) - self.assertAllClose(got_pred, np.array([expected_prediction])) - self.assertAllClose(got_example_weight, np.array([1.0])) - - def testStandardMetricInputsWithTopKToNumpyWithoutFlatten(self): - example = metric_types.StandardMetricInputs( - labels={'output_name': np.array([1, 2])}, - predictions={ - 'output_name': np.array([[0, 0.5, 0.3, 0.9], [0.1, 0.4, 0.2, 0.3]]) - }, - example_weights={'output_name': np.array([1.0])}) - got_label, got_pred, got_example_weight = next( - metric_util.to_label_prediction_example_weight( - example, - output_name='output_name', + output_name="output_name", sub_key=metric_types.SubKey(top_k=2), - flatten=False)) - - self.assertAllClose(got_label, np.array([1, 2])) - self.assertAllClose( - got_pred, - np.array([[float('-inf'), 0.5, float('-inf'), 0.9], - [float('-inf'), 0.4, float('-inf'), 0.3]])) - self.assertAllClose(got_example_weight, np.array([1.0])) - - def testStandardMetricInputsWithClassWeights(self): - example = metric_types.StandardMetricInputs( - labels={'output_name': np.array([2])}, - predictions={'output_name': np.array([0, 0.5, 0.3, 0.9])}, - example_weights={'output_name': np.array([1.0])}) - iterator = metric_util.to_label_prediction_example_weight( - example, - output_name='output_name', - aggregation_type=metric_types.AggregationType(micro_average=True), - class_weights={ - 0: 1.0, - 1: 0.5, - 2: 0.25, - 3: 1.0 - }, - flatten=True) - - for expected_label, expected_prediction, expected_weight in zip( - (0.0, 0.0, 1.0, 0.0), (0.0, 0.5, 0.3, 0.9), (1.0, 0.5, 0.25, 1.0)): - got_label, got_pred, got_example_weight = next(iterator) - self.assertAllClose(got_label, np.array([expected_label])) - self.assertAllClose(got_pred, np.array([expected_prediction])) - self.assertAllClose(got_example_weight, np.array([expected_weight])) - - def testStandardMetricInputsWithClassWeightsRaisesErrorWithoutFlatten(self): - with self.assertRaises(ValueError): - example = metric_types.StandardMetricInputs( - labels=np.array([2]), - predictions=np.array([0, 0.5, 0.3, 0.9]), - example_weights=np.array([1.0])) - next( - metric_util.to_label_prediction_example_weight( - example, class_weights={ - 1: 0.5, - 2: 0.25 - }, flatten=False)) - - def testStandardMetricInputsWithCustomLabelKeys(self): - example = metric_types.StandardMetricInputs( - labels={ - 'custom_label': np.array([2]), - 'other_label': np.array([0]) - }, - predictions={'custom_prediction': np.array([0, 0.5, 0.3, 0.9])}, - example_weights=np.array([1.0])) - eval_config = config_pb2.EvalConfig(model_specs=[ - config_pb2.ModelSpec( - label_key='custom_label', prediction_key='custom_prediction') - ]) - iterator = metric_util.to_label_prediction_example_weight( - example, eval_config=eval_config) - - for expected_label, expected_prediction in zip((0.0, 0.0, 1.0, 0.0), - (0.0, 0.5, 0.3, 0.9)): - got_label, got_pred, got_example_weight = next(iterator) - self.assertAllClose(got_label, np.array([expected_label]), atol=0, rtol=0) - self.assertAllClose( - got_pred, np.array([expected_prediction]), atol=0, rtol=0) - self.assertAllClose(got_example_weight, np.array([1.0]), atol=0, rtol=0) - - def testStandardMetricInputsWithMissingStringLabel(self): - example = metric_types.StandardMetricInputs( - label=np.array(['d']), - prediction={ - 'scores': np.array([0.2, 0.7, 0.1]), - 'classes': np.array(['a', 'b', 'c']) - }, - example_weight=np.array([1.0])) - iterator = metric_util.to_label_prediction_example_weight(example) - - for expected_label, expected_prediction in zip((0.0, 0.0, 0.0), - (0.2, 0.7, 0.1)): - got_label, got_pred, got_example_weight = next(iterator) - self.assertAllClose(got_label, np.array([expected_label]), atol=0, rtol=0) - self.assertAllClose( - got_pred, np.array([expected_prediction]), atol=0, rtol=0) - self.assertAllClose(got_example_weight, np.array([1.0]), atol=0, rtol=0) - - def testStandardMetricInputsWithoutLabels(self): - example = metric_types.StandardMetricInputs( - label={'output_name': np.array([])}, - prediction={'output_name': np.array([0, 0.5, 0.3, 0.9])}, - example_weight={'output_name': np.array([1.0])}) - iterator = metric_util.to_label_prediction_example_weight( - example, output_name='output_name') - - for expected_prediction in (0.0, 0.5, 0.3, 0.9): - got_label, got_pred, got_example_weight = next(iterator) - self.assertAllEqual(got_label, np.array([])) - self.assertAllClose(got_pred, np.array([expected_prediction])) - self.assertAllClose(got_example_weight, np.array([1.0])) - - def testStandardMetricInputsWithoutPredictions(self): - example = metric_types.StandardMetricInputs( - label={'output_name': np.array([0, 0.5, 0.3, 0.9])}, - prediction={'output_name': np.array([])}, - example_weight={'output_name': np.array([1.0])}) - iterator = metric_util.to_label_prediction_example_weight( - example, output_name='output_name') - - for expected_label in (0.0, 0.5, 0.3, 0.9): - got_label, got_pred, got_example_weight = next(iterator) - self.assertAllClose(got_label, np.array([expected_label])) - self.assertAllEqual(got_pred, np.array([])) - self.assertAllClose(got_example_weight, np.array([1.0])) - - def testStandardMetricInputsWithMultipleOutputs(self): - example = metric_types.StandardMetricInputs( - label={ - 'output1': np.array([0, 1]), - 'output2': np.array([1, 1]) - }, - prediction={ - 'output1': np.array([0, 0.5]), - 'output2': np.array([0.2, 0.8]) - }, - example_weight={ - 'output1': np.array([0.5]), - 'output2': np.array([1.0]) - }) - - for output in ('output1', 'output2'): - iterator = metric_util.to_label_prediction_example_weight( - example, output_name=output, flatten=False, example_weighted=True) - got_label, got_pred, got_example_weight = next(iterator) - self.assertAllClose(got_label, example.label[output]) - self.assertAllEqual(got_pred, example.prediction[output]) - self.assertAllClose(got_example_weight, example.example_weight[output]) - - def testStandardMetricInputsWithMultipleOutputsNotExampleWeighted(self): - example = metric_types.StandardMetricInputs( - label={ - 'output1': np.array([0, 1]), - 'output2': np.array([1, 1]) - }, - prediction={ - 'output1': np.array([0, 0.5]), - 'output2': np.array([0.2, 0.8]) - }, - example_weight={ - 'output1': np.array([0.5]), - 'output2': np.array([1.0]) - }) - - for output in ('output1', 'output2'): - iterator = metric_util.to_label_prediction_example_weight( - example, output_name=output, flatten=False, example_weighted=False) - got_label, got_pred, got_example_weight = next(iterator) - self.assertAllClose(got_label, example.label[output]) - self.assertAllEqual(got_pred, example.prediction[output]) - self.assertAllClose(got_example_weight, np.array([1.0])) - - def testStandardMetricInputsWithMissingLabelsAndExampleWeights(self): - example = metric_types.StandardMetricInputs(prediction={ - 'output1': np.array([0, 0.5]), - 'output2': np.array([0.2, 0.8]) - }) - - for output in ('output1', 'output2'): - iterator = metric_util.to_label_prediction_example_weight( - example, output_name=output, flatten=False, allow_none=True) - got_label, got_pred, got_example_weight = next(iterator) - self.assertAllEqual(got_label, np.array([])) - self.assertAllEqual(got_pred, example.prediction[output]) - self.assertAllEqual(got_example_weight, np.array([1.0])) - - def testStandardMetricInputsWithMissingLabelKeyRaisesError(self): - example = metric_types.StandardMetricInputs( - label={'output2': np.array([1, 1])}, - prediction={ - 'output1': np.array([0.5]), - 'output2': np.array([0.8]) - }, - example_weight={ - 'output1': np.array([0.5]), - 'output2': np.array([1.0]) - }) - with self.assertRaisesRegex( - ValueError, 'unable to prepare label for metric computation.*'): - next( - metric_util.to_label_prediction_example_weight( - example, output_name='output1')) - - def testStandardMetricInputsWithMissingPredictionRaisesError(self): - example = metric_types.StandardMetricInputs( - label={ - 'output1': np.array([0, 1]), - 'output2': np.array([1, 1]) - }, - prediction={'output2': np.array([0.8])}, - example_weight={ - 'output1': np.array([0.5]), - 'output2': np.array([1.0]) - }) - with self.assertRaisesRegex(ValueError, '"output1" key not found.*'): - next( - metric_util.to_label_prediction_example_weight( - example, output_name='output1')) - - def testStandardMetricInputsWithMissingExampleWeightKeyRaisesError(self): - example = metric_types.StandardMetricInputs( - label={ - 'output1': np.array([0, 1]), - 'output2': np.array([1, 1]) - }, - prediction={ - 'output1': np.array([0.5]), - 'output2': np.array([0.8]) - }, - example_weight={'output2': np.array([1.0])}) - with self.assertRaisesRegex( - ValueError, - 'unable to prepare example_weight for metric computation.*'): - next( - metric_util.to_label_prediction_example_weight( - example, output_name='output1', example_weighted=True)) - - def testStandardMetricInputsWithNonScalarWeights(self): - example = metric_types.StandardMetricInputs( - label={'output_name': np.array([2])}, - prediction={'output_name': np.array([0, 0.5, 0.3, 0.9])}, - example_weight={'output_name': np.array([1.0, 0.0, 1.0, 1.0])}) - iterable = metric_util.to_label_prediction_example_weight( - example, - output_name='output_name', - example_weighted=True, - require_single_example_weight=False) - - for expected_label, expected_prediction, expected_weight in zip( - (0.0, 0.0, 1.0, 0.0), (0.0, 0.5, 0.3, 0.9), (1.0, 0.0, 1.0, 1.0)): - got_label, got_pred, got_example_weight = next(iterable) - self.assertAllClose(got_label, np.array([expected_label])) - self.assertAllEqual(got_pred, np.array([expected_prediction])) - self.assertAllClose(got_example_weight, np.array([expected_weight])) - - def testStandardMetricInputsWithNonScalarWeightsNoFlatten(self): - example = metric_types.StandardMetricInputs( - label=np.array([2]), - prediction=np.array([0, 0.5, 0.3, 0.9]), - example_weight=np.array([1.0, 0.0, 1.0, 1.0])) - got_label, got_pred, got_example_weight = next( - metric_util.to_label_prediction_example_weight( + aggregation_type=metric_types.AggregationType(micro_average=True), + ) + + for expected_label, expected_prediction in zip((1.0, 0.0), (0.5, 0.9)): + got_label, got_pred, got_example_weight = next(iterator) + self.assertAllClose(got_label, np.array([expected_label])) + self.assertAllClose(got_pred, np.array([expected_prediction])) + self.assertAllClose(got_example_weight, np.array([1.0])) + + def testStandardMetricInputsWithTopKToNumpyWithoutFlatten(self): + example = metric_types.StandardMetricInputs( + labels={"output_name": np.array([1, 2])}, + predictions={ + "output_name": np.array([[0, 0.5, 0.3, 0.9], [0.1, 0.4, 0.2, 0.3]]) + }, + example_weights={"output_name": np.array([1.0])}, + ) + got_label, got_pred, got_example_weight = next( + metric_util.to_label_prediction_example_weight( + example, + output_name="output_name", + sub_key=metric_types.SubKey(top_k=2), + flatten=False, + ) + ) + + self.assertAllClose(got_label, np.array([1, 2])) + self.assertAllClose( + got_pred, + np.array( + [ + [float("-inf"), 0.5, float("-inf"), 0.9], + [float("-inf"), 0.4, float("-inf"), 0.3], + ] + ), + ) + self.assertAllClose(got_example_weight, np.array([1.0])) + + def testStandardMetricInputsWithClassWeights(self): + example = metric_types.StandardMetricInputs( + labels={"output_name": np.array([2])}, + predictions={"output_name": np.array([0, 0.5, 0.3, 0.9])}, + example_weights={"output_name": np.array([1.0])}, + ) + iterator = metric_util.to_label_prediction_example_weight( + example, + output_name="output_name", + aggregation_type=metric_types.AggregationType(micro_average=True), + class_weights={0: 1.0, 1: 0.5, 2: 0.25, 3: 1.0}, + flatten=True, + ) + + for expected_label, expected_prediction, expected_weight in zip( + (0.0, 0.0, 1.0, 0.0), (0.0, 0.5, 0.3, 0.9), (1.0, 0.5, 0.25, 1.0) + ): + got_label, got_pred, got_example_weight = next(iterator) + self.assertAllClose(got_label, np.array([expected_label])) + self.assertAllClose(got_pred, np.array([expected_prediction])) + self.assertAllClose(got_example_weight, np.array([expected_weight])) + + def testStandardMetricInputsWithClassWeightsRaisesErrorWithoutFlatten(self): + with self.assertRaises(ValueError): + example = metric_types.StandardMetricInputs( + labels=np.array([2]), + predictions=np.array([0, 0.5, 0.3, 0.9]), + example_weights=np.array([1.0]), + ) + next( + metric_util.to_label_prediction_example_weight( + example, class_weights={1: 0.5, 2: 0.25}, flatten=False + ) + ) + + def testStandardMetricInputsWithCustomLabelKeys(self): + example = metric_types.StandardMetricInputs( + labels={"custom_label": np.array([2]), "other_label": np.array([0])}, + predictions={"custom_prediction": np.array([0, 0.5, 0.3, 0.9])}, + example_weights=np.array([1.0]), + ) + eval_config = config_pb2.EvalConfig( + model_specs=[ + config_pb2.ModelSpec( + label_key="custom_label", prediction_key="custom_prediction" + ) + ] + ) + iterator = metric_util.to_label_prediction_example_weight( + example, eval_config=eval_config + ) + + for expected_label, expected_prediction in zip( + (0.0, 0.0, 1.0, 0.0), (0.0, 0.5, 0.3, 0.9) + ): + got_label, got_pred, got_example_weight = next(iterator) + self.assertAllClose(got_label, np.array([expected_label]), atol=0, rtol=0) + self.assertAllClose( + got_pred, np.array([expected_prediction]), atol=0, rtol=0 + ) + self.assertAllClose(got_example_weight, np.array([1.0]), atol=0, rtol=0) + + def testStandardMetricInputsWithMissingStringLabel(self): + example = metric_types.StandardMetricInputs( + label=np.array(["d"]), + prediction={ + "scores": np.array([0.2, 0.7, 0.1]), + "classes": np.array(["a", "b", "c"]), + }, + example_weight=np.array([1.0]), + ) + iterator = metric_util.to_label_prediction_example_weight(example) + + for expected_label, expected_prediction in zip( + (0.0, 0.0, 0.0), (0.2, 0.7, 0.1) + ): + got_label, got_pred, got_example_weight = next(iterator) + self.assertAllClose(got_label, np.array([expected_label]), atol=0, rtol=0) + self.assertAllClose( + got_pred, np.array([expected_prediction]), atol=0, rtol=0 + ) + self.assertAllClose(got_example_weight, np.array([1.0]), atol=0, rtol=0) + + def testStandardMetricInputsWithoutLabels(self): + example = metric_types.StandardMetricInputs( + label={"output_name": np.array([])}, + prediction={"output_name": np.array([0, 0.5, 0.3, 0.9])}, + example_weight={"output_name": np.array([1.0])}, + ) + iterator = metric_util.to_label_prediction_example_weight( + example, output_name="output_name" + ) + + for expected_prediction in (0.0, 0.5, 0.3, 0.9): + got_label, got_pred, got_example_weight = next(iterator) + self.assertAllEqual(got_label, np.array([])) + self.assertAllClose(got_pred, np.array([expected_prediction])) + self.assertAllClose(got_example_weight, np.array([1.0])) + + def testStandardMetricInputsWithoutPredictions(self): + example = metric_types.StandardMetricInputs( + label={"output_name": np.array([0, 0.5, 0.3, 0.9])}, + prediction={"output_name": np.array([])}, + example_weight={"output_name": np.array([1.0])}, + ) + iterator = metric_util.to_label_prediction_example_weight( + example, output_name="output_name" + ) + + for expected_label in (0.0, 0.5, 0.3, 0.9): + got_label, got_pred, got_example_weight = next(iterator) + self.assertAllClose(got_label, np.array([expected_label])) + self.assertAllEqual(got_pred, np.array([])) + self.assertAllClose(got_example_weight, np.array([1.0])) + + def testStandardMetricInputsWithMultipleOutputs(self): + example = metric_types.StandardMetricInputs( + label={"output1": np.array([0, 1]), "output2": np.array([1, 1])}, + prediction={"output1": np.array([0, 0.5]), "output2": np.array([0.2, 0.8])}, + example_weight={"output1": np.array([0.5]), "output2": np.array([1.0])}, + ) + + for output in ("output1", "output2"): + iterator = metric_util.to_label_prediction_example_weight( + example, output_name=output, flatten=False, example_weighted=True + ) + got_label, got_pred, got_example_weight = next(iterator) + self.assertAllClose(got_label, example.label[output]) + self.assertAllEqual(got_pred, example.prediction[output]) + self.assertAllClose(got_example_weight, example.example_weight[output]) + + def testStandardMetricInputsWithMultipleOutputsNotExampleWeighted(self): + example = metric_types.StandardMetricInputs( + label={"output1": np.array([0, 1]), "output2": np.array([1, 1])}, + prediction={"output1": np.array([0, 0.5]), "output2": np.array([0.2, 0.8])}, + example_weight={"output1": np.array([0.5]), "output2": np.array([1.0])}, + ) + + for output in ("output1", "output2"): + iterator = metric_util.to_label_prediction_example_weight( + example, output_name=output, flatten=False, example_weighted=False + ) + got_label, got_pred, got_example_weight = next(iterator) + self.assertAllClose(got_label, example.label[output]) + self.assertAllEqual(got_pred, example.prediction[output]) + self.assertAllClose(got_example_weight, np.array([1.0])) + + def testStandardMetricInputsWithMissingLabelsAndExampleWeights(self): + example = metric_types.StandardMetricInputs( + prediction={"output1": np.array([0, 0.5]), "output2": np.array([0.2, 0.8])} + ) + + for output in ("output1", "output2"): + iterator = metric_util.to_label_prediction_example_weight( + example, output_name=output, flatten=False, allow_none=True + ) + got_label, got_pred, got_example_weight = next(iterator) + self.assertAllEqual(got_label, np.array([])) + self.assertAllEqual(got_pred, example.prediction[output]) + self.assertAllEqual(got_example_weight, np.array([1.0])) + + def testStandardMetricInputsWithMissingLabelKeyRaisesError(self): + example = metric_types.StandardMetricInputs( + label={"output2": np.array([1, 1])}, + prediction={"output1": np.array([0.5]), "output2": np.array([0.8])}, + example_weight={"output1": np.array([0.5]), "output2": np.array([1.0])}, + ) + with self.assertRaisesRegex( + ValueError, "unable to prepare label for metric computation.*" + ): + next( + metric_util.to_label_prediction_example_weight( + example, output_name="output1" + ) + ) + + def testStandardMetricInputsWithMissingPredictionRaisesError(self): + example = metric_types.StandardMetricInputs( + label={"output1": np.array([0, 1]), "output2": np.array([1, 1])}, + prediction={"output2": np.array([0.8])}, + example_weight={"output1": np.array([0.5]), "output2": np.array([1.0])}, + ) + with self.assertRaisesRegex(ValueError, '"output1" key not found.*'): + next( + metric_util.to_label_prediction_example_weight( + example, output_name="output1" + ) + ) + + def testStandardMetricInputsWithMissingExampleWeightKeyRaisesError(self): + example = metric_types.StandardMetricInputs( + label={"output1": np.array([0, 1]), "output2": np.array([1, 1])}, + prediction={"output1": np.array([0.5]), "output2": np.array([0.8])}, + example_weight={"output2": np.array([1.0])}, + ) + with self.assertRaisesRegex( + ValueError, "unable to prepare example_weight for metric computation.*" + ): + next( + metric_util.to_label_prediction_example_weight( + example, output_name="output1", example_weighted=True + ) + ) + + def testStandardMetricInputsWithNonScalarWeights(self): + example = metric_types.StandardMetricInputs( + label={"output_name": np.array([2])}, + prediction={"output_name": np.array([0, 0.5, 0.3, 0.9])}, + example_weight={"output_name": np.array([1.0, 0.0, 1.0, 1.0])}, + ) + iterable = metric_util.to_label_prediction_example_weight( example, - flatten=False, + output_name="output_name", example_weighted=True, - require_single_example_weight=False)) - self.assertAllClose(got_label, np.array([2])) - self.assertAllEqual(got_pred, np.array([0, 0.5, 0.3, 0.9])) - self.assertAllClose(got_example_weight, np.array([1.0, 0.0, 1.0, 1.0])) - - def testStandardMetricInputsWithMismatchedExampleWeightsRaisesError(self): - with self.assertRaises(ValueError): - example = metric_types.StandardMetricInputs( - labels=np.array([2]), - predictions=np.array([0, 0.5, 0.3, 0.9]), - example_weights=np.array([1.0, 0.0])) - next( - metric_util.to_label_prediction_example_weight( - example, - flatten=True, - example_weighted=True, - require_single_example_weight=False)) - - def testStandardMetricInputsRequiringSingleExampleWeightRaisesError(self): - with self.assertRaises(ValueError): - example = metric_types.StandardMetricInputs( - labels=np.array([2]), - predictions=np.array([0, 0.5, 0.3, 0.9]), - example_weights=np.array([1.0, 0.0])) - next( - metric_util.to_label_prediction_example_weight( - example, - example_weighted=True, - require_single_example_weight=True)) - - def testPrepareLabelsAndPredictions(self): - labels = [0] - preds = { - 'logistic': np.array([0.8]), - } - got_labels, got_preds = metric_util.prepare_labels_and_predictions( - labels, preds) - - self.assertAllClose(got_labels, np.array([0])) - self.assertAllClose(got_preds, np.array([0.8])) - - def testPrepareLabelsAndPredictionsClassNotFound(self): - labels = ['d'] - preds = { - 'scores': np.array([0.2, 0.7, 0.1]), - 'all_classes': np.array(['a', 'b', 'c']) - } - got_labels, got_preds = metric_util.prepare_labels_and_predictions( - labels, preds) - - self.assertAllClose(got_labels, np.array([0, 0, 0])) - self.assertAllClose(got_preds, np.array([0.2, 0.7, 0.1])) - - def testPrepareLabelsAndPredictionsBatched(self): - labels = [['b']] - preds = { - 'logistic': np.array([[0.8]]), - 'all_classes': np.array([['a', 'b', 'c']]) - } - got_labels, got_preds = metric_util.prepare_labels_and_predictions( - labels, preds) - - self.assertAllClose(got_labels, np.array([[1]])) - self.assertAllClose(got_preds, np.array([[0.8]])) - - def testPrepareLabelsAndPredictionsMixedBatching(self): - labels = np.array([1]) - preds = { - 'predictions': np.array([[0.8]]), - } - got_labels, got_preds = metric_util.prepare_labels_and_predictions( - labels, preds) - - self.assertAllClose(got_labels, np.array([1])) - self.assertAllClose(got_preds, np.array([[0.8]])) - - def testPrepareMultipleLabelsAndPredictions(self): - labels = np.array(['b', 'c', 'a']) - preds = { - 'scores': np.array([0.2, 0.7, 0.1]), - 'classes': np.array(['a', 'b', 'c']) - } - got_labels, got_preds = metric_util.prepare_labels_and_predictions( - labels, preds) - - self.assertAllClose(got_labels, np.array([1, 2, 0])) - self.assertAllClose(got_preds, np.array([0.2, 0.7, 0.1])) - - def testPrepareMultipleLabelsAndPredictionsPythonList(self): - labels = ['b', 'c', 'a'] - preds = {'probabilities': [0.2, 0.7, 0.1], 'all_classes': ['a', 'b', 'c']} - got_labels, got_preds = metric_util.prepare_labels_and_predictions( - labels, preds) - - self.assertAllClose(got_labels, np.array([1, 2, 0])) - self.assertAllClose(got_preds, np.array([0.2, 0.7, 0.1])) - - def testPrepareLabelsAndPredictionsSparseTensorValue(self): - labels = types.SparseTensorValue( - indices=np.array([1, 2]), - values=np.array([1, 1]), - dense_shape=np.array([1, 2])) - preds = {'probabilities': [0.2, 0.7, 0.1], 'all_classes': ['a', 'b', 'c']} - got_labels, got_preds = metric_util.prepare_labels_and_predictions( - labels, preds) - - self.assertAllClose(got_labels, np.array([0, 1, 1])) - self.assertAllClose(got_preds, np.array([0.2, 0.7, 0.1])) - - def testPrepareLabelsAndPredictionsEmptySparseTensorValue(self): - labels = types.SparseTensorValue( - values=np.array([]), indices=np.array([]), dense_shape=np.array([0, 2])) - preds = {'probabilities': [0.2, 0.7, 0.1], 'all_classes': ['a', 'b', 'c']} - got_labels, got_preds = metric_util.prepare_labels_and_predictions( - labels, preds) - - self.assertAllClose(got_labels, np.array([0, 0, 0])) - self.assertAllClose(got_preds, np.array([0.2, 0.7, 0.1])) - - def testPrepareLabelsAndPredictionsSparseTensorValueWithBatching(self): - labels = types.SparseTensorValue( - indices=np.array([1, 2]), - values=np.array([1, 1]), - dense_shape=np.array([1, 2])) - preds = { - 'probabilities': [[0.2, 0.7, 0.1]], - 'all_classes': [['a', 'b', 'c']] - } - got_labels, got_preds = metric_util.prepare_labels_and_predictions( - labels, preds) - - self.assertAllClose(got_labels, np.array([[0, 1, 1]])) - self.assertAllClose(got_preds, np.array([[0.2, 0.7, 0.1]])) - - def testPrepareMultipleLabelsAndPredictionsMultiDimension(self): - labels = [[0], [1]] - preds = {'probabilities': [[0.2, 0.8], [0.3, 0.7]]} - got_labels, got_preds = metric_util.prepare_labels_and_predictions( - labels, preds) - - self.assertAllClose(got_labels, np.array([[0], [1]])) - self.assertAllClose(got_preds, np.array([[0.2, 0.8], [0.3, 0.7]])) - - def testPrepareLabelsAndPredictionsEmpty(self): - labels = [] - preds = {'logistic': [], 'all_classes': ['a', 'b', 'c']} - got_labels, got_preds = metric_util.prepare_labels_and_predictions( - labels, preds) - - self.assertAllClose(got_labels, np.array([])) - self.assertAllClose(got_preds, np.array([])) - - def testPrepareLabelsAndPredictionsWithVocab(self): - labels = np.array(['e', 'f']) - preds = {'probabilities': [0.2, 0.8], 'all_classes': ['a', 'b', 'c']} - got_labels, got_preds = metric_util.prepare_labels_and_predictions( - labels, preds, label_vocabulary=['e', 'f']) - - self.assertAllClose(got_labels, np.array([0, 1])) - self.assertAllClose(got_preds, np.array([0.2, 0.8])) - - def testPrepareLabelsAndPredictionsWithVocabUsingObjectType(self): - labels = np.array(['e', 'f'], dtype=object) - preds = {'probabilities': [0.2, 0.8], 'all_classes': ['a', 'b', 'c']} - got_labels, got_preds = metric_util.prepare_labels_and_predictions( - labels, preds, label_vocabulary=['e', 'f']) - - self.assertAllClose(got_labels, np.array([0, 1])) - self.assertAllClose(got_preds, np.array([0.2, 0.8])) - - def testPrepareLabelsAndPredictionsSparseTensorValueAndVocab(self): - labels = types.SparseTensorValue( - indices=np.array([0, 2]), - values=np.array(['c', 'a']), - dense_shape=np.array([1, 2])) - preds = {'probabilities': [0.2, 0.7, 0.1], 'all_classes': ['a', 'b', 'c']} - got_labels, got_preds = metric_util.prepare_labels_and_predictions( - labels, preds) - - self.assertAllClose(got_labels, np.array([1, 0, 1])) - self.assertAllClose(got_preds, np.array([0.2, 0.7, 0.1])) - - def testPrepareLabelsAndPredictionsUsingBinaryScores(self): - labels = np.array([[0], [1]]) - preds = { - 'scores': np.array([[0.9, 0.2], [0.3, 0.7]]), - 'classes': np.array([['a', 'b'], ['a', 'b']]) - } - got_labels, got_preds = metric_util.prepare_labels_and_predictions( - labels, preds) - - self.assertAllClose(got_labels, np.array([[0], [1]])) - self.assertAllClose(got_preds, np.array([[0.9, 0.2], [0.3, 0.7]])) - - def testPrepareLabelsAndPredictionsUsingBinaryScoresSparse(self): - labels = np.array([1, 0]) - preds = { - 'scores': np.array([[0.9, 0.2], [0.3, 0.7]]), - 'classes': np.array([['a', 'b'], ['a', 'b']]) - } - got_labels, got_preds = metric_util.prepare_labels_and_predictions( - labels, preds) - - self.assertAllClose(got_labels, np.array([1, 0])) - self.assertAllClose(got_preds, np.array([[0.9, 0.2], [0.3, 0.7]])) - - def testPrepareLabelsAndPredictionsUsingBinaryScoresUnbatched(self): - labels = np.array([1]) - preds = {'scores': np.array([0.3, 0.7]), 'classes': np.array(['a', 'b'])} - got_labels, got_preds = metric_util.prepare_labels_and_predictions( - labels, preds) - - self.assertAllClose(got_labels, np.array([1])) - self.assertAllClose(got_preds, np.array([0.3, 0.7])) - - def testSelectClassIDSparse(self): - labels = np.array([2]) - preds = np.array([0.2, 0.7, 0.1]) - got_labels, got_preds = metric_util.select_class_id(1, labels, preds) - - self.assertAllClose(got_labels, np.array([0])) - self.assertAllClose(got_preds, np.array([0.7])) - - def testSelectClassIDSparseNoShape(self): - labels = np.array(2) - preds = np.array([0.2, 0.7, 0.1]) - got_labels, got_preds = metric_util.select_class_id(1, labels, preds) - - self.assertAllClose(got_labels, np.array([0])) - self.assertAllClose(got_preds, np.array([0.7])) - - def testSelectClassIDSparseWithMultipleValues(self): - labels = np.array([0, 2, 1]) - preds = np.array([[0.2, 0.7, 0.1], [0.3, 0.6, 0.1], [0.1, 0.2, 0.7]]) - got_labels, got_preds = metric_util.select_class_id(1, labels, preds) - - self.assertAllClose(got_labels, np.array([[0], [0], [1]])) - self.assertAllClose(got_preds, np.array([[0.7], [0.6], [0.2]])) - - def testSelectClassIDSparseBatched(self): - labels = np.array([[0], [2], [1]]) - preds = np.array([[0.2, 0.7, 0.1], [0.3, 0.6, 0.1], [0.1, 0.2, 0.7]]) - got_labels, got_preds = metric_util.select_class_id(1, labels, preds) - - self.assertAllClose(got_labels, np.array([[0], [0], [1]])) - self.assertAllClose(got_preds, np.array([[0.7], [0.6], [0.2]])) - - def testSelectClassIDSparseMultiDim(self): - labels = np.array([[[0]], [[2]], [[1]]]) - preds = np.array([[[0.2, 0.7, 0.1]], [[0.3, 0.6, 0.1]], [[0.1, 0.2, 0.7]]]) - got_labels, got_preds = metric_util.select_class_id(1, labels, preds) - - self.assertAllClose(got_labels, np.array([[[0]], [[0]], [[1]]])) - self.assertAllClose(got_preds, np.array([[[0.7]], [[0.6]], [[0.2]]])) - - def testRaisesErrorForInvalidSparseSettings(self): - with self.assertRaises(ValueError): - labels = np.array([[0, 0, 1]]) - preds = np.array([[0.2, 0.7, 0.1]]) - metric_util.select_class_id(1, labels, preds, sparse_labels=True) - - def testSelectClassID(self): - labels = np.array([0, 0, 1]) - preds = np.array([0.2, 0.7, 0.1]) - got_labels, got_preds = metric_util.select_class_id(1, labels, preds) - - self.assertAllClose(got_labels, np.array([0])) - self.assertAllClose(got_preds, np.array([0.7])) - - def testSelectClassIDWithMultipleValues(self): - labels = np.array([[0, 0, 1], [0, 0, 1], [0, 1, 0]]) - preds = np.array([[0.2, 0.7, 0.1], [0.3, 0.6, 0.1], [0.1, 0.2, 0.7]]) - got_labels, got_preds = metric_util.select_class_id(1, labels, preds) - - self.assertAllClose(got_labels, np.array([[0], [0], [1]])) - self.assertAllClose(got_preds, np.array([[0.7], [0.6], [0.2]])) - - def testSelectClassIDBatched(self): - labels = np.array([[0, 0, 1]]) - preds = np.array([[0.2, 0.7, 0.1]]) - got_labels, got_preds = metric_util.select_class_id(1, labels, preds) - - self.assertAllClose(got_labels, np.array([[0]])) - self.assertAllClose(got_preds, np.array([[0.7]])) - - def testSelectClassIDMultiDim(self): - labels = np.array([[[0, 0, 1]]]) - preds = np.array([[[0.2, 0.7, 0.1]]]) - got_labels, got_preds = metric_util.select_class_id(1, labels, preds) - - self.assertAllClose(got_labels, np.array([[[0]]])) - self.assertAllClose(got_preds, np.array([[[0.7]]])) - - def testRaisesErrorForInvalidNonSparseSettings(self): - with self.assertRaises(ValueError): - labels = np.array([5]) - preds = np.array([0.2, 0.7, 0.1]) - metric_util.select_class_id(1, labels, preds, sparse_labels=False) - - def testSelectClassIDEmpty(self): - labels = np.array(np.array([])) - preds = np.array(np.array([])) - got_labels, got_preds = metric_util.select_class_id(1, labels, preds) - - self.assertAllClose(got_labels, np.array([])) - self.assertAllClose(got_preds, np.array([])) - - def testTopKIndices(self): - scores = np.array([0.4, 0.1, 0.2, 0.3]) - got = metric_util.top_k_indices(2, scores) - # Indices could be in any order, test by overwritting the original scores - scores[got] = -1.0 - self.assertAllClose(scores, np.array([-1.0, 0.1, 0.2, -1.0])) - - def testTopKIndicesSorted(self): - scores = np.array([0.1, 0.3, 0.4, 0.2]) - got = metric_util.top_k_indices(2, scores, sort=True) - self.assertAllClose(got, np.array([2, 1])) - self.assertAllClose(scores[got], np.array([0.4, 0.3])) - - def testTopKIndices2D(self): - scores = np.array([[0.4, 0.1, 0.2, 0.3], [0.1, 0.2, 0.1, 0.6]]) - got = metric_util.top_k_indices(2, scores) - scores[got] = -1.0 - self.assertAllClose( - scores, np.array([[-1.0, 0.1, 0.2, -1.0], [0.1, -1.0, 0.1, -1.0]])) - - def testTopKIndices2DSorted(self): - scores = np.array([[0.3, 0.1, 0.4, 0.2], [0.1, 0.2, 0.3, 0.6]]) - got = metric_util.top_k_indices(2, scores, sort=True) - # Indices are in ([row_index,...], [col_index, ...]) format. - self.assertAllClose(got, (np.array([0, 0, 1, 1]), np.array([2, 0, 3, 2]))) - self.assertAllClose(scores[got], np.array([0.4, 0.3, 0.6, 0.3])) - - def testTopKIndicesWithBinaryClassification(self): - scores = np.array([0.2, 0.8]) - got = metric_util.top_k_indices(1, scores) - self.assertAllClose(got, np.array([1])) - self.assertAllClose(scores[got], np.array([0.8])) - - -if __name__ == '__main__': - tf.test.main() + require_single_example_weight=False, + ) + + for expected_label, expected_prediction, expected_weight in zip( + (0.0, 0.0, 1.0, 0.0), (0.0, 0.5, 0.3, 0.9), (1.0, 0.0, 1.0, 1.0) + ): + got_label, got_pred, got_example_weight = next(iterable) + self.assertAllClose(got_label, np.array([expected_label])) + self.assertAllEqual(got_pred, np.array([expected_prediction])) + self.assertAllClose(got_example_weight, np.array([expected_weight])) + + def testStandardMetricInputsWithNonScalarWeightsNoFlatten(self): + example = metric_types.StandardMetricInputs( + label=np.array([2]), + prediction=np.array([0, 0.5, 0.3, 0.9]), + example_weight=np.array([1.0, 0.0, 1.0, 1.0]), + ) + got_label, got_pred, got_example_weight = next( + metric_util.to_label_prediction_example_weight( + example, + flatten=False, + example_weighted=True, + require_single_example_weight=False, + ) + ) + self.assertAllClose(got_label, np.array([2])) + self.assertAllEqual(got_pred, np.array([0, 0.5, 0.3, 0.9])) + self.assertAllClose(got_example_weight, np.array([1.0, 0.0, 1.0, 1.0])) + + def testStandardMetricInputsWithMismatchedExampleWeightsRaisesError(self): + with self.assertRaises(ValueError): + example = metric_types.StandardMetricInputs( + labels=np.array([2]), + predictions=np.array([0, 0.5, 0.3, 0.9]), + example_weights=np.array([1.0, 0.0]), + ) + next( + metric_util.to_label_prediction_example_weight( + example, + flatten=True, + example_weighted=True, + require_single_example_weight=False, + ) + ) + + def testStandardMetricInputsRequiringSingleExampleWeightRaisesError(self): + with self.assertRaises(ValueError): + example = metric_types.StandardMetricInputs( + labels=np.array([2]), + predictions=np.array([0, 0.5, 0.3, 0.9]), + example_weights=np.array([1.0, 0.0]), + ) + next( + metric_util.to_label_prediction_example_weight( + example, example_weighted=True, require_single_example_weight=True + ) + ) + + def testPrepareLabelsAndPredictions(self): + labels = [0] + preds = { + "logistic": np.array([0.8]), + } + got_labels, got_preds = metric_util.prepare_labels_and_predictions( + labels, preds + ) + + self.assertAllClose(got_labels, np.array([0])) + self.assertAllClose(got_preds, np.array([0.8])) + + def testPrepareLabelsAndPredictionsClassNotFound(self): + labels = ["d"] + preds = { + "scores": np.array([0.2, 0.7, 0.1]), + "all_classes": np.array(["a", "b", "c"]), + } + got_labels, got_preds = metric_util.prepare_labels_and_predictions( + labels, preds + ) + + self.assertAllClose(got_labels, np.array([0, 0, 0])) + self.assertAllClose(got_preds, np.array([0.2, 0.7, 0.1])) + + def testPrepareLabelsAndPredictionsBatched(self): + labels = [["b"]] + preds = { + "logistic": np.array([[0.8]]), + "all_classes": np.array([["a", "b", "c"]]), + } + got_labels, got_preds = metric_util.prepare_labels_and_predictions( + labels, preds + ) + + self.assertAllClose(got_labels, np.array([[1]])) + self.assertAllClose(got_preds, np.array([[0.8]])) + + def testPrepareLabelsAndPredictionsMixedBatching(self): + labels = np.array([1]) + preds = { + "predictions": np.array([[0.8]]), + } + got_labels, got_preds = metric_util.prepare_labels_and_predictions( + labels, preds + ) + + self.assertAllClose(got_labels, np.array([1])) + self.assertAllClose(got_preds, np.array([[0.8]])) + + def testPrepareMultipleLabelsAndPredictions(self): + labels = np.array(["b", "c", "a"]) + preds = { + "scores": np.array([0.2, 0.7, 0.1]), + "classes": np.array(["a", "b", "c"]), + } + got_labels, got_preds = metric_util.prepare_labels_and_predictions( + labels, preds + ) + + self.assertAllClose(got_labels, np.array([1, 2, 0])) + self.assertAllClose(got_preds, np.array([0.2, 0.7, 0.1])) + + def testPrepareMultipleLabelsAndPredictionsPythonList(self): + labels = ["b", "c", "a"] + preds = {"probabilities": [0.2, 0.7, 0.1], "all_classes": ["a", "b", "c"]} + got_labels, got_preds = metric_util.prepare_labels_and_predictions( + labels, preds + ) + + self.assertAllClose(got_labels, np.array([1, 2, 0])) + self.assertAllClose(got_preds, np.array([0.2, 0.7, 0.1])) + + def testPrepareLabelsAndPredictionsSparseTensorValue(self): + labels = types.SparseTensorValue( + indices=np.array([1, 2]), + values=np.array([1, 1]), + dense_shape=np.array([1, 2]), + ) + preds = {"probabilities": [0.2, 0.7, 0.1], "all_classes": ["a", "b", "c"]} + got_labels, got_preds = metric_util.prepare_labels_and_predictions( + labels, preds + ) + + self.assertAllClose(got_labels, np.array([0, 1, 1])) + self.assertAllClose(got_preds, np.array([0.2, 0.7, 0.1])) + + def testPrepareLabelsAndPredictionsEmptySparseTensorValue(self): + labels = types.SparseTensorValue( + values=np.array([]), indices=np.array([]), dense_shape=np.array([0, 2]) + ) + preds = {"probabilities": [0.2, 0.7, 0.1], "all_classes": ["a", "b", "c"]} + got_labels, got_preds = metric_util.prepare_labels_and_predictions( + labels, preds + ) + + self.assertAllClose(got_labels, np.array([0, 0, 0])) + self.assertAllClose(got_preds, np.array([0.2, 0.7, 0.1])) + + def testPrepareLabelsAndPredictionsSparseTensorValueWithBatching(self): + labels = types.SparseTensorValue( + indices=np.array([1, 2]), + values=np.array([1, 1]), + dense_shape=np.array([1, 2]), + ) + preds = {"probabilities": [[0.2, 0.7, 0.1]], "all_classes": [["a", "b", "c"]]} + got_labels, got_preds = metric_util.prepare_labels_and_predictions( + labels, preds + ) + + self.assertAllClose(got_labels, np.array([[0, 1, 1]])) + self.assertAllClose(got_preds, np.array([[0.2, 0.7, 0.1]])) + + def testPrepareMultipleLabelsAndPredictionsMultiDimension(self): + labels = [[0], [1]] + preds = {"probabilities": [[0.2, 0.8], [0.3, 0.7]]} + got_labels, got_preds = metric_util.prepare_labels_and_predictions( + labels, preds + ) + + self.assertAllClose(got_labels, np.array([[0], [1]])) + self.assertAllClose(got_preds, np.array([[0.2, 0.8], [0.3, 0.7]])) + + def testPrepareLabelsAndPredictionsEmpty(self): + labels = [] + preds = {"logistic": [], "all_classes": ["a", "b", "c"]} + got_labels, got_preds = metric_util.prepare_labels_and_predictions( + labels, preds + ) + + self.assertAllClose(got_labels, np.array([])) + self.assertAllClose(got_preds, np.array([])) + + def testPrepareLabelsAndPredictionsWithVocab(self): + labels = np.array(["e", "f"]) + preds = {"probabilities": [0.2, 0.8], "all_classes": ["a", "b", "c"]} + got_labels, got_preds = metric_util.prepare_labels_and_predictions( + labels, preds, label_vocabulary=["e", "f"] + ) + + self.assertAllClose(got_labels, np.array([0, 1])) + self.assertAllClose(got_preds, np.array([0.2, 0.8])) + + def testPrepareLabelsAndPredictionsWithVocabUsingObjectType(self): + labels = np.array(["e", "f"], dtype=object) + preds = {"probabilities": [0.2, 0.8], "all_classes": ["a", "b", "c"]} + got_labels, got_preds = metric_util.prepare_labels_and_predictions( + labels, preds, label_vocabulary=["e", "f"] + ) + + self.assertAllClose(got_labels, np.array([0, 1])) + self.assertAllClose(got_preds, np.array([0.2, 0.8])) + + def testPrepareLabelsAndPredictionsSparseTensorValueAndVocab(self): + labels = types.SparseTensorValue( + indices=np.array([0, 2]), + values=np.array(["c", "a"]), + dense_shape=np.array([1, 2]), + ) + preds = {"probabilities": [0.2, 0.7, 0.1], "all_classes": ["a", "b", "c"]} + got_labels, got_preds = metric_util.prepare_labels_and_predictions( + labels, preds + ) + + self.assertAllClose(got_labels, np.array([1, 0, 1])) + self.assertAllClose(got_preds, np.array([0.2, 0.7, 0.1])) + + def testPrepareLabelsAndPredictionsUsingBinaryScores(self): + labels = np.array([[0], [1]]) + preds = { + "scores": np.array([[0.9, 0.2], [0.3, 0.7]]), + "classes": np.array([["a", "b"], ["a", "b"]]), + } + got_labels, got_preds = metric_util.prepare_labels_and_predictions( + labels, preds + ) + + self.assertAllClose(got_labels, np.array([[0], [1]])) + self.assertAllClose(got_preds, np.array([[0.9, 0.2], [0.3, 0.7]])) + + def testPrepareLabelsAndPredictionsUsingBinaryScoresSparse(self): + labels = np.array([1, 0]) + preds = { + "scores": np.array([[0.9, 0.2], [0.3, 0.7]]), + "classes": np.array([["a", "b"], ["a", "b"]]), + } + got_labels, got_preds = metric_util.prepare_labels_and_predictions( + labels, preds + ) + + self.assertAllClose(got_labels, np.array([1, 0])) + self.assertAllClose(got_preds, np.array([[0.9, 0.2], [0.3, 0.7]])) + + def testPrepareLabelsAndPredictionsUsingBinaryScoresUnbatched(self): + labels = np.array([1]) + preds = {"scores": np.array([0.3, 0.7]), "classes": np.array(["a", "b"])} + got_labels, got_preds = metric_util.prepare_labels_and_predictions( + labels, preds + ) + + self.assertAllClose(got_labels, np.array([1])) + self.assertAllClose(got_preds, np.array([0.3, 0.7])) + + def testSelectClassIDSparse(self): + labels = np.array([2]) + preds = np.array([0.2, 0.7, 0.1]) + got_labels, got_preds = metric_util.select_class_id(1, labels, preds) + + self.assertAllClose(got_labels, np.array([0])) + self.assertAllClose(got_preds, np.array([0.7])) + + def testSelectClassIDSparseNoShape(self): + labels = np.array(2) + preds = np.array([0.2, 0.7, 0.1]) + got_labels, got_preds = metric_util.select_class_id(1, labels, preds) + + self.assertAllClose(got_labels, np.array([0])) + self.assertAllClose(got_preds, np.array([0.7])) + + def testSelectClassIDSparseWithMultipleValues(self): + labels = np.array([0, 2, 1]) + preds = np.array([[0.2, 0.7, 0.1], [0.3, 0.6, 0.1], [0.1, 0.2, 0.7]]) + got_labels, got_preds = metric_util.select_class_id(1, labels, preds) + + self.assertAllClose(got_labels, np.array([[0], [0], [1]])) + self.assertAllClose(got_preds, np.array([[0.7], [0.6], [0.2]])) + + def testSelectClassIDSparseBatched(self): + labels = np.array([[0], [2], [1]]) + preds = np.array([[0.2, 0.7, 0.1], [0.3, 0.6, 0.1], [0.1, 0.2, 0.7]]) + got_labels, got_preds = metric_util.select_class_id(1, labels, preds) + + self.assertAllClose(got_labels, np.array([[0], [0], [1]])) + self.assertAllClose(got_preds, np.array([[0.7], [0.6], [0.2]])) + + def testSelectClassIDSparseMultiDim(self): + labels = np.array([[[0]], [[2]], [[1]]]) + preds = np.array([[[0.2, 0.7, 0.1]], [[0.3, 0.6, 0.1]], [[0.1, 0.2, 0.7]]]) + got_labels, got_preds = metric_util.select_class_id(1, labels, preds) + + self.assertAllClose(got_labels, np.array([[[0]], [[0]], [[1]]])) + self.assertAllClose(got_preds, np.array([[[0.7]], [[0.6]], [[0.2]]])) + + def testRaisesErrorForInvalidSparseSettings(self): + with self.assertRaises(ValueError): + labels = np.array([[0, 0, 1]]) + preds = np.array([[0.2, 0.7, 0.1]]) + metric_util.select_class_id(1, labels, preds, sparse_labels=True) + + def testSelectClassID(self): + labels = np.array([0, 0, 1]) + preds = np.array([0.2, 0.7, 0.1]) + got_labels, got_preds = metric_util.select_class_id(1, labels, preds) + + self.assertAllClose(got_labels, np.array([0])) + self.assertAllClose(got_preds, np.array([0.7])) + + def testSelectClassIDWithMultipleValues(self): + labels = np.array([[0, 0, 1], [0, 0, 1], [0, 1, 0]]) + preds = np.array([[0.2, 0.7, 0.1], [0.3, 0.6, 0.1], [0.1, 0.2, 0.7]]) + got_labels, got_preds = metric_util.select_class_id(1, labels, preds) + + self.assertAllClose(got_labels, np.array([[0], [0], [1]])) + self.assertAllClose(got_preds, np.array([[0.7], [0.6], [0.2]])) + + def testSelectClassIDBatched(self): + labels = np.array([[0, 0, 1]]) + preds = np.array([[0.2, 0.7, 0.1]]) + got_labels, got_preds = metric_util.select_class_id(1, labels, preds) + + self.assertAllClose(got_labels, np.array([[0]])) + self.assertAllClose(got_preds, np.array([[0.7]])) + + def testSelectClassIDMultiDim(self): + labels = np.array([[[0, 0, 1]]]) + preds = np.array([[[0.2, 0.7, 0.1]]]) + got_labels, got_preds = metric_util.select_class_id(1, labels, preds) + + self.assertAllClose(got_labels, np.array([[[0]]])) + self.assertAllClose(got_preds, np.array([[[0.7]]])) + + def testRaisesErrorForInvalidNonSparseSettings(self): + with self.assertRaises(ValueError): + labels = np.array([5]) + preds = np.array([0.2, 0.7, 0.1]) + metric_util.select_class_id(1, labels, preds, sparse_labels=False) + + def testSelectClassIDEmpty(self): + labels = np.array(np.array([])) + preds = np.array(np.array([])) + got_labels, got_preds = metric_util.select_class_id(1, labels, preds) + + self.assertAllClose(got_labels, np.array([])) + self.assertAllClose(got_preds, np.array([])) + + def testTopKIndices(self): + scores = np.array([0.4, 0.1, 0.2, 0.3]) + got = metric_util.top_k_indices(2, scores) + # Indices could be in any order, test by overwritting the original scores + scores[got] = -1.0 + self.assertAllClose(scores, np.array([-1.0, 0.1, 0.2, -1.0])) + + def testTopKIndicesSorted(self): + scores = np.array([0.1, 0.3, 0.4, 0.2]) + got = metric_util.top_k_indices(2, scores, sort=True) + self.assertAllClose(got, np.array([2, 1])) + self.assertAllClose(scores[got], np.array([0.4, 0.3])) + + def testTopKIndices2D(self): + scores = np.array([[0.4, 0.1, 0.2, 0.3], [0.1, 0.2, 0.1, 0.6]]) + got = metric_util.top_k_indices(2, scores) + scores[got] = -1.0 + self.assertAllClose( + scores, np.array([[-1.0, 0.1, 0.2, -1.0], [0.1, -1.0, 0.1, -1.0]]) + ) + + def testTopKIndices2DSorted(self): + scores = np.array([[0.3, 0.1, 0.4, 0.2], [0.1, 0.2, 0.3, 0.6]]) + got = metric_util.top_k_indices(2, scores, sort=True) + # Indices are in ([row_index,...], [col_index, ...]) format. + self.assertAllClose(got, (np.array([0, 0, 1, 1]), np.array([2, 0, 3, 2]))) + self.assertAllClose(scores[got], np.array([0.4, 0.3, 0.6, 0.3])) + + def testTopKIndicesWithBinaryClassification(self): + scores = np.array([0.2, 0.8]) + got = metric_util.top_k_indices(1, scores) + self.assertAllClose(got, np.array([1])) + self.assertAllClose(scores[got], np.array([0.8])) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow_model_analysis/metrics/min_label_position.py b/tensorflow_model_analysis/metrics/min_label_position.py index 5dd7dfb1f1..4e0666527e 100644 --- a/tensorflow_model_analysis/metrics/min_label_position.py +++ b/tensorflow_model_analysis/metrics/min_label_position.py @@ -17,144 +17,156 @@ import apache_beam as beam import numpy as np -from tensorflow_model_analysis.metrics import metric_types -from tensorflow_model_analysis.metrics import metric_util + +from tensorflow_model_analysis.metrics import metric_types, metric_util from tensorflow_model_analysis.proto import config_pb2 from tensorflow_model_analysis.utils import util -MIN_LABEL_POSITION_NAME = 'min_label_position' +MIN_LABEL_POSITION_NAME = "min_label_position" class MinLabelPosition(metric_types.Metric): - """Min label position metric. - - Calculates the least index in a query which has a positive label. The final - returned value is the weighted average over all queries in the evaluation set - which have at least one labeled entry. Note, ranking is indexed from one, so - the optimal value for this metric is one. If there are no labeled rows in the - evaluation set, the final output will be zero. - - This is a query/ranking based metric so a query_key must also be provided in - the associated metrics spec. - """ - - def __init__(self, - name=MIN_LABEL_POSITION_NAME, - label_key: Optional[str] = None): - """Initializes min label position metric. - - Args: - name: Metric name. - label_key: Optional label key to override default label. + """Min label position metric. + + Calculates the least index in a query which has a positive label. The final + returned value is the weighted average over all queries in the evaluation set + which have at least one labeled entry. Note, ranking is indexed from one, so + the optimal value for this metric is one. If there are no labeled rows in the + evaluation set, the final output will be zero. + + This is a query/ranking based metric so a query_key must also be provided in + the associated metrics spec. """ - super().__init__(_min_label_position, name=name, label_key=label_key) + + def __init__(self, name=MIN_LABEL_POSITION_NAME, label_key: Optional[str] = None): + """Initializes min label position metric. + + Args: + ---- + name: Metric name. + label_key: Optional label key to override default label. + """ + super().__init__(_min_label_position, name=name, label_key=label_key) metric_types.register_metric(MinLabelPosition) -def _min_label_position(name: str = MIN_LABEL_POSITION_NAME, - label_key: Optional[str] = None, - eval_config: Optional[config_pb2.EvalConfig] = None, - model_names: Optional[List[str]] = None, - output_names: Optional[List[str]] = None, - example_weighted: bool = False, - query_key: str = '') -> metric_types.MetricComputations: - """Returns metric computations for min label position.""" - if not query_key: - raise ValueError('a query_key is required to use MinLabelPosition metric') - if model_names is None: - model_names = [''] - if output_names is None: - output_names = [''] - keys = [] - computations = [] - preprocessors = None - if label_key: - preprocessors = [metric_types.FeaturePreprocessor(feature_keys=[label_key])] - for model_name in model_names: - for output_name in output_names: - key = metric_types.MetricKey( - name=name, - model_name=model_name, - output_name=output_name, - example_weighted=example_weighted) - keys.append(key) - computations.append( - metric_types.MetricComputation( - keys=[key], - preprocessors=preprocessors, - combiner=_MinLabelPositionCombiner(key, eval_config, - example_weighted, label_key))) - return computations +def _min_label_position( + name: str = MIN_LABEL_POSITION_NAME, + label_key: Optional[str] = None, + eval_config: Optional[config_pb2.EvalConfig] = None, + model_names: Optional[List[str]] = None, + output_names: Optional[List[str]] = None, + example_weighted: bool = False, + query_key: str = "", +) -> metric_types.MetricComputations: + """Returns metric computations for min label position.""" + if not query_key: + raise ValueError("a query_key is required to use MinLabelPosition metric") + if model_names is None: + model_names = [""] + if output_names is None: + output_names = [""] + keys = [] + computations = [] + preprocessors = None + if label_key: + preprocessors = [metric_types.FeaturePreprocessor(feature_keys=[label_key])] + for model_name in model_names: + for output_name in output_names: + key = metric_types.MetricKey( + name=name, + model_name=model_name, + output_name=output_name, + example_weighted=example_weighted, + ) + keys.append(key) + computations.append( + metric_types.MetricComputation( + keys=[key], + preprocessors=preprocessors, + combiner=_MinLabelPositionCombiner( + key, eval_config, example_weighted, label_key + ), + ) + ) + return computations class _MinLabelPositionAccumulator: - """Min label position accumulator.""" - __slots__ = ['total_min_position', 'total_weighted_examples'] + """Min label position accumulator.""" + + __slots__ = ["total_min_position", "total_weighted_examples"] - def __init__(self): - self.total_min_position = 0.0 - self.total_weighted_examples = 0.0 + def __init__(self): + self.total_min_position = 0.0 + self.total_weighted_examples = 0.0 class _MinLabelPositionCombiner(beam.CombineFn): - """Computes min label position metric.""" - - def __init__(self, key: metric_types.MetricKey, - eval_config: Optional[config_pb2.EvalConfig], - example_weighted: bool, label_key: Optional[str]): - self._key = key - self._eval_config = eval_config - self._example_weighted = example_weighted - self._label_key = label_key - - def create_accumulator(self) -> _MinLabelPositionAccumulator: - return _MinLabelPositionAccumulator() - - def add_input( - self, accumulator: _MinLabelPositionAccumulator, - element: metric_types.StandardMetricInputs - ) -> _MinLabelPositionAccumulator: - labels, predictions, example_weight = next( - metric_util.to_label_prediction_example_weight( - element, - eval_config=self._eval_config, - model_name=self._key.model_name, - output_name=self._key.output_name, - example_weighted=self._example_weighted, - flatten=False, - allow_none=True, - require_single_example_weight=True)) # pytype: disable=wrong-arg-types - if self._label_key: - labels = util.get_by_keys(element.features, [self._label_key]) - if labels is not None: - min_label_pos = None - for i, l in enumerate(labels[np.argsort(predictions)[::-1]]): - if np.sum(l) > 0: - min_label_pos = i + 1 # Use 1-indexed positions - break - if min_label_pos: - accumulator.total_min_position += min_label_pos * float(example_weight) - accumulator.total_weighted_examples += float(example_weight) - return accumulator - - def merge_accumulators( - self, accumulators: Iterable[_MinLabelPositionAccumulator] - ) -> _MinLabelPositionAccumulator: - accumulators = iter(accumulators) - result = next(accumulators) - for accumulator in accumulators: - result.total_min_position += accumulator.total_min_position - result.total_weighted_examples += accumulator.total_weighted_examples - return result - - def extract_output( - self, accumulator: _MinLabelPositionAccumulator - ) -> Dict[metric_types.MetricKey, float]: - if accumulator.total_weighted_examples > 0: - value = ( - accumulator.total_min_position / accumulator.total_weighted_examples) - else: - value = float('nan') - return {self._key: value} + """Computes min label position metric.""" + + def __init__( + self, + key: metric_types.MetricKey, + eval_config: Optional[config_pb2.EvalConfig], + example_weighted: bool, + label_key: Optional[str], + ): + self._key = key + self._eval_config = eval_config + self._example_weighted = example_weighted + self._label_key = label_key + + def create_accumulator(self) -> _MinLabelPositionAccumulator: + return _MinLabelPositionAccumulator() + + def add_input( + self, + accumulator: _MinLabelPositionAccumulator, + element: metric_types.StandardMetricInputs, + ) -> _MinLabelPositionAccumulator: + labels, predictions, example_weight = next( + metric_util.to_label_prediction_example_weight( + element, + eval_config=self._eval_config, + model_name=self._key.model_name, + output_name=self._key.output_name, + example_weighted=self._example_weighted, + flatten=False, + allow_none=True, + require_single_example_weight=True, + ) + ) # pytype: disable=wrong-arg-types + if self._label_key: + labels = util.get_by_keys(element.features, [self._label_key]) + if labels is not None: + min_label_pos = None + for i, l in enumerate(labels[np.argsort(predictions)[::-1]]): + if np.sum(l) > 0: + min_label_pos = i + 1 # Use 1-indexed positions + break + if min_label_pos: + accumulator.total_min_position += min_label_pos * float(example_weight) + accumulator.total_weighted_examples += float(example_weight) + return accumulator + + def merge_accumulators( + self, accumulators: Iterable[_MinLabelPositionAccumulator] + ) -> _MinLabelPositionAccumulator: + accumulators = iter(accumulators) + result = next(accumulators) + for accumulator in accumulators: + result.total_min_position += accumulator.total_min_position + result.total_weighted_examples += accumulator.total_weighted_examples + return result + + def extract_output( + self, accumulator: _MinLabelPositionAccumulator + ) -> Dict[metric_types.MetricKey, float]: + if accumulator.total_weighted_examples > 0: + value = accumulator.total_min_position / accumulator.total_weighted_examples + else: + value = float("nan") + return {self._key: value} diff --git a/tensorflow_model_analysis/metrics/min_label_position_test.py b/tensorflow_model_analysis/metrics/min_label_position_test.py index 2a67962e1c..f4d510fccf 100644 --- a/tensorflow_model_analysis/metrics/min_label_position_test.py +++ b/tensorflow_model_analysis/metrics/min_label_position_test.py @@ -15,14 +15,17 @@ import math -from absl.testing import parameterized import apache_beam as beam -from apache_beam.testing import util import numpy as np import tensorflow as tf -from tensorflow_model_analysis.metrics import metric_types -from tensorflow_model_analysis.metrics import metric_util -from tensorflow_model_analysis.metrics import min_label_position +from absl.testing import parameterized +from apache_beam.testing import util + +from tensorflow_model_analysis.metrics import ( + metric_types, + metric_util, + min_label_position, +) from tensorflow_model_analysis.utils import test_util from tensorflow_model_analysis.utils import util as tfma_util @@ -30,186 +33,191 @@ class MinLabelPositionTest( test_util.TensorflowModelAnalysisTest, parameterized.TestCase ): - - def testRaisesErrorIfNoQueryKey(self): - with self.assertRaises(ValueError): - min_label_position.MinLabelPosition().computations() - - def testRaisesErrorWhenExampleWeightsDiffer(self): - with self.assertRaises(ValueError): - metric = min_label_position.MinLabelPosition().computations( - query_key='query', example_weighted=True)[0] - - query1_example1 = { - 'labels': np.array([0.0]), - 'predictions': np.array([0.2]), - 'example_weights': np.array([1.0]), - 'features': { - 'query': np.array(['query1']) - } - } - query1_example2 = { - 'labels': np.array([1.0]), - 'predictions': np.array([0.8]), - 'example_weights': np.array([0.5]), - 'features': { - 'query': np.array(['query1']) - } - } - - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - _ = ( - pipeline - | 'Create' >> beam.Create( - [tfma_util.merge_extracts([query1_example1, query1_example2])]) - | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs, True) - | 'AddSlice' >> beam.Map(lambda x: ((), x)) - | 'Combine' >> beam.CombinePerKey(metric.combiner)) - - @parameterized.named_parameters(('default_label', None), - ('custom_label', 'custom_label')) - def testMinLabelPosition(self, label_key): - metric = min_label_position.MinLabelPosition( - label_key=label_key).computations( - query_key='query', example_weighted=True)[0] - - query1_example1 = { - 'labels': np.array([1.0]), - 'predictions': np.array([0.2]), - 'example_weights': np.array([1.0]), - 'features': { - 'custom_label': np.array([0.0]), - 'query': np.array(['query1']) + def testRaisesErrorIfNoQueryKey(self): + with self.assertRaises(ValueError): + min_label_position.MinLabelPosition().computations() + + def testRaisesErrorWhenExampleWeightsDiffer(self): + with self.assertRaises(ValueError): + metric = min_label_position.MinLabelPosition().computations( + query_key="query", example_weighted=True + )[0] + + query1_example1 = { + "labels": np.array([0.0]), + "predictions": np.array([0.2]), + "example_weights": np.array([1.0]), + "features": {"query": np.array(["query1"])}, + } + query1_example2 = { + "labels": np.array([1.0]), + "predictions": np.array([0.8]), + "example_weights": np.array([0.5]), + "features": {"query": np.array(["query1"])}, + } + + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + _ = ( + pipeline + | "Create" + >> beam.Create( + [tfma_util.merge_extracts([query1_example1, query1_example2])] + ) + | "Process" >> beam.Map(metric_util.to_standard_metric_inputs, True) + | "AddSlice" >> beam.Map(lambda x: ((), x)) + | "Combine" >> beam.CombinePerKey(metric.combiner) + ) + + @parameterized.named_parameters( + ("default_label", None), ("custom_label", "custom_label") + ) + def testMinLabelPosition(self, label_key): + metric = min_label_position.MinLabelPosition(label_key=label_key).computations( + query_key="query", example_weighted=True + )[0] + + query1_example1 = { + "labels": np.array([1.0]), + "predictions": np.array([0.2]), + "example_weights": np.array([1.0]), + "features": { + "custom_label": np.array([0.0]), + "query": np.array(["query1"]), + }, } - } - query1_example2 = { - 'labels': np.array([0.0]), - 'predictions': np.array([0.8]), - 'example_weights': np.array([1.0]), - 'features': { - 'custom_label': np.array([1.0]), - 'query': np.array(['query1']) + query1_example2 = { + "labels": np.array([0.0]), + "predictions": np.array([0.8]), + "example_weights": np.array([1.0]), + "features": { + "custom_label": np.array([1.0]), + "query": np.array(["query1"]), + }, } - } - query2_example1 = { - 'labels': np.array([1.0]), - 'predictions': np.array([0.9]), - 'example_weights': np.array([2.0]), - 'features': { - 'custom_label': np.array([0.0]), - 'query': np.array(['query2']) + query2_example1 = { + "labels": np.array([1.0]), + "predictions": np.array([0.9]), + "example_weights": np.array([2.0]), + "features": { + "custom_label": np.array([0.0]), + "query": np.array(["query2"]), + }, } - } - query2_example2 = { - 'labels': np.array([0.0]), - 'predictions': np.array([0.1]), - 'example_weights': np.array([2.0]), - 'features': { - 'custom_label': np.array([1.0]), - 'query': np.array(['query2']) + query2_example2 = { + "labels": np.array([0.0]), + "predictions": np.array([0.1]), + "example_weights": np.array([2.0]), + "features": { + "custom_label": np.array([1.0]), + "query": np.array(["query2"]), + }, } - } - query2_example3 = { - 'labels': np.array([0.0]), - 'predictions': np.array([0.5]), - 'example_weights': np.array([2.0]), - 'features': { - 'custom_label': np.array([0.0]), - 'query': np.array(['query2']) + query2_example3 = { + "labels": np.array([0.0]), + "predictions": np.array([0.5]), + "example_weights": np.array([2.0]), + "features": { + "custom_label": np.array([0.0]), + "query": np.array(["query2"]), + }, } - } - query3_example1 = { - 'labels': np.array([1.0]), - 'predictions': np.array([0.9]), - 'example_weights': np.array([3.0]), - 'features': { - 'custom_label': np.array([0.0]), - 'query': np.array(['query3']) + query3_example1 = { + "labels": np.array([1.0]), + "predictions": np.array([0.9]), + "example_weights": np.array([3.0]), + "features": { + "custom_label": np.array([0.0]), + "query": np.array(["query3"]), + }, } - } - examples = [ - tfma_util.merge_extracts([query1_example1, query1_example2]), - tfma_util.merge_extracts( - [query2_example1, query2_example2, query2_example3]), - tfma_util.merge_extracts([query3_example1]) - ] - - if label_key: - self.assertIsNotNone(metric.preprocessors) - - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'Create' >> beam.Create(examples) - | 'Process' >> beam.Map( - metric_util.to_standard_metric_inputs, include_features=True) - | 'AddSlice' >> beam.Map(lambda x: ((), x)) - | 'Combine' >> beam.CombinePerKey(metric.combiner)) - - # pylint: enable=no-value-for-parameter - - def check_result(got): - try: - self.assertLen(got, 1) - got_slice_key, got_metrics = got[0] - self.assertEqual(got_slice_key, ()) - key = metric_types.MetricKey( - name='min_label_position', example_weighted=True) - self.assertIn(key, got_metrics) - if label_key == 'custom_label': - # (1*1.0 + 3*2.0) / (1.0 + 2.0) = 2.333333 - self.assertAllClose(got_metrics[key], 2.333333) - else: - # (2*1.0 + 1*2.0 + 1*3.0) / (1.0 + 2.0 + 3.0) = 1.166666 - self.assertAllClose(got_metrics[key], 1.166666) - - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(result, check_result, label='result') - - def testMinLabelPositionWithNoWeightedExamples(self): - metric = min_label_position.MinLabelPosition().computations( - query_key='query', example_weighted=True)[0] - - query1_example1 = { - 'labels': np.array([1.0]), - 'predictions': np.array([0.2]), - 'example_weights': np.array([0.0]), - 'features': { - 'query': np.array(['query1']) + examples = [ + tfma_util.merge_extracts([query1_example1, query1_example2]), + tfma_util.merge_extracts( + [query2_example1, query2_example2, query2_example3] + ), + tfma_util.merge_extracts([query3_example1]), + ] + + if label_key: + self.assertIsNotNone(metric.preprocessors) + + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "Create" >> beam.Create(examples) + | "Process" + >> beam.Map( + metric_util.to_standard_metric_inputs, include_features=True + ) + | "AddSlice" >> beam.Map(lambda x: ((), x)) + | "Combine" >> beam.CombinePerKey(metric.combiner) + ) + + # pylint: enable=no-value-for-parameter + + def check_result(got): + try: + self.assertLen(got, 1) + got_slice_key, got_metrics = got[0] + self.assertEqual(got_slice_key, ()) + key = metric_types.MetricKey( + name="min_label_position", example_weighted=True + ) + self.assertIn(key, got_metrics) + if label_key == "custom_label": + # (1*1.0 + 3*2.0) / (1.0 + 2.0) = 2.333333 + self.assertAllClose(got_metrics[key], 2.333333) + else: + # (2*1.0 + 1*2.0 + 1*3.0) / (1.0 + 2.0 + 3.0) = 1.166666 + self.assertAllClose(got_metrics[key], 1.166666) + + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(result, check_result, label="result") + + def testMinLabelPositionWithNoWeightedExamples(self): + metric = min_label_position.MinLabelPosition().computations( + query_key="query", example_weighted=True + )[0] + + query1_example1 = { + "labels": np.array([1.0]), + "predictions": np.array([0.2]), + "example_weights": np.array([0.0]), + "features": {"query": np.array(["query1"])}, } - } - - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | - 'Create' >> beam.Create([tfma_util.merge_extracts([query1_example1])]) - | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs, True) - | 'AddSlice' >> beam.Map(lambda x: ((), x)) - | 'Combine' >> beam.CombinePerKey(metric.combiner)) - - # pylint: enable=no-value-for-parameter - - def check_result(got): - try: - self.assertLen(got, 1) - got_slice_key, got_metrics = got[0] - self.assertEqual(got_slice_key, ()) - key = metric_types.MetricKey( - name='min_label_position', example_weighted=True) - self.assertIn(key, got_metrics) - self.assertTrue(math.isnan(got_metrics[key])) - - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(result, check_result, label='result') - - -if __name__ == '__main__': - tf.test.main() + + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "Create" >> beam.Create([tfma_util.merge_extracts([query1_example1])]) + | "Process" >> beam.Map(metric_util.to_standard_metric_inputs, True) + | "AddSlice" >> beam.Map(lambda x: ((), x)) + | "Combine" >> beam.CombinePerKey(metric.combiner) + ) + + # pylint: enable=no-value-for-parameter + + def check_result(got): + try: + self.assertLen(got, 1) + got_slice_key, got_metrics = got[0] + self.assertEqual(got_slice_key, ()) + key = metric_types.MetricKey( + name="min_label_position", example_weighted=True + ) + self.assertIn(key, got_metrics) + self.assertTrue(math.isnan(got_metrics[key])) + + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(result, check_result, label="result") + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow_model_analysis/metrics/model_cosine_similarity.py b/tensorflow_model_analysis/metrics/model_cosine_similarity.py index 0db7ffb659..29ec6dbb23 100644 --- a/tensorflow_model_analysis/metrics/model_cosine_similarity.py +++ b/tensorflow_model_analysis/metrics/model_cosine_similarity.py @@ -13,178 +13,179 @@ # limitations under the License. """Model cosine similiarty metrics.""" -from collections.abc import Iterable import dataclasses +from collections.abc import Iterable from typing import Any, Optional import apache_beam as beam import numpy as np -from tensorflow_model_analysis.metrics import metric_types -from tensorflow_model_analysis.metrics import metric_util + +from tensorflow_model_analysis.metrics import metric_types, metric_util from tensorflow_model_analysis.proto import config_pb2 from tensorflow_model_analysis.utils import model_util -_COSINE_SIMILARITY_METRIC_NAME = 'model_cosine_similarity' +_COSINE_SIMILARITY_METRIC_NAME = "model_cosine_similarity" def _compute_cosine_similarity( baseline_prediction: np.ndarray[Any, Any], candidate_prediction: np.ndarray[Any, Any], ) -> float: - """Computes cosine similarity between two predictions of np.ndarrays.""" - return np.dot(baseline_prediction, candidate_prediction) / ( - np.linalg.norm(baseline_prediction) * np.linalg.norm(candidate_prediction) - ) + """Computes cosine similarity between two predictions of np.ndarrays.""" + return np.dot(baseline_prediction, candidate_prediction) / ( + np.linalg.norm(baseline_prediction) * np.linalg.norm(candidate_prediction) + ) @dataclasses.dataclass class _CosineSimilarityAccumulator: - """Accumulator for computing average CosineSimilarity.""" + """Accumulator for computing average CosineSimilarity.""" - num_examples: int = 0 - sum_cosine_similarity: float = 0.0 + num_examples: int = 0 + sum_cosine_similarity: float = 0.0 - def merge(self, other: '_CosineSimilarityAccumulator'): - self.num_examples += other.num_examples - self.sum_cosine_similarity += other.sum_cosine_similarity + def merge(self, other: "_CosineSimilarityAccumulator"): + self.num_examples += other.num_examples + self.sum_cosine_similarity += other.sum_cosine_similarity - def get_average(self) -> float: - if self.num_examples == 0: - return np.nan - return self.sum_cosine_similarity / self.num_examples + def get_average(self) -> float: + if self.num_examples == 0: + return np.nan + return self.sum_cosine_similarity / self.num_examples class ModelCosineSimilarity(metric_types.Metric): - """ModelCosineSimilarity compares predictions from baseline and candidate models using cosine similarity.""" - - def __init__(self, name: str = _COSINE_SIMILARITY_METRIC_NAME): - super().__init__(self._metric_computation, name=name) - - def _metric_computation( - self, - name: str, - eval_config: config_pb2.EvalConfig, - model_names: Iterable[str], - output_names: Optional[Iterable[str]] = ('',), - sub_keys: Optional[Iterable[metric_types.SubKey]] = None, - ) -> metric_types.MetricComputations: - """Returns the metric computations for calculating the cosine similarity. - - Args: - name: Metric name for individual flip rate. - eval_config: The EvalConfig for this TFMA evaluation. This is used to - identify which model is the baseline. - model_names: The name of the baseline model and the candidate model. - output_names: The set of output names for which to compute this metric. - sub_keys: The set of sub_key settings for which to compute this metric. - """ - computations = [] - - # Get the baseline model name. - baseline_spec = model_util.get_baseline_model_spec(eval_config) - baseline_model_name = baseline_spec.name if baseline_spec else None - - for candidate_model_name in model_names: - if candidate_model_name == baseline_model_name: - continue - for output_name in output_names: - for sub_key in sub_keys or (None,): - # Define the metric key. - key = metric_types.MetricKey( - name=name, - model_name=candidate_model_name, - output_name=output_name, - sub_key=sub_key, - is_diff=True, - ) - - # Append cosine similarity calculation to computations. - computations.append( - metric_types.MetricComputation( - keys=[key], - preprocessors=None, - combiner=_ModelCosineSimilarityCombiner( - metric_key=key, - eval_config=eval_config, - baseline_model_name=baseline_model_name, - model_name=candidate_model_name, - output_name=output_name, - ), - ) - ) - - return computations + """ModelCosineSimilarity compares predictions from baseline and candidate models using cosine similarity.""" + + def __init__(self, name: str = _COSINE_SIMILARITY_METRIC_NAME): + super().__init__(self._metric_computation, name=name) + + def _metric_computation( + self, + name: str, + eval_config: config_pb2.EvalConfig, + model_names: Iterable[str], + output_names: Optional[Iterable[str]] = ("",), + sub_keys: Optional[Iterable[metric_types.SubKey]] = None, + ) -> metric_types.MetricComputations: + """Returns the metric computations for calculating the cosine similarity. + + Args: + ---- + name: Metric name for individual flip rate. + eval_config: The EvalConfig for this TFMA evaluation. This is used to + identify which model is the baseline. + model_names: The name of the baseline model and the candidate model. + output_names: The set of output names for which to compute this metric. + sub_keys: The set of sub_key settings for which to compute this metric. + """ + computations = [] + + # Get the baseline model name. + baseline_spec = model_util.get_baseline_model_spec(eval_config) + baseline_model_name = baseline_spec.name if baseline_spec else None + + for candidate_model_name in model_names: + if candidate_model_name == baseline_model_name: + continue + for output_name in output_names: + for sub_key in sub_keys or (None,): + # Define the metric key. + key = metric_types.MetricKey( + name=name, + model_name=candidate_model_name, + output_name=output_name, + sub_key=sub_key, + is_diff=True, + ) + + # Append cosine similarity calculation to computations. + computations.append( + metric_types.MetricComputation( + keys=[key], + preprocessors=None, + combiner=_ModelCosineSimilarityCombiner( + metric_key=key, + eval_config=eval_config, + baseline_model_name=baseline_model_name, + model_name=candidate_model_name, + output_name=output_name, + ), + ) + ) + + return computations class _ModelCosineSimilarityCombiner(beam.CombineFn): - """A combiner for computing the cosine similarity between models.""" - - def __init__( - self, - metric_key: metric_types.MetricKey, - eval_config: config_pb2.EvalConfig, - baseline_model_name: str, - model_name: str, - output_name: str, - ): - self._metric_key = metric_key - self._eval_config = eval_config - self._baseline_model_name = baseline_model_name - self._model_name = model_name - self._output_name = output_name - - def create_accumulator(self) -> _CosineSimilarityAccumulator: - return _CosineSimilarityAccumulator() - - def add_input( - self, - accumulator: _CosineSimilarityAccumulator, - element: metric_types.StandardMetricInputs, - ) -> _CosineSimilarityAccumulator: - _, baseline_prediction, _ = next( - metric_util.to_label_prediction_example_weight( - inputs=element, - eval_config=self._eval_config, - model_name=self._baseline_model_name, - output_name=self._output_name, - flatten=False, - allow_none=True, + """A combiner for computing the cosine similarity between models.""" + + def __init__( + self, + metric_key: metric_types.MetricKey, + eval_config: config_pb2.EvalConfig, + baseline_model_name: str, + model_name: str, + output_name: str, + ): + self._metric_key = metric_key + self._eval_config = eval_config + self._baseline_model_name = baseline_model_name + self._model_name = model_name + self._output_name = output_name + + def create_accumulator(self) -> _CosineSimilarityAccumulator: + return _CosineSimilarityAccumulator() + + def add_input( + self, + accumulator: _CosineSimilarityAccumulator, + element: metric_types.StandardMetricInputs, + ) -> _CosineSimilarityAccumulator: + _, baseline_prediction, _ = next( + metric_util.to_label_prediction_example_weight( + inputs=element, + eval_config=self._eval_config, + model_name=self._baseline_model_name, + output_name=self._output_name, + flatten=False, + allow_none=True, + ) ) - ) - _, candidate_prediction, _ = next( - metric_util.to_label_prediction_example_weight( - inputs=element, - eval_config=self._eval_config, - model_name=self._model_name, - output_name=self._output_name, - flatten=False, - allow_none=True, + _, candidate_prediction, _ = next( + metric_util.to_label_prediction_example_weight( + inputs=element, + eval_config=self._eval_config, + model_name=self._model_name, + output_name=self._output_name, + flatten=False, + allow_none=True, + ) ) - ) - accumulator.merge( - _CosineSimilarityAccumulator( - num_examples=1, - sum_cosine_similarity=_compute_cosine_similarity( - baseline_prediction, candidate_prediction - ), + accumulator.merge( + _CosineSimilarityAccumulator( + num_examples=1, + sum_cosine_similarity=_compute_cosine_similarity( + baseline_prediction, candidate_prediction + ), + ) ) - ) - return accumulator + return accumulator - def merge_accumulators( - self, accumulators: Iterable[_CosineSimilarityAccumulator] - ) -> _CosineSimilarityAccumulator: - result = next(iter(accumulators)) - for accumulator in accumulators: - result.merge(accumulator) - return result + def merge_accumulators( + self, accumulators: Iterable[_CosineSimilarityAccumulator] + ) -> _CosineSimilarityAccumulator: + result = next(iter(accumulators)) + for accumulator in accumulators: + result.merge(accumulator) + return result - def extract_output( - self, accumulator: _CosineSimilarityAccumulator - ) -> dict[metric_types.MetricKey, float]: - return {self._metric_key: accumulator.get_average()} + def extract_output( + self, accumulator: _CosineSimilarityAccumulator + ) -> dict[metric_types.MetricKey, float]: + return {self._metric_key: accumulator.get_average()} # Register Model Cosine Similarity metric. diff --git a/tensorflow_model_analysis/metrics/model_cosine_similarity_test.py b/tensorflow_model_analysis/metrics/model_cosine_similarity_test.py index 56ecc464fe..e9c10df90e 100644 --- a/tensorflow_model_analysis/metrics/model_cosine_similarity_test.py +++ b/tensorflow_model_analysis/metrics/model_cosine_similarity_test.py @@ -13,72 +13,72 @@ # limitations under the License. """Tests for model cosine similiarty metrics.""" -from absl.testing import absltest -from absl.testing import parameterized import apache_beam as beam -from apache_beam.testing import util import numpy as np +from absl.testing import absltest, parameterized +from apache_beam.testing import util +from google.protobuf import text_format + from tensorflow_model_analysis import constants -from tensorflow_model_analysis.metrics import metric_types -from tensorflow_model_analysis.metrics import metric_util -from tensorflow_model_analysis.metrics import model_cosine_similarity +from tensorflow_model_analysis.metrics import ( + metric_types, + metric_util, + model_cosine_similarity, +) from tensorflow_model_analysis.proto import config_pb2 -from google.protobuf import text_format - _PREDICTION_A = np.array([1.0, 0.5, 0.5, 1.0]) _PREDICTION_B = np.array([0.5, 1.0, 1.0, 0.5]) _PREDICTION_C = np.array([0.25, 0.1, 0.9, 0.75]) class ModelCosineSimilarityMetricsTest(parameterized.TestCase): - - @parameterized.named_parameters( - dict( - testcase_name='no_change', - prediction_pairs=[ - (_PREDICTION_A, _PREDICTION_A), - (_PREDICTION_B, _PREDICTION_B), - (_PREDICTION_C, _PREDICTION_C), - ], - # cs(p1, p2): - # np.dot(p1, p2) / (np.linalg.norm(p1) * np.linalg.norm(p2)) - # cs(_PREDICTION_A/B/C, _PREDICTION_A/B/C) = 1.0 - expected_average_cosine_similarity=1.0, - ), - dict( - testcase_name='small_change', - prediction_pairs=[ - (_PREDICTION_A, _PREDICTION_A), - (_PREDICTION_B, _PREDICTION_A), - (_PREDICTION_A, _PREDICTION_B), - ], - # cs(_PREDICTION_A, _PREDICTION_A) = 1.0 - # cs(_PREDICTION_B, _PREDICTION_A) = 0.8 - # cs(_PREDICTION_A, _PREDICTION_B) = 0.8 - expected_average_cosine_similarity=0.8666666666666666, - ), - dict( - testcase_name='large_change', - prediction_pairs=[ - (_PREDICTION_C, _PREDICTION_A), - (_PREDICTION_A, _PREDICTION_B), - (_PREDICTION_B, _PREDICTION_C), - ], - # cs(_PREDICTION_C, _PREDICTION_A) = 0.7892004626469845 - # cs(_PREDICTION_A, _PREDICTION_B) = 0.8 - # cs(_PREDICTION_B, _PREDICTION_C) = 0.7892004626469845 - expected_average_cosine_similarity=0.7928003084313229, - ), - ) - def test_cosine_similarity( - self, prediction_pairs, expected_average_cosine_similarity - ): - baseline_model_name = 'baseline' - candidate_model_name = 'candidate' - - eval_config = text_format.Parse( - """ + @parameterized.named_parameters( + dict( + testcase_name="no_change", + prediction_pairs=[ + (_PREDICTION_A, _PREDICTION_A), + (_PREDICTION_B, _PREDICTION_B), + (_PREDICTION_C, _PREDICTION_C), + ], + # cs(p1, p2): + # np.dot(p1, p2) / (np.linalg.norm(p1) * np.linalg.norm(p2)) + # cs(_PREDICTION_A/B/C, _PREDICTION_A/B/C) = 1.0 + expected_average_cosine_similarity=1.0, + ), + dict( + testcase_name="small_change", + prediction_pairs=[ + (_PREDICTION_A, _PREDICTION_A), + (_PREDICTION_B, _PREDICTION_A), + (_PREDICTION_A, _PREDICTION_B), + ], + # cs(_PREDICTION_A, _PREDICTION_A) = 1.0 + # cs(_PREDICTION_B, _PREDICTION_A) = 0.8 + # cs(_PREDICTION_A, _PREDICTION_B) = 0.8 + expected_average_cosine_similarity=0.8666666666666666, + ), + dict( + testcase_name="large_change", + prediction_pairs=[ + (_PREDICTION_C, _PREDICTION_A), + (_PREDICTION_A, _PREDICTION_B), + (_PREDICTION_B, _PREDICTION_C), + ], + # cs(_PREDICTION_C, _PREDICTION_A) = 0.7892004626469845 + # cs(_PREDICTION_A, _PREDICTION_B) = 0.8 + # cs(_PREDICTION_B, _PREDICTION_C) = 0.7892004626469845 + expected_average_cosine_similarity=0.7928003084313229, + ), + ) + def test_cosine_similarity( + self, prediction_pairs, expected_average_cosine_similarity + ): + baseline_model_name = "baseline" + candidate_model_name = "candidate" + + eval_config = text_format.Parse( + """ model_specs { name: "baseline" is_baseline: true @@ -87,64 +87,66 @@ def test_cosine_similarity( name: "candidate" } """, - config_pb2.EvalConfig(), - ) - - computations = model_cosine_similarity.ModelCosineSimilarity().computations( - eval_config=eval_config, - model_names=['baseline', 'candidate'], - output_names=[''], - ) - self.assertLen(computations, 1) - cosine_similarity = computations[0] - - examples = [] - for baseline_prediction, candidate_prediction in prediction_pairs: - examples.append({ - constants.LABELS_KEY: { - baseline_model_name: None, - candidate_model_name: None, - }, - constants.PREDICTIONS_KEY: { - baseline_model_name: baseline_prediction, - candidate_model_name: candidate_prediction, - }, - }) - - with beam.Pipeline() as pipeline: - result = ( - pipeline - | 'Create' >> beam.Create(examples) - | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs) - | 'AddSlice' >> beam.Map(lambda x: ((), x)) - | 'ComputeMetric' >> beam.CombinePerKey(cosine_similarity.combiner) - ) - - def check_result(got): - try: - self.assertLen(got, 1) - got_slice_key, got_metrics = got[0] - self.assertEqual(got_slice_key, ()) - - metric_key = metric_types.MetricKey( - name=model_cosine_similarity._COSINE_SIMILARITY_METRIC_NAME, - model_name=candidate_model_name, - output_name='', - is_diff=True, - ) - - self.assertIn(metric_key, got_metrics) - self.assertIsInstance(got_metrics[metric_key], float) - self.assertAlmostEqual( - got_metrics[metric_key], - expected_average_cosine_similarity, - ) - - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(result, check_result, label='result') - - -if __name__ == '__main__': - absltest.main() + config_pb2.EvalConfig(), + ) + + computations = model_cosine_similarity.ModelCosineSimilarity().computations( + eval_config=eval_config, + model_names=["baseline", "candidate"], + output_names=[""], + ) + self.assertLen(computations, 1) + cosine_similarity = computations[0] + + examples = [] + for baseline_prediction, candidate_prediction in prediction_pairs: + examples.append( + { + constants.LABELS_KEY: { + baseline_model_name: None, + candidate_model_name: None, + }, + constants.PREDICTIONS_KEY: { + baseline_model_name: baseline_prediction, + candidate_model_name: candidate_prediction, + }, + } + ) + + with beam.Pipeline() as pipeline: + result = ( + pipeline + | "Create" >> beam.Create(examples) + | "Process" >> beam.Map(metric_util.to_standard_metric_inputs) + | "AddSlice" >> beam.Map(lambda x: ((), x)) + | "ComputeMetric" >> beam.CombinePerKey(cosine_similarity.combiner) + ) + + def check_result(got): + try: + self.assertLen(got, 1) + got_slice_key, got_metrics = got[0] + self.assertEqual(got_slice_key, ()) + + metric_key = metric_types.MetricKey( + name=model_cosine_similarity._COSINE_SIMILARITY_METRIC_NAME, + model_name=candidate_model_name, + output_name="", + is_diff=True, + ) + + self.assertIn(metric_key, got_metrics) + self.assertIsInstance(got_metrics[metric_key], float) + self.assertAlmostEqual( + got_metrics[metric_key], + expected_average_cosine_similarity, + ) + + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(result, check_result, label="result") + + +if __name__ == "__main__": + absltest.main() diff --git a/tensorflow_model_analysis/metrics/multi_class_confusion_matrix_metrics.py b/tensorflow_model_analysis/metrics/multi_class_confusion_matrix_metrics.py index 84b802c8b6..26279ba1ad 100644 --- a/tensorflow_model_analysis/metrics/multi_class_confusion_matrix_metrics.py +++ b/tensorflow_model_analysis/metrics/multi_class_confusion_matrix_metrics.py @@ -13,48 +13,63 @@ # limitations under the License. """Multi-class confusion matrix metrics at thresholds.""" -from typing import Any, Callable, Dict, Iterable, Iterator, List, NamedTuple, Optional, Tuple +from typing import ( + Any, + Callable, + Dict, + Iterable, + Iterator, + List, + NamedTuple, + Optional, + Tuple, +) import apache_beam as beam import numpy as np + from tensorflow_model_analysis.api import types -from tensorflow_model_analysis.metrics import metric_types -from tensorflow_model_analysis.metrics import metric_util -from tensorflow_model_analysis.proto import config_pb2 -from tensorflow_model_analysis.proto import metrics_for_slice_pb2 +from tensorflow_model_analysis.metrics import metric_types, metric_util +from tensorflow_model_analysis.proto import config_pb2, metrics_for_slice_pb2 MULTI_CLASS_CONFUSION_MATRIX_AT_THRESHOLDS_NAME = ( - 'multi_class_confusion_matrix_at_thresholds') + "multi_class_confusion_matrix_at_thresholds" +) class MultiClassConfusionMatrixAtThresholds(metric_types.Metric): - """Multi-class confusion matrix metrics at thresholds. - - Computes weighted example counts for all combinations of actual / (top) - predicted classes. - - The inputs are assumed to contain a single positive label per example (i.e. - only one class can be true at a time) while the predictions are assumed to sum - to 1.0. - """ + """Multi-class confusion matrix metrics at thresholds. - def __init__(self, - thresholds: Optional[List[float]] = None, - name: str = MULTI_CLASS_CONFUSION_MATRIX_AT_THRESHOLDS_NAME): - """Initializes multi-class confusion matrix. + Computes weighted example counts for all combinations of actual / (top) + predicted classes. - Args: - thresholds: Optional thresholds, defaults to 0.5 if not specified. If the - top prediction is less than a threshold then the associated example will - be assumed to have no prediction associated with it (the - predicted_class_id will be set to NO_PREDICTED_CLASS_ID). - name: Metric name. + The inputs are assumed to contain a single positive label per example (i.e. + only one class can be true at a time) while the predictions are assumed to sum + to 1.0. """ - super().__init__( - metric_util.merge_per_key_computations( - _multi_class_confusion_matrix_at_thresholds), - thresholds=thresholds, - name=name) # pytype: disable=wrong-arg-types + + def __init__( + self, + thresholds: Optional[List[float]] = None, + name: str = MULTI_CLASS_CONFUSION_MATRIX_AT_THRESHOLDS_NAME, + ): + """Initializes multi-class confusion matrix. + + Args: + ---- + thresholds: Optional thresholds, defaults to 0.5 if not specified. If the + top prediction is less than a threshold then the associated example will + be assumed to have no prediction associated with it (the + predicted_class_id will be set to NO_PREDICTED_CLASS_ID). + name: Metric name. + """ + super().__init__( + metric_util.merge_per_key_computations( + _multi_class_confusion_matrix_at_thresholds + ), + thresholds=thresholds, + name=name, + ) # pytype: disable=wrong-arg-types metric_types.register_metric(MultiClassConfusionMatrixAtThresholds) @@ -64,43 +79,51 @@ def _multi_class_confusion_matrix_at_thresholds( thresholds: Optional[List[float]] = None, name: str = MULTI_CLASS_CONFUSION_MATRIX_AT_THRESHOLDS_NAME, eval_config: Optional[config_pb2.EvalConfig] = None, - model_name: str = '', - output_name: str = '', - example_weighted: bool = False) -> metric_types.MetricComputations: - """Returns computations for multi-class confusion matrix at thresholds.""" - if not thresholds: - thresholds = [0.5] - - key = metric_types.MetricKey( - name=name, - model_name=model_name, - output_name=output_name, - example_weighted=example_weighted) - - # Make sure matrices are calculated. - matrices_computations = multi_class_confusion_matrices( - thresholds=thresholds, - eval_config=eval_config, - model_name=model_name, - output_name=output_name, - example_weighted=example_weighted) - matrices_key = matrices_computations[-1].keys[-1] - - def result( - metrics: Dict[metric_types.MetricKey, - metrics_for_slice_pb2.MultiClassConfusionMatrixAtThresholds] - ) -> Dict[metric_types.MetricKey, - metrics_for_slice_pb2.MultiClassConfusionMatrixAtThresholds]: - return {key: metrics[matrices_key]} - - derived_computation = metric_types.DerivedMetricComputation( - keys=[key], result=result) - computations = matrices_computations - computations.append(derived_computation) - return computations - - -MULTI_CLASS_CONFUSION_MATRICES = '_multi_class_confusion_matrices' + model_name: str = "", + output_name: str = "", + example_weighted: bool = False, +) -> metric_types.MetricComputations: + """Returns computations for multi-class confusion matrix at thresholds.""" + if not thresholds: + thresholds = [0.5] + + key = metric_types.MetricKey( + name=name, + model_name=model_name, + output_name=output_name, + example_weighted=example_weighted, + ) + + # Make sure matrices are calculated. + matrices_computations = multi_class_confusion_matrices( + thresholds=thresholds, + eval_config=eval_config, + model_name=model_name, + output_name=output_name, + example_weighted=example_weighted, + ) + matrices_key = matrices_computations[-1].keys[-1] + + def result( + metrics: Dict[ + metric_types.MetricKey, + metrics_for_slice_pb2.MultiClassConfusionMatrixAtThresholds, + ], + ) -> Dict[ + metric_types.MetricKey, + metrics_for_slice_pb2.MultiClassConfusionMatrixAtThresholds, + ]: + return {key: metrics[matrices_key]} + + derived_computation = metric_types.DerivedMetricComputation( + keys=[key], result=result + ) + computations = matrices_computations + computations.append(derived_computation) + return computations + + +MULTI_CLASS_CONFUSION_MATRICES = "_multi_class_confusion_matrices" _EPSILON = 1e-7 @@ -113,193 +136,211 @@ def multi_class_confusion_matrices( thresholds: Optional[List[float]] = None, num_thresholds: Optional[int] = None, name: str = MULTI_CLASS_CONFUSION_MATRICES, - extract_label_prediction_and_weight: Optional[Callable[..., Iterator[Tuple[ - np.ndarray, np.ndarray, - np.ndarray]]]] = metric_util.to_label_prediction_example_weight, + extract_label_prediction_and_weight: Optional[ + Callable[..., Iterator[Tuple[np.ndarray, np.ndarray, np.ndarray]]] + ] = metric_util.to_label_prediction_example_weight, eval_config: Optional[config_pb2.EvalConfig] = None, - model_name: str = '', - output_name: str = '', - example_weighted: bool = False) -> metric_types.MetricComputations: - """Returns computations for multi-class confusion matrices. - - Args: - thresholds: A specific set of thresholds to use. The caller is responsible - for marking the bondaires with +/-epsilon if desired. Only one of - num_thresholds or thresholds should be used. - num_thresholds: Number of thresholds to use. Thresholds will be calculated - using linear interpolation between 0.0 and 1.0 with equidistant values and - bondardaries at -epsilon and 1.0+epsilon. Values must be > 0. Only one of - num_thresholds or thresholds should be used. - name: Metric name. - extract_label_prediction_and_weight: User-provided function argument that - yields label, prediction, and example weights for use in calculations. - eval_config: Eval config. - model_name: Optional model name (if multi-model evaluation). - output_name: Optional output name (if multi-output model type). - example_weighted: True if example weights should be applied. - - Raises: - ValueError: If both num_thresholds and thresholds are set at the same time. - """ - if num_thresholds is not None and thresholds is not None: - raise ValueError( - 'only one of thresholds or num_thresholds can be set at a time') - if num_thresholds is None and thresholds is None: - thresholds = [0.0] - if num_thresholds is not None: - thresholds = [ - (i + 1) * 1.0 / (num_thresholds - 1) for i in range(num_thresholds - 2) + model_name: str = "", + output_name: str = "", + example_weighted: bool = False, +) -> metric_types.MetricComputations: + """Returns computations for multi-class confusion matrices. + + Args: + ---- + thresholds: A specific set of thresholds to use. The caller is responsible + for marking the bondaires with +/-epsilon if desired. Only one of + num_thresholds or thresholds should be used. + num_thresholds: Number of thresholds to use. Thresholds will be calculated + using linear interpolation between 0.0 and 1.0 with equidistant values and + bondardaries at -epsilon and 1.0+epsilon. Values must be > 0. Only one of + num_thresholds or thresholds should be used. + name: Metric name. + extract_label_prediction_and_weight: User-provided function argument that + yields label, prediction, and example weights for use in calculations. + eval_config: Eval config. + model_name: Optional model name (if multi-model evaluation). + output_name: Optional output name (if multi-output model type). + example_weighted: True if example weights should be applied. + + Raises: + ------ + ValueError: If both num_thresholds and thresholds are set at the same time. + """ + if num_thresholds is not None and thresholds is not None: + raise ValueError( + "only one of thresholds or num_thresholds can be set at a time" + ) + if num_thresholds is None and thresholds is None: + thresholds = [0.0] + if num_thresholds is not None: + thresholds = [ + (i + 1) * 1.0 / (num_thresholds - 1) for i in range(num_thresholds - 2) + ] + thresholds = [-_EPSILON] + thresholds + [1.0 + _EPSILON] + + key = metric_types.MetricKey( + name=name, + model_name=model_name, + output_name=output_name, + example_weighted=example_weighted, + ) + return [ + metric_types.MetricComputation( + keys=[key], + preprocessors=None, + combiner=_MultiClassConfusionMatrixCombiner( + key=key, + eval_config=eval_config, + example_weighted=example_weighted, + thresholds=thresholds, + extract_label_prediction_and_weight=extract_label_prediction_and_weight, + ), + ) ] - thresholds = [-_EPSILON] + thresholds + [1.0 + _EPSILON] - - key = metric_types.MetricKey( - name=name, - model_name=model_name, - output_name=output_name, - example_weighted=example_weighted) - return [ - metric_types.MetricComputation( - keys=[key], - preprocessors=None, - combiner=_MultiClassConfusionMatrixCombiner( - key=key, - eval_config=eval_config, - example_weighted=example_weighted, - thresholds=thresholds, - extract_label_prediction_and_weight=extract_label_prediction_and_weight - )) - ] - - -MatrixEntryKey = NamedTuple('MatrixEntryKey', [('actual_class_id', int), - ('predicted_class_id', int)]) + + +class MatrixEntryKey(NamedTuple): + actual_class_id: int + predicted_class_id: int class Matrices(types.StructuredMetricValue, dict): - """A Matrices object wraps a Dict[float, Dict[MatrixEntryKey, float]]. - - A specific confusion matrix entry can be accessed for a threshold, - actual_class and predicted_class with - - instance[threshold][MatrixEntryKey(actual_class_id, predicted_class_id)] - """ - - def _apply_binary_op_elementwise( - self, other: 'Matrices', op: Callable[[float, float], - float]) -> 'Matrices': - result = Matrices() - all_thresholds = set(self).union(other) - for threshold in all_thresholds: - self_entries = self.get(threshold, {}) - other_entries = other.get(threshold, {}) - result[threshold] = {} - all_entry_keys = set(self_entries).union(other_entries) - for entry_key in all_entry_keys: - self_count = self_entries.get(entry_key, 0) - other_count = other_entries.get(entry_key, 0) - result[threshold][entry_key] = op(self_count, other_count) - return result - - def _apply_binary_op_broadcast( - self, other: float, op: Callable[[float, float], float]) -> 'Matrices': - result = Matrices() - for threshold, self_entries in self.items(): - result[threshold] = {} - for entry_key, self_count in self_entries.items(): - result[threshold][entry_key] = op(self_count, other) - return result - - def to_proto(self) -> metrics_for_slice_pb2.MetricValue: - result = metrics_for_slice_pb2.MetricValue() - multi_class_confusion_matrices_at_thresholds_proto = ( - result.multi_class_confusion_matrix_at_thresholds) - for threshold in sorted(self): - # Convert -epsilon and 1.0+epsilon back to 0.0 and 1.0. - if threshold == -_EPSILON: - t = 0.0 - elif threshold == 1.0 + _EPSILON: - t = 1.0 - else: - t = threshold - matrix = multi_class_confusion_matrices_at_thresholds_proto.matrices.add( - threshold=t) - for k in sorted(self[threshold]): - matrix.entries.add( - actual_class_id=k.actual_class_id, - predicted_class_id=k.predicted_class_id, - num_weighted_examples=self[threshold][k]) - return result + """A Matrices object wraps a Dict[float, Dict[MatrixEntryKey, float]]. + + A specific confusion matrix entry can be accessed for a threshold, + actual_class and predicted_class with + + instance[threshold][MatrixEntryKey(actual_class_id, predicted_class_id)] + """ + + def _apply_binary_op_elementwise( + self, other: "Matrices", op: Callable[[float, float], float] + ) -> "Matrices": + result = Matrices() + all_thresholds = set(self).union(other) + for threshold in all_thresholds: + self_entries = self.get(threshold, {}) + other_entries = other.get(threshold, {}) + result[threshold] = {} + all_entry_keys = set(self_entries).union(other_entries) + for entry_key in all_entry_keys: + self_count = self_entries.get(entry_key, 0) + other_count = other_entries.get(entry_key, 0) + result[threshold][entry_key] = op(self_count, other_count) + return result + + def _apply_binary_op_broadcast( + self, other: float, op: Callable[[float, float], float] + ) -> "Matrices": + result = Matrices() + for threshold, self_entries in self.items(): + result[threshold] = {} + for entry_key, self_count in self_entries.items(): + result[threshold][entry_key] = op(self_count, other) + return result + + def to_proto(self) -> metrics_for_slice_pb2.MetricValue: + result = metrics_for_slice_pb2.MetricValue() + multi_class_confusion_matrices_at_thresholds_proto = ( + result.multi_class_confusion_matrix_at_thresholds + ) + for threshold in sorted(self): + # Convert -epsilon and 1.0+epsilon back to 0.0 and 1.0. + if threshold == -_EPSILON: + t = 0.0 + elif threshold == 1.0 + _EPSILON: + t = 1.0 + else: + t = threshold + matrix = multi_class_confusion_matrices_at_thresholds_proto.matrices.add( + threshold=t + ) + for k in sorted(self[threshold]): + matrix.entries.add( + actual_class_id=k.actual_class_id, + predicted_class_id=k.predicted_class_id, + num_weighted_examples=self[threshold][k], + ) + return result class _MultiClassConfusionMatrixCombiner(beam.CombineFn): - """Creates multi-class confusion matrix at thresholds from standard inputs.""" - - def __init__(self, key: metric_types.MetricKey, - eval_config: Optional[config_pb2.EvalConfig], - example_weighted: bool, thresholds: List[float], - extract_label_prediction_and_weight: Optional[Callable[..., - Any]]): - self._key = key - self._eval_config = eval_config - self._example_weighted = example_weighted - self._thresholds = thresholds if thresholds else [0.0] - self._extract_label_prediction_and_weight = ( - extract_label_prediction_and_weight) - - def create_accumulator(self) -> Matrices: - return Matrices() - - def add_input(self, accumulator: Matrices, - element: metric_types.StandardMetricInputs) -> Matrices: - label, predictions, example_weight = next( - self._extract_label_prediction_and_weight( - element, - eval_config=self._eval_config, - model_name=self._key.model_name, - output_name=self._key.output_name, - example_weighted=self._example_weighted, - flatten=False, - require_single_example_weight=True)) # pytype: disable=wrong-arg-types - if not label.shape: - raise ValueError( - 'Label missing from example: StandardMetricInputs={}'.format(element)) - if predictions.shape in ((), (1,)): - raise ValueError( - 'Predictions shape must be > 1 for multi-class confusion matrix: ' - 'shape={}, StandardMetricInputs={}'.format(predictions.shape, - element)) - if label.size > 1: - actual_class_id = np.argmax(label) - else: - actual_class_id = int(label) - predicted_class_id = np.argmax(predictions) - example_weight = float(example_weight) - for threshold in self._thresholds: - if threshold not in accumulator: - accumulator[threshold] = {} - if predictions[predicted_class_id] <= threshold: - predicted_class_id = NO_PREDICTED_CLASS_ID - matrix_key = MatrixEntryKey(actual_class_id, predicted_class_id) - if matrix_key in accumulator[threshold]: - accumulator[threshold][matrix_key] += example_weight - else: - accumulator[threshold][matrix_key] = example_weight - return accumulator - - def merge_accumulators(self, accumulators: Iterable[Matrices]) -> Matrices: - accumulators = iter(accumulators) - result = next(accumulators) - for accumulator in accumulators: - for threshold, matrix in accumulator.items(): - if threshold not in result: - result[threshold] = {} - for k, v in matrix.items(): - if k in result[threshold]: - result[threshold][k] += v - else: - result[threshold][k] = v - return result - - def extract_output( - self, accumulator: Matrices) -> Dict[metric_types.MetricKey, Matrices]: - return {self._key: accumulator} + """Creates multi-class confusion matrix at thresholds from standard inputs.""" + + def __init__( + self, + key: metric_types.MetricKey, + eval_config: Optional[config_pb2.EvalConfig], + example_weighted: bool, + thresholds: List[float], + extract_label_prediction_and_weight: Optional[Callable[..., Any]], + ): + self._key = key + self._eval_config = eval_config + self._example_weighted = example_weighted + self._thresholds = thresholds if thresholds else [0.0] + self._extract_label_prediction_and_weight = extract_label_prediction_and_weight + + def create_accumulator(self) -> Matrices: + return Matrices() + + def add_input( + self, accumulator: Matrices, element: metric_types.StandardMetricInputs + ) -> Matrices: + label, predictions, example_weight = next( + self._extract_label_prediction_and_weight( + element, + eval_config=self._eval_config, + model_name=self._key.model_name, + output_name=self._key.output_name, + example_weighted=self._example_weighted, + flatten=False, + require_single_example_weight=True, + ) + ) # pytype: disable=wrong-arg-types + if not label.shape: + raise ValueError( + f"Label missing from example: StandardMetricInputs={element}" + ) + if predictions.shape in ((), (1,)): + raise ValueError( + "Predictions shape must be > 1 for multi-class confusion matrix: " + f"shape={predictions.shape}, StandardMetricInputs={element}" + ) + if label.size > 1: + actual_class_id = np.argmax(label) + else: + actual_class_id = int(label) + predicted_class_id = np.argmax(predictions) + example_weight = float(example_weight) + for threshold in self._thresholds: + if threshold not in accumulator: + accumulator[threshold] = {} + if predictions[predicted_class_id] <= threshold: + predicted_class_id = NO_PREDICTED_CLASS_ID + matrix_key = MatrixEntryKey(actual_class_id, predicted_class_id) + if matrix_key in accumulator[threshold]: + accumulator[threshold][matrix_key] += example_weight + else: + accumulator[threshold][matrix_key] = example_weight + return accumulator + + def merge_accumulators(self, accumulators: Iterable[Matrices]) -> Matrices: + accumulators = iter(accumulators) + result = next(accumulators) + for accumulator in accumulators: + for threshold, matrix in accumulator.items(): + if threshold not in result: + result[threshold] = {} + for k, v in matrix.items(): + if k in result[threshold]: + result[threshold][k] += v + else: + result[threshold][k] = v + return result + + def extract_output( + self, accumulator: Matrices + ) -> Dict[metric_types.MetricKey, Matrices]: + return {self._key: accumulator} diff --git a/tensorflow_model_analysis/metrics/multi_class_confusion_matrix_metrics_test.py b/tensorflow_model_analysis/metrics/multi_class_confusion_matrix_metrics_test.py index 1b23aedb93..32e6fdb1f5 100644 --- a/tensorflow_model_analysis/metrics/multi_class_confusion_matrix_metrics_test.py +++ b/tensorflow_model_analysis/metrics/multi_class_confusion_matrix_metrics_test.py @@ -13,391 +13,416 @@ # limitations under the License. """Tests for multi-class confusion matrix metrics at thresholds.""" -from absl.testing import parameterized import apache_beam as beam -from apache_beam.testing import util import numpy as np import tensorflow as tf -from tensorflow_model_analysis.metrics import metric_types -from tensorflow_model_analysis.metrics import metric_util -from tensorflow_model_analysis.metrics import multi_class_confusion_matrix_metrics +from absl.testing import parameterized +from apache_beam.testing import util + +from tensorflow_model_analysis.metrics import ( + metric_types, + metric_util, + multi_class_confusion_matrix_metrics, +) from tensorflow_model_analysis.utils import test_util class MultiClassConfusionMatrixMetricsTest( test_util.TensorflowModelAnalysisTest, parameterized.TestCase ): + @parameterized.named_parameters( + { + "testcase_name": "_empty_thresholds", + "left": multi_class_confusion_matrix_metrics.Matrices({}), + "right": multi_class_confusion_matrix_metrics.Matrices({}), + "expected": multi_class_confusion_matrix_metrics.Matrices({}), + }, + { + "testcase_name": "_empty_entries", + "left": multi_class_confusion_matrix_metrics.Matrices({0.5: {}}), + "right": multi_class_confusion_matrix_metrics.Matrices({0.5: {}}), + "expected": multi_class_confusion_matrix_metrics.Matrices({0.5: {}}), + }, + { + "testcase_name": "_different_thresholds", + "left": multi_class_confusion_matrix_metrics.Matrices( + { + 0.5: { + multi_class_confusion_matrix_metrics.MatrixEntryKey( + actual_class_id=0, predicted_class_id=0 + ): 1.0 + } + } + ), + "right": multi_class_confusion_matrix_metrics.Matrices( + { + 0.75: { + multi_class_confusion_matrix_metrics.MatrixEntryKey( + actual_class_id=0, predicted_class_id=0 + ): 2.0 + } + } + ), + "expected": multi_class_confusion_matrix_metrics.Matrices( + { + 0.5: { + multi_class_confusion_matrix_metrics.MatrixEntryKey( + actual_class_id=0, predicted_class_id=0 + ): 1.0 + }, + 0.75: { + multi_class_confusion_matrix_metrics.MatrixEntryKey( + actual_class_id=0, predicted_class_id=0 + ): 2.0 + }, + } + ), + }, + { + "testcase_name": "_different_entries", + "left": multi_class_confusion_matrix_metrics.Matrices( + { + 0.5: { + multi_class_confusion_matrix_metrics.MatrixEntryKey( + actual_class_id=0, predicted_class_id=0 + ): 1.0 + } + } + ), + "right": multi_class_confusion_matrix_metrics.Matrices( + { + 0.5: { + multi_class_confusion_matrix_metrics.MatrixEntryKey( + actual_class_id=0, predicted_class_id=1 + ): 2.0 + } + } + ), + "expected": multi_class_confusion_matrix_metrics.Matrices( + { + 0.5: { + multi_class_confusion_matrix_metrics.MatrixEntryKey( + actual_class_id=0, predicted_class_id=0 + ): 1.0, + multi_class_confusion_matrix_metrics.MatrixEntryKey( + actual_class_id=0, predicted_class_id=1 + ): 2.0, + } + } + ), + }, + { + "testcase_name": "_same_thresholds_and_entries", + "left": multi_class_confusion_matrix_metrics.Matrices( + { + 0.5: { + multi_class_confusion_matrix_metrics.MatrixEntryKey( + actual_class_id=0, predicted_class_id=0 + ): 1.0, + multi_class_confusion_matrix_metrics.MatrixEntryKey( + actual_class_id=0, predicted_class_id=1 + ): 2.0, + multi_class_confusion_matrix_metrics.MatrixEntryKey( + actual_class_id=1, predicted_class_id=0 + ): 3.0, + multi_class_confusion_matrix_metrics.MatrixEntryKey( + actual_class_id=1, predicted_class_id=1 + ): 4.0, + }, + 0.75: { + multi_class_confusion_matrix_metrics.MatrixEntryKey( + actual_class_id=0, predicted_class_id=0 + ): 2.0, + multi_class_confusion_matrix_metrics.MatrixEntryKey( + actual_class_id=0, predicted_class_id=1 + ): 4.0, + multi_class_confusion_matrix_metrics.MatrixEntryKey( + actual_class_id=1, predicted_class_id=0 + ): 6.0, + multi_class_confusion_matrix_metrics.MatrixEntryKey( + actual_class_id=1, predicted_class_id=1 + ): 8.0, + }, + } + ), + "right": multi_class_confusion_matrix_metrics.Matrices( + { + 0.5: { + multi_class_confusion_matrix_metrics.MatrixEntryKey( + actual_class_id=0, predicted_class_id=0 + ): 1.0, + multi_class_confusion_matrix_metrics.MatrixEntryKey( + actual_class_id=0, predicted_class_id=1 + ): 3.0, + multi_class_confusion_matrix_metrics.MatrixEntryKey( + actual_class_id=1, predicted_class_id=0 + ): 5.0, + multi_class_confusion_matrix_metrics.MatrixEntryKey( + actual_class_id=1, predicted_class_id=1 + ): 7.0, + }, + 0.75: { + multi_class_confusion_matrix_metrics.MatrixEntryKey( + actual_class_id=0, predicted_class_id=0 + ): 2.0, + multi_class_confusion_matrix_metrics.MatrixEntryKey( + actual_class_id=0, predicted_class_id=1 + ): 6.0, + multi_class_confusion_matrix_metrics.MatrixEntryKey( + actual_class_id=1, predicted_class_id=0 + ): 10.0, + multi_class_confusion_matrix_metrics.MatrixEntryKey( + actual_class_id=1, predicted_class_id=1 + ): 14.0, + }, + } + ), + "expected": multi_class_confusion_matrix_metrics.Matrices( + { + 0.5: { + multi_class_confusion_matrix_metrics.MatrixEntryKey( + actual_class_id=0, predicted_class_id=0 + ): 2.0, + multi_class_confusion_matrix_metrics.MatrixEntryKey( + actual_class_id=0, predicted_class_id=1 + ): 5.0, + multi_class_confusion_matrix_metrics.MatrixEntryKey( + actual_class_id=1, predicted_class_id=0 + ): 8.0, + multi_class_confusion_matrix_metrics.MatrixEntryKey( + actual_class_id=1, predicted_class_id=1 + ): 11.0, + }, + 0.75: { + multi_class_confusion_matrix_metrics.MatrixEntryKey( + actual_class_id=0, predicted_class_id=0 + ): 4.0, + multi_class_confusion_matrix_metrics.MatrixEntryKey( + actual_class_id=0, predicted_class_id=1 + ): 10.0, + multi_class_confusion_matrix_metrics.MatrixEntryKey( + actual_class_id=1, predicted_class_id=0 + ): 16.0, + multi_class_confusion_matrix_metrics.MatrixEntryKey( + actual_class_id=1, predicted_class_id=1 + ): 22.0, + }, + } + ), + }, + { + "testcase_name": "_empty_thresholds_broadcast", + "left": multi_class_confusion_matrix_metrics.Matrices({}), + "right": 1.0, + "expected": multi_class_confusion_matrix_metrics.Matrices({}), + }, + { + "testcase_name": "_empty_entries_broadcast", + "left": multi_class_confusion_matrix_metrics.Matrices({0.5: {}}), + "right": 1.0, + "expected": multi_class_confusion_matrix_metrics.Matrices({0.5: {}}), + }, + { + "testcase_name": "_nonempty_thresholds_and_entries_broadcast", + "left": multi_class_confusion_matrix_metrics.Matrices( + { + 0.5: { + multi_class_confusion_matrix_metrics.MatrixEntryKey( + actual_class_id=0, predicted_class_id=0 + ): 1.0, + multi_class_confusion_matrix_metrics.MatrixEntryKey( + actual_class_id=0, predicted_class_id=1 + ): 2.0, + }, + } + ), + "right": 3.0, + "expected": multi_class_confusion_matrix_metrics.Matrices( + { + 0.5: { + multi_class_confusion_matrix_metrics.MatrixEntryKey( + actual_class_id=0, predicted_class_id=0 + ): 4.0, + multi_class_confusion_matrix_metrics.MatrixEntryKey( + actual_class_id=0, predicted_class_id=1 + ): 5.0, + }, + } + ), + }, + ) + def testAddMatrices(self, left, right, expected): + self.assertEqual(expected, left + right) - @parameterized.named_parameters( - { - 'testcase_name': '_empty_thresholds', - 'left': multi_class_confusion_matrix_metrics.Matrices({}), - 'right': multi_class_confusion_matrix_metrics.Matrices({}), - 'expected': multi_class_confusion_matrix_metrics.Matrices({}) - }, { - 'testcase_name': '_empty_entries', - 'left': multi_class_confusion_matrix_metrics.Matrices({0.5: {}}), - 'right': multi_class_confusion_matrix_metrics.Matrices({0.5: {}}), - 'expected': multi_class_confusion_matrix_metrics.Matrices({0.5: {}}) - }, { - 'testcase_name': - '_different_thresholds', - 'left': - multi_class_confusion_matrix_metrics.Matrices({ - 0.5: { - multi_class_confusion_matrix_metrics.MatrixEntryKey( - actual_class_id=0, predicted_class_id=0): - 1.0 - } - }), - 'right': - multi_class_confusion_matrix_metrics.Matrices({ - 0.75: { - multi_class_confusion_matrix_metrics.MatrixEntryKey( - actual_class_id=0, predicted_class_id=0): - 2.0 - } - }), - 'expected': - multi_class_confusion_matrix_metrics.Matrices({ - 0.5: { - multi_class_confusion_matrix_metrics.MatrixEntryKey( - actual_class_id=0, predicted_class_id=0): - 1.0 - }, - 0.75: { - multi_class_confusion_matrix_metrics.MatrixEntryKey( - actual_class_id=0, predicted_class_id=0): - 2.0 - } - }), - }, { - 'testcase_name': - '_different_entries', - 'left': - multi_class_confusion_matrix_metrics.Matrices({ - 0.5: { - multi_class_confusion_matrix_metrics.MatrixEntryKey( - actual_class_id=0, predicted_class_id=0): - 1.0 - } - }), - 'right': - multi_class_confusion_matrix_metrics.Matrices({ - 0.5: { - multi_class_confusion_matrix_metrics.MatrixEntryKey( - actual_class_id=0, predicted_class_id=1): - 2.0 - } - }), - 'expected': - multi_class_confusion_matrix_metrics.Matrices({ - 0.5: { - multi_class_confusion_matrix_metrics.MatrixEntryKey( - actual_class_id=0, predicted_class_id=0): - 1.0, - multi_class_confusion_matrix_metrics.MatrixEntryKey( - actual_class_id=0, predicted_class_id=1): - 2.0 - } - }), - }, { - 'testcase_name': - '_same_thresholds_and_entries', - 'left': - multi_class_confusion_matrix_metrics.Matrices({ - 0.5: { - multi_class_confusion_matrix_metrics.MatrixEntryKey( - actual_class_id=0, predicted_class_id=0): - 1.0, - multi_class_confusion_matrix_metrics.MatrixEntryKey( - actual_class_id=0, predicted_class_id=1): - 2.0, - multi_class_confusion_matrix_metrics.MatrixEntryKey( - actual_class_id=1, predicted_class_id=0): - 3.0, - multi_class_confusion_matrix_metrics.MatrixEntryKey( - actual_class_id=1, predicted_class_id=1): - 4.0, - }, - 0.75: { - multi_class_confusion_matrix_metrics.MatrixEntryKey( - actual_class_id=0, predicted_class_id=0): - 2.0, - multi_class_confusion_matrix_metrics.MatrixEntryKey( - actual_class_id=0, predicted_class_id=1): - 4.0, - multi_class_confusion_matrix_metrics.MatrixEntryKey( - actual_class_id=1, predicted_class_id=0): - 6.0, - multi_class_confusion_matrix_metrics.MatrixEntryKey( - actual_class_id=1, predicted_class_id=1): - 8.0, - } - }), - 'right': - multi_class_confusion_matrix_metrics.Matrices({ - 0.5: { - multi_class_confusion_matrix_metrics.MatrixEntryKey( - actual_class_id=0, predicted_class_id=0): - 1.0, - multi_class_confusion_matrix_metrics.MatrixEntryKey( - actual_class_id=0, predicted_class_id=1): - 3.0, - multi_class_confusion_matrix_metrics.MatrixEntryKey( - actual_class_id=1, predicted_class_id=0): - 5.0, - multi_class_confusion_matrix_metrics.MatrixEntryKey( - actual_class_id=1, predicted_class_id=1): - 7.0, - }, - 0.75: { - multi_class_confusion_matrix_metrics.MatrixEntryKey( - actual_class_id=0, predicted_class_id=0): - 2.0, - multi_class_confusion_matrix_metrics.MatrixEntryKey( - actual_class_id=0, predicted_class_id=1): - 6.0, - multi_class_confusion_matrix_metrics.MatrixEntryKey( - actual_class_id=1, predicted_class_id=0): - 10.0, - multi_class_confusion_matrix_metrics.MatrixEntryKey( - actual_class_id=1, predicted_class_id=1): - 14.0, - } - }), - 'expected': - multi_class_confusion_matrix_metrics.Matrices({ - 0.5: { - multi_class_confusion_matrix_metrics.MatrixEntryKey( - actual_class_id=0, predicted_class_id=0): - 2.0, - multi_class_confusion_matrix_metrics.MatrixEntryKey( - actual_class_id=0, predicted_class_id=1): - 5.0, - multi_class_confusion_matrix_metrics.MatrixEntryKey( - actual_class_id=1, predicted_class_id=0): - 8.0, - multi_class_confusion_matrix_metrics.MatrixEntryKey( - actual_class_id=1, predicted_class_id=1): - 11.0, - }, - 0.75: { - multi_class_confusion_matrix_metrics.MatrixEntryKey( - actual_class_id=0, predicted_class_id=0): - 4.0, - multi_class_confusion_matrix_metrics.MatrixEntryKey( - actual_class_id=0, predicted_class_id=1): - 10.0, - multi_class_confusion_matrix_metrics.MatrixEntryKey( - actual_class_id=1, predicted_class_id=0): - 16.0, - multi_class_confusion_matrix_metrics.MatrixEntryKey( - actual_class_id=1, predicted_class_id=1): - 22.0, - } - }), - }, { - 'testcase_name': '_empty_thresholds_broadcast', - 'left': multi_class_confusion_matrix_metrics.Matrices({}), - 'right': 1.0, - 'expected': multi_class_confusion_matrix_metrics.Matrices({}) - }, { - 'testcase_name': '_empty_entries_broadcast', - 'left': multi_class_confusion_matrix_metrics.Matrices({0.5: {}}), - 'right': 1.0, - 'expected': multi_class_confusion_matrix_metrics.Matrices({0.5: {}}) - }, { - 'testcase_name': - '_nonempty_thresholds_and_entries_broadcast', - 'left': - multi_class_confusion_matrix_metrics.Matrices({ - 0.5: { - multi_class_confusion_matrix_metrics.MatrixEntryKey( - actual_class_id=0, predicted_class_id=0): - 1.0, - multi_class_confusion_matrix_metrics.MatrixEntryKey( - actual_class_id=0, predicted_class_id=1): - 2.0, - }, - }), - 'right': - 3.0, - 'expected': - multi_class_confusion_matrix_metrics.Matrices({ - 0.5: { - multi_class_confusion_matrix_metrics.MatrixEntryKey( - actual_class_id=0, predicted_class_id=0): - 4.0, - multi_class_confusion_matrix_metrics.MatrixEntryKey( - actual_class_id=0, predicted_class_id=1): - 5.0, - }, - }), - }) - def testAddMatrices(self, left, right, expected): - self.assertEqual(expected, left + right) - - @parameterized.named_parameters(('using_default_thresholds', {}), - ('setting_thresholds', { - 'thresholds': [0.5] - })) - def testMultiClassConfusionMatrixAtThresholds(self, kwargs): - computations = ( - multi_class_confusion_matrix_metrics - .MultiClassConfusionMatrixAtThresholds(**kwargs).computations( - example_weighted=True)) - matrices = computations[0] - metrics = computations[1] + @parameterized.named_parameters( + ("using_default_thresholds", {}), ("setting_thresholds", {"thresholds": [0.5]}) + ) + def testMultiClassConfusionMatrixAtThresholds(self, kwargs): + computations = ( + multi_class_confusion_matrix_metrics.MultiClassConfusionMatrixAtThresholds( + **kwargs + ).computations(example_weighted=True) + ) + matrices = computations[0] + metrics = computations[1] - example1 = { - 'labels': np.array([2.0]), - 'predictions': np.array([0.2, 0.3, 0.5]), - 'example_weights': np.array([0.5]) - } - example2 = { - 'labels': np.array([0.0]), - 'predictions': np.array([0.1, 0.3, 0.6]), - 'example_weights': np.array([1.0]) - } - example3 = { - 'labels': np.array([1.0]), - 'predictions': np.array([0.3, 0.1, 0.6]), - 'example_weights': np.array([0.25]) - } - example4 = { - 'labels': np.array([1.0]), - 'predictions': np.array([0.1, 0.9, 0.0]), - 'example_weights': np.array([1.0]) - } - example5 = { - 'labels': np.array([1.0]), - 'predictions': np.array([0.1, 0.8, 0.1]), - 'example_weights': np.array([1.0]) - } - example6 = { - 'labels': np.array([2.0]), - 'predictions': np.array([0.3, 0.1, 0.6]), - 'example_weights': np.array([1.0]) - } + example1 = { + "labels": np.array([2.0]), + "predictions": np.array([0.2, 0.3, 0.5]), + "example_weights": np.array([0.5]), + } + example2 = { + "labels": np.array([0.0]), + "predictions": np.array([0.1, 0.3, 0.6]), + "example_weights": np.array([1.0]), + } + example3 = { + "labels": np.array([1.0]), + "predictions": np.array([0.3, 0.1, 0.6]), + "example_weights": np.array([0.25]), + } + example4 = { + "labels": np.array([1.0]), + "predictions": np.array([0.1, 0.9, 0.0]), + "example_weights": np.array([1.0]), + } + example5 = { + "labels": np.array([1.0]), + "predictions": np.array([0.1, 0.8, 0.1]), + "example_weights": np.array([1.0]), + } + example6 = { + "labels": np.array([2.0]), + "predictions": np.array([0.3, 0.1, 0.6]), + "example_weights": np.array([1.0]), + } - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'Create' >> beam.Create( - [example1, example2, example3, example4, example5, example6]) - | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs) - | 'AddSlice' >> beam.Map(lambda x: ((), x)) - | 'ComputeMatrices' >> beam.CombinePerKey(matrices.combiner) - | - 'ComputeMetrics' >> beam.Map(lambda x: (x[0], metrics.result(x[1])))) + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "Create" + >> beam.Create( + [example1, example2, example3, example4, example5, example6] + ) + | "Process" >> beam.Map(metric_util.to_standard_metric_inputs) + | "AddSlice" >> beam.Map(lambda x: ((), x)) + | "ComputeMatrices" >> beam.CombinePerKey(matrices.combiner) + | "ComputeMetrics" >> beam.Map(lambda x: (x[0], metrics.result(x[1]))) + ) - # pylint: enable=no-value-for-parameter + # pylint: enable=no-value-for-parameter - def check_result(got): - try: - self.assertLen(got, 1) - got_slice_key, got_metrics = got[0] - self.assertEqual(got_slice_key, ()) - self.assertLen(got_metrics, 1) - key = metric_types.MetricKey( - name='multi_class_confusion_matrix_at_thresholds', - example_weighted=True) - got_matrix = got_metrics[key] - self.assertEqual( - multi_class_confusion_matrix_metrics.Matrices({ - 0.5: { - multi_class_confusion_matrix_metrics.MatrixEntryKey( - actual_class_id=0, predicted_class_id=2): - 1.0, - multi_class_confusion_matrix_metrics.MatrixEntryKey( - actual_class_id=1, predicted_class_id=1): - 2.0, - multi_class_confusion_matrix_metrics.MatrixEntryKey( - actual_class_id=1, predicted_class_id=2): - 0.25, - multi_class_confusion_matrix_metrics.MatrixEntryKey( - actual_class_id=2, predicted_class_id=-1): - 0.5, - multi_class_confusion_matrix_metrics.MatrixEntryKey( - actual_class_id=2, predicted_class_id=2): - 1.0 - } - }), got_matrix) - except AssertionError as err: - raise util.BeamAssertException(err) + def check_result(got): + try: + self.assertLen(got, 1) + got_slice_key, got_metrics = got[0] + self.assertEqual(got_slice_key, ()) + self.assertLen(got_metrics, 1) + key = metric_types.MetricKey( + name="multi_class_confusion_matrix_at_thresholds", + example_weighted=True, + ) + got_matrix = got_metrics[key] + self.assertEqual( + multi_class_confusion_matrix_metrics.Matrices( + { + 0.5: { + multi_class_confusion_matrix_metrics.MatrixEntryKey( + actual_class_id=0, predicted_class_id=2 + ): 1.0, + multi_class_confusion_matrix_metrics.MatrixEntryKey( + actual_class_id=1, predicted_class_id=1 + ): 2.0, + multi_class_confusion_matrix_metrics.MatrixEntryKey( + actual_class_id=1, predicted_class_id=2 + ): 0.25, + multi_class_confusion_matrix_metrics.MatrixEntryKey( + actual_class_id=2, predicted_class_id=-1 + ): 0.5, + multi_class_confusion_matrix_metrics.MatrixEntryKey( + actual_class_id=2, predicted_class_id=2 + ): 1.0, + } + } + ), + got_matrix, + ) + except AssertionError as err: + raise util.BeamAssertException(err) - util.assert_that(result, check_result, label='result') + util.assert_that(result, check_result, label="result") - def testMultiClassConfusionMatrixAtThresholdsWithStringLabels(self): - computations = ( - multi_class_confusion_matrix_metrics - .MultiClassConfusionMatrixAtThresholds().computations( - example_weighted=True)) - matrices = computations[0] - metrics = computations[1] + def testMultiClassConfusionMatrixAtThresholdsWithStringLabels(self): + computations = multi_class_confusion_matrix_metrics.MultiClassConfusionMatrixAtThresholds().computations( + example_weighted=True + ) + matrices = computations[0] + metrics = computations[1] - example1 = { - 'labels': np.array([['unacc']]), - 'predictions': { - 'probabilities': - np.array([[ - 1.0000000e+00, 6.9407083e-24, 2.7419115e-38, 0.0000000e+00 - ]]), - 'all_classes': - np.array([['unacc', 'acc', 'vgood', 'good']]), - }, - 'example_weights': np.array([0.5]) - } - example2 = { - 'labels': np.array([['vgood']]), - 'predictions': { - 'probabilities': np.array([[0.2, 0.3, 0.4, 0.1]]), - 'all_classes': np.array([['unacc', 'acc', 'vgood', 'good']]), - }, - 'example_weights': np.array([1.0]) - } + example1 = { + "labels": np.array([["unacc"]]), + "predictions": { + "probabilities": np.array( + [[1.0000000e00, 6.9407083e-24, 2.7419115e-38, 0.0000000e00]] + ), + "all_classes": np.array([["unacc", "acc", "vgood", "good"]]), + }, + "example_weights": np.array([0.5]), + } + example2 = { + "labels": np.array([["vgood"]]), + "predictions": { + "probabilities": np.array([[0.2, 0.3, 0.4, 0.1]]), + "all_classes": np.array([["unacc", "acc", "vgood", "good"]]), + }, + "example_weights": np.array([1.0]), + } - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'Create' >> beam.Create([example1, example2]) - | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs) - | 'AddSlice' >> beam.Map(lambda x: ((), x)) - | 'ComputeMatrices' >> beam.CombinePerKey(matrices.combiner) - | - 'ComputeMetrics' >> beam.Map(lambda x: (x[0], metrics.result(x[1])))) + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "Create" >> beam.Create([example1, example2]) + | "Process" >> beam.Map(metric_util.to_standard_metric_inputs) + | "AddSlice" >> beam.Map(lambda x: ((), x)) + | "ComputeMatrices" >> beam.CombinePerKey(matrices.combiner) + | "ComputeMetrics" >> beam.Map(lambda x: (x[0], metrics.result(x[1]))) + ) - # pylint: enable=no-value-for-parameter + # pylint: enable=no-value-for-parameter - def check_result(got): - try: - self.assertLen(got, 1) - got_slice_key, got_metrics = got[0] - self.assertEqual(got_slice_key, ()) - self.assertLen(got_metrics, 1) - key = metric_types.MetricKey( - name='multi_class_confusion_matrix_at_thresholds', - example_weighted=True) - got_matrix = got_metrics[key] - self.assertEqual( - multi_class_confusion_matrix_metrics.Matrices({ - 0.5: { - multi_class_confusion_matrix_metrics.MatrixEntryKey( - actual_class_id=0, predicted_class_id=0): - 0.5, - multi_class_confusion_matrix_metrics.MatrixEntryKey( - actual_class_id=2, predicted_class_id=-1): - 1.0 - } - }), got_matrix) - except AssertionError as err: - raise util.BeamAssertException(err) + def check_result(got): + try: + self.assertLen(got, 1) + got_slice_key, got_metrics = got[0] + self.assertEqual(got_slice_key, ()) + self.assertLen(got_metrics, 1) + key = metric_types.MetricKey( + name="multi_class_confusion_matrix_at_thresholds", + example_weighted=True, + ) + got_matrix = got_metrics[key] + self.assertEqual( + multi_class_confusion_matrix_metrics.Matrices( + { + 0.5: { + multi_class_confusion_matrix_metrics.MatrixEntryKey( + actual_class_id=0, predicted_class_id=0 + ): 0.5, + multi_class_confusion_matrix_metrics.MatrixEntryKey( + actual_class_id=2, predicted_class_id=-1 + ): 1.0, + } + } + ), + got_matrix, + ) + except AssertionError as err: + raise util.BeamAssertException(err) - util.assert_that(result, check_result, label='result') + util.assert_that(result, check_result, label="result") -if __name__ == '__main__': - tf.test.main() +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow_model_analysis/metrics/multi_class_confusion_matrix_plot.py b/tensorflow_model_analysis/metrics/multi_class_confusion_matrix_plot.py index 3b0bd38d2d..959b74e258 100644 --- a/tensorflow_model_analysis/metrics/multi_class_confusion_matrix_plot.py +++ b/tensorflow_model_analysis/metrics/multi_class_confusion_matrix_plot.py @@ -15,52 +15,56 @@ from typing import Dict, List, Optional -from tensorflow_model_analysis.metrics import metric_types -from tensorflow_model_analysis.metrics import metric_util -from tensorflow_model_analysis.metrics import multi_class_confusion_matrix_metrics -from tensorflow_model_analysis.proto import config_pb2 -from tensorflow_model_analysis.proto import metrics_for_slice_pb2 +from tensorflow_model_analysis.metrics import ( + metric_types, + metric_util, + multi_class_confusion_matrix_metrics, +) +from tensorflow_model_analysis.proto import config_pb2, metrics_for_slice_pb2 -MULTI_CLASS_CONFUSION_MATRIX_PLOT_NAME = ('multi_class_confusion_matrix_plot') +MULTI_CLASS_CONFUSION_MATRIX_PLOT_NAME = "multi_class_confusion_matrix_plot" class MultiClassConfusionMatrixPlot(metric_types.Metric): - """Multi-class confusion matrix plot. - - Computes weighted example counts for all combinations of actual / (top) - predicted classes. - - The inputs are assumed to contain a single positive label per example (i.e. - only one class can be true at a time) while the predictions are assumed to sum - to 1.0. - """ - - def __init__(self, - thresholds: Optional[List[float]] = None, - num_thresholds: Optional[int] = None, - name: str = MULTI_CLASS_CONFUSION_MATRIX_PLOT_NAME): - """Initializes multi-class confusion matrix. - - Args: - thresholds: Optional thresholds. If the top prediction is less than a - threshold then the associated example will be assumed to have no - prediction associated with it (the predicted_class_id will be set to - tfma.metrics.NO_PREDICTED_CLASS_ID). Only one of - either thresholds or num_thresholds should be used. If both are unset, - then [0.0] will be assumed. - num_thresholds: Number of thresholds to use. The thresholds will be evenly - spaced between 0.0 and 1.0 and inclusive of the boundaries (i.e. to - configure the thresholds to [0.0, 0.25, 0.5, 0.75, 1.0], the parameter - should be set to 5). Only one of either thresholds or num_thresholds - should be used. - name: Metric name. + """Multi-class confusion matrix plot. + + Computes weighted example counts for all combinations of actual / (top) + predicted classes. + + The inputs are assumed to contain a single positive label per example (i.e. + only one class can be true at a time) while the predictions are assumed to sum + to 1.0. """ - super().__init__( - metric_util.merge_per_key_computations( - _multi_class_confusion_matrix_plot), - thresholds=thresholds, - num_thresholds=num_thresholds, - name=name) + + def __init__( + self, + thresholds: Optional[List[float]] = None, + num_thresholds: Optional[int] = None, + name: str = MULTI_CLASS_CONFUSION_MATRIX_PLOT_NAME, + ): + """Initializes multi-class confusion matrix. + + Args: + ---- + thresholds: Optional thresholds. If the top prediction is less than a + threshold then the associated example will be assumed to have no + prediction associated with it (the predicted_class_id will be set to + tfma.metrics.NO_PREDICTED_CLASS_ID). Only one of + either thresholds or num_thresholds should be used. If both are unset, + then [0.0] will be assumed. + num_thresholds: Number of thresholds to use. The thresholds will be evenly + spaced between 0.0 and 1.0 and inclusive of the boundaries (i.e. to + configure the thresholds to [0.0, 0.25, 0.5, 0.75, 1.0], the parameter + should be set to 5). Only one of either thresholds or num_thresholds + should be used. + name: Metric name. + """ + super().__init__( + metric_util.merge_per_key_computations(_multi_class_confusion_matrix_plot), + thresholds=thresholds, + num_thresholds=num_thresholds, + name=name, + ) metric_types.register_metric(MultiClassConfusionMatrixPlot) @@ -71,43 +75,51 @@ def _multi_class_confusion_matrix_plot( num_thresholds: Optional[int] = None, name: str = MULTI_CLASS_CONFUSION_MATRIX_PLOT_NAME, eval_config: Optional[config_pb2.EvalConfig] = None, - model_name: str = '', - output_name: str = '', - example_weighted: bool = False) -> metric_types.MetricComputations: - """Returns computations for multi-class confusion matrix plot.""" - if num_thresholds is None and thresholds is None: - thresholds = [0.0] - - key = metric_types.PlotKey( - name=name, - model_name=model_name, - output_name=output_name, - example_weighted=example_weighted) - - # Make sure matrices are calculated. - matrices_computations = ( - multi_class_confusion_matrix_metrics.multi_class_confusion_matrices( - thresholds=thresholds, - num_thresholds=num_thresholds, - eval_config=eval_config, - model_name=model_name, - output_name=output_name, - example_weighted=example_weighted)) - matrices_key = matrices_computations[-1].keys[-1] - - def result( - metrics: Dict[metric_types.MetricKey, - multi_class_confusion_matrix_metrics.Matrices] - ) -> Dict[metric_types.PlotKey, - metrics_for_slice_pb2.MultiClassConfusionMatrixAtThresholds]: - return { - key: - metrics[matrices_key].to_proto() + model_name: str = "", + output_name: str = "", + example_weighted: bool = False, +) -> metric_types.MetricComputations: + """Returns computations for multi-class confusion matrix plot.""" + if num_thresholds is None and thresholds is None: + thresholds = [0.0] + + key = metric_types.PlotKey( + name=name, + model_name=model_name, + output_name=output_name, + example_weighted=example_weighted, + ) + + # Make sure matrices are calculated. + matrices_computations = ( + multi_class_confusion_matrix_metrics.multi_class_confusion_matrices( + thresholds=thresholds, + num_thresholds=num_thresholds, + eval_config=eval_config, + model_name=model_name, + output_name=output_name, + example_weighted=example_weighted, + ) + ) + matrices_key = matrices_computations[-1].keys[-1] + + def result( + metrics: Dict[ + metric_types.MetricKey, multi_class_confusion_matrix_metrics.Matrices + ], + ) -> Dict[ + metric_types.PlotKey, + metrics_for_slice_pb2.MultiClassConfusionMatrixAtThresholds, + ]: + return { + key: metrics[matrices_key] + .to_proto() .multi_class_confusion_matrix_at_thresholds - } - - derived_computation = metric_types.DerivedMetricComputation( - keys=[key], result=result) - computations = matrices_computations - computations.append(derived_computation) - return computations + } + + derived_computation = metric_types.DerivedMetricComputation( + keys=[key], result=result + ) + computations = matrices_computations + computations.append(derived_computation) + return computations diff --git a/tensorflow_model_analysis/metrics/multi_class_confusion_matrix_plot_test.py b/tensorflow_model_analysis/metrics/multi_class_confusion_matrix_plot_test.py index 9b834c1c69..8a3eec2c72 100644 --- a/tensorflow_model_analysis/metrics/multi_class_confusion_matrix_plot_test.py +++ b/tensorflow_model_analysis/metrics/multi_class_confusion_matrix_plot_test.py @@ -13,83 +13,89 @@ # limitations under the License. """Tests for multi-class confusion matrix plot at thresholds.""" -from absl.testing import parameterized import apache_beam as beam -from apache_beam.testing import util import numpy as np import tensorflow as tf -from tensorflow_model_analysis.metrics import metric_types -from tensorflow_model_analysis.metrics import metric_util -from tensorflow_model_analysis.metrics import multi_class_confusion_matrix_plot +from absl.testing import parameterized +from apache_beam.testing import util + +from tensorflow_model_analysis.metrics import ( + metric_types, + metric_util, + multi_class_confusion_matrix_plot, +) from tensorflow_model_analysis.utils import test_util class MultiClassConfusionMatrixPlotTest( test_util.TensorflowModelAnalysisTest, parameterized.TestCase ): + def testMultiClassConfusionMatrixPlot(self): + computations = multi_class_confusion_matrix_plot.MultiClassConfusionMatrixPlot().computations( + example_weighted=True + ) + matrices = computations[0] + plot = computations[1] - def testMultiClassConfusionMatrixPlot(self): - computations = ( - multi_class_confusion_matrix_plot.MultiClassConfusionMatrixPlot() - .computations(example_weighted=True)) - matrices = computations[0] - plot = computations[1] - - example1 = { - 'labels': np.array([2.0]), - 'predictions': np.array([0.2, 0.3, 0.5]), - 'example_weights': np.array([0.5]) - } - example2 = { - 'labels': np.array([0.0]), - 'predictions': np.array([0.1, 0.4, 0.5]), - 'example_weights': np.array([1.0]) - } - example3 = { - 'labels': np.array([1.0]), - 'predictions': np.array([0.3, 0.2, 0.5]), - 'example_weights': np.array([0.25]) - } - example4 = { - 'labels': np.array([1.0]), - 'predictions': np.array([0.1, 0.9, 0.0]), - 'example_weights': np.array([1.0]) - } - example5 = { - 'labels': np.array([1.0]), - 'predictions': np.array([0.1, 0.8, 0.1]), - 'example_weights': np.array([1.0]) - } - example6 = { - 'labels': np.array([2.0]), - 'predictions': np.array([0.3, 0.2, 0.5]), - 'example_weights': np.array([1.0]) - } + example1 = { + "labels": np.array([2.0]), + "predictions": np.array([0.2, 0.3, 0.5]), + "example_weights": np.array([0.5]), + } + example2 = { + "labels": np.array([0.0]), + "predictions": np.array([0.1, 0.4, 0.5]), + "example_weights": np.array([1.0]), + } + example3 = { + "labels": np.array([1.0]), + "predictions": np.array([0.3, 0.2, 0.5]), + "example_weights": np.array([0.25]), + } + example4 = { + "labels": np.array([1.0]), + "predictions": np.array([0.1, 0.9, 0.0]), + "example_weights": np.array([1.0]), + } + example5 = { + "labels": np.array([1.0]), + "predictions": np.array([0.1, 0.8, 0.1]), + "example_weights": np.array([1.0]), + } + example6 = { + "labels": np.array([2.0]), + "predictions": np.array([0.3, 0.2, 0.5]), + "example_weights": np.array([1.0]), + } - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'Create' >> beam.Create( - [example1, example2, example3, example4, example5, example6]) - | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs) - | 'AddSlice' >> beam.Map(lambda x: ((), x)) - | 'ComputeMatrices' >> beam.CombinePerKey(matrices.combiner) - | 'ComputePlot' >> beam.Map(lambda x: (x[0], plot.result(x[1])))) + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "Create" + >> beam.Create( + [example1, example2, example3, example4, example5, example6] + ) + | "Process" >> beam.Map(metric_util.to_standard_metric_inputs) + | "AddSlice" >> beam.Map(lambda x: ((), x)) + | "ComputeMatrices" >> beam.CombinePerKey(matrices.combiner) + | "ComputePlot" >> beam.Map(lambda x: (x[0], plot.result(x[1]))) + ) - # pylint: enable=no-value-for-parameter + # pylint: enable=no-value-for-parameter - def check_result(got): - try: - self.assertLen(got, 1) - got_slice_key, got_plots = got[0] - self.assertEqual(got_slice_key, ()) - self.assertLen(got_plots, 1) - key = metric_types.PlotKey( - name='multi_class_confusion_matrix_plot', example_weighted=True) - got_matrix = got_plots[key] - self.assertProtoEquals( - """ + def check_result(got): + try: + self.assertLen(got, 1) + got_slice_key, got_plots = got[0] + self.assertEqual(got_slice_key, ()) + self.assertLen(got_plots, 1) + key = metric_types.PlotKey( + name="multi_class_confusion_matrix_plot", example_weighted=True + ) + got_matrix = got_plots[key] + self.assertProtoEquals( + """ matrices { threshold: 0.0 entries { @@ -113,79 +119,83 @@ def check_result(got): num_weighted_examples: 1.5 } } - """, got_matrix) + """, + got_matrix, + ) - except AssertionError as err: - raise util.BeamAssertException(err) + except AssertionError as err: + raise util.BeamAssertException(err) - util.assert_that(result, check_result, label='result') + util.assert_that(result, check_result, label="result") - @parameterized.named_parameters(('using_num_thresholds', { - 'num_thresholds': 3 - }), ('using_thresholds', { - 'thresholds': [0.0, 0.5, 1.0] - })) - def testMultiClassConfusionMatrixPlotWithThresholds(self, kwargs): - computations = ( - multi_class_confusion_matrix_plot.MultiClassConfusionMatrixPlot( - **kwargs).computations()) - matrices = computations[0] - plot = computations[1] + @parameterized.named_parameters( + ("using_num_thresholds", {"num_thresholds": 3}), + ("using_thresholds", {"thresholds": [0.0, 0.5, 1.0]}), + ) + def testMultiClassConfusionMatrixPlotWithThresholds(self, kwargs): + computations = multi_class_confusion_matrix_plot.MultiClassConfusionMatrixPlot( + **kwargs + ).computations() + matrices = computations[0] + plot = computations[1] - example1 = { - 'labels': np.array([2.0]), - 'predictions': np.array([0.2, 0.35, 0.45]), - 'example_weights': np.array([1.0]) - } - example2 = { - 'labels': np.array([0.0]), - 'predictions': np.array([0.1, 0.35, 0.55]), - 'example_weights': np.array([1.0]) - } - example3 = { - 'labels': np.array([1.0]), - 'predictions': np.array([0.3, 0.25, 0.45]), - 'example_weights': np.array([1.0]) - } - example4 = { - 'labels': np.array([1.0]), - 'predictions': np.array([0.1, 0.9, 0.0]), - 'example_weights': np.array([1.0]) - } - example5 = { - 'labels': np.array([1.0]), - 'predictions': np.array([0.1, 0.8, 0.1]), - 'example_weights': np.array([1.0]) - } - example6 = { - 'labels': np.array([2.0]), - 'predictions': np.array([0.3, 0.25, 0.45]), - 'example_weights': np.array([1.0]) - } + example1 = { + "labels": np.array([2.0]), + "predictions": np.array([0.2, 0.35, 0.45]), + "example_weights": np.array([1.0]), + } + example2 = { + "labels": np.array([0.0]), + "predictions": np.array([0.1, 0.35, 0.55]), + "example_weights": np.array([1.0]), + } + example3 = { + "labels": np.array([1.0]), + "predictions": np.array([0.3, 0.25, 0.45]), + "example_weights": np.array([1.0]), + } + example4 = { + "labels": np.array([1.0]), + "predictions": np.array([0.1, 0.9, 0.0]), + "example_weights": np.array([1.0]), + } + example5 = { + "labels": np.array([1.0]), + "predictions": np.array([0.1, 0.8, 0.1]), + "example_weights": np.array([1.0]), + } + example6 = { + "labels": np.array([2.0]), + "predictions": np.array([0.3, 0.25, 0.45]), + "example_weights": np.array([1.0]), + } - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'Create' >> beam.Create( - [example1, example2, example3, example4, example5, example6]) - | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs) - | 'AddSlice' >> beam.Map(lambda x: ((), x)) - | 'ComputeMatrices' >> beam.CombinePerKey(matrices.combiner) - | 'ComputePlot' >> beam.Map(lambda x: (x[0], plot.result(x[1])))) + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "Create" + >> beam.Create( + [example1, example2, example3, example4, example5, example6] + ) + | "Process" >> beam.Map(metric_util.to_standard_metric_inputs) + | "AddSlice" >> beam.Map(lambda x: ((), x)) + | "ComputeMatrices" >> beam.CombinePerKey(matrices.combiner) + | "ComputePlot" >> beam.Map(lambda x: (x[0], plot.result(x[1]))) + ) - # pylint: enable=no-value-for-parameter + # pylint: enable=no-value-for-parameter - def check_result(got): - try: - self.assertLen(got, 1) - got_slice_key, got_plots = got[0] - self.assertEqual(got_slice_key, ()) - self.assertLen(got_plots, 1) - key = metric_types.PlotKey(name='multi_class_confusion_matrix_plot') - got_matrix = got_plots[key] - self.assertProtoEquals( - """ + def check_result(got): + try: + self.assertLen(got, 1) + got_slice_key, got_plots = got[0] + self.assertEqual(got_slice_key, ()) + self.assertLen(got_plots, 1) + key = metric_types.PlotKey(name="multi_class_confusion_matrix_plot") + got_matrix = got_plots[key] + self.assertProtoEquals( + """ matrices { threshold: 0.0 entries { @@ -249,65 +259,67 @@ def check_result(got): num_weighted_examples: 2.0 } } - """, got_matrix) + """, + got_matrix, + ) - except AssertionError as err: - raise util.BeamAssertException(err) + except AssertionError as err: + raise util.BeamAssertException(err) - util.assert_that(result, check_result, label='result') + util.assert_that(result, check_result, label="result") - def testMultiClassConfusionMatrixPlotWithStringLabels(self): - computations = ( - multi_class_confusion_matrix_plot.MultiClassConfusionMatrixPlot() - .computations(example_weighted=True)) - matrices = computations[0] - plot = computations[1] + def testMultiClassConfusionMatrixPlotWithStringLabels(self): + computations = multi_class_confusion_matrix_plot.MultiClassConfusionMatrixPlot().computations( + example_weighted=True + ) + matrices = computations[0] + plot = computations[1] - # Examples from b/149558504. - example1 = { - 'labels': np.array([['unacc']]), - 'predictions': { - 'probabilities': - np.array([[ - 1.0000000e+00, 6.9407083e-24, 2.7419115e-38, 0.0000000e+00 - ]]), - 'all_classes': - np.array([['unacc', 'acc', 'vgood', 'good']]), - }, - 'example_weights': np.array([0.5]) - } - example2 = { - 'labels': np.array([['vgood']]), - 'predictions': { - 'probabilities': np.array([[0.2, 0.3, 0.4, 0.1]]), - 'all_classes': np.array([['unacc', 'acc', 'vgood', 'good']]), - }, - 'example_weights': np.array([1.0]) - } + # Examples from b/149558504. + example1 = { + "labels": np.array([["unacc"]]), + "predictions": { + "probabilities": np.array( + [[1.0000000e00, 6.9407083e-24, 2.7419115e-38, 0.0000000e00]] + ), + "all_classes": np.array([["unacc", "acc", "vgood", "good"]]), + }, + "example_weights": np.array([0.5]), + } + example2 = { + "labels": np.array([["vgood"]]), + "predictions": { + "probabilities": np.array([[0.2, 0.3, 0.4, 0.1]]), + "all_classes": np.array([["unacc", "acc", "vgood", "good"]]), + }, + "example_weights": np.array([1.0]), + } - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'Create' >> beam.Create([example1, example2]) - | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs) - | 'AddSlice' >> beam.Map(lambda x: ((), x)) - | 'ComputeMatrices' >> beam.CombinePerKey(matrices.combiner) - | 'ComputePlot' >> beam.Map(lambda x: (x[0], plot.result(x[1])))) + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "Create" >> beam.Create([example1, example2]) + | "Process" >> beam.Map(metric_util.to_standard_metric_inputs) + | "AddSlice" >> beam.Map(lambda x: ((), x)) + | "ComputeMatrices" >> beam.CombinePerKey(matrices.combiner) + | "ComputePlot" >> beam.Map(lambda x: (x[0], plot.result(x[1]))) + ) - # pylint: enable=no-value-for-parameter + # pylint: enable=no-value-for-parameter - def check_result(got): - try: - self.assertLen(got, 1) - got_slice_key, got_plots = got[0] - self.assertEqual(got_slice_key, ()) - self.assertLen(got_plots, 1) - key = metric_types.PlotKey( - name='multi_class_confusion_matrix_plot', example_weighted=True) - got_matrix = got_plots[key] - self.assertProtoEquals( - """ + def check_result(got): + try: + self.assertLen(got, 1) + got_slice_key, got_plots = got[0] + self.assertEqual(got_slice_key, ()) + self.assertLen(got_plots, 1) + key = metric_types.PlotKey( + name="multi_class_confusion_matrix_plot", example_weighted=True + ) + got_matrix = got_plots[key] + self.assertProtoEquals( + """ matrices { threshold: 0.0 entries { @@ -321,13 +333,15 @@ def check_result(got): num_weighted_examples: 1.0 } } - """, got_matrix) + """, + got_matrix, + ) - except AssertionError as err: - raise util.BeamAssertException(err) + except AssertionError as err: + raise util.BeamAssertException(err) - util.assert_that(result, check_result, label='result') + util.assert_that(result, check_result, label="result") -if __name__ == '__main__': - tf.test.main() +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow_model_analysis/metrics/multi_label_confusion_matrix_plot.py b/tensorflow_model_analysis/metrics/multi_label_confusion_matrix_plot.py index 1c7e23e1a1..94a46e25b0 100644 --- a/tensorflow_model_analysis/metrics/multi_label_confusion_matrix_plot.py +++ b/tensorflow_model_analysis/metrics/multi_label_confusion_matrix_plot.py @@ -13,86 +13,88 @@ # limitations under the License. """Multi-label confusion matrix at thresholds.""" -from typing import Dict, Iterable, List, Optional, NamedTuple +from typing import Dict, Iterable, List, NamedTuple, Optional import apache_beam as beam -from tensorflow_model_analysis.metrics import metric_types -from tensorflow_model_analysis.metrics import metric_util -from tensorflow_model_analysis.proto import config_pb2 -from tensorflow_model_analysis.proto import metrics_for_slice_pb2 -MULTI_LABEL_CONFUSION_MATRIX_PLOT_NAME = ('multi_label_confusion_matrix_plot') +from tensorflow_model_analysis.metrics import metric_types, metric_util +from tensorflow_model_analysis.proto import config_pb2, metrics_for_slice_pb2 + +MULTI_LABEL_CONFUSION_MATRIX_PLOT_NAME = "multi_label_confusion_matrix_plot" _EPSILON = 1e-7 class MultiLabelConfusionMatrixPlot(metric_types.Metric): - """Multi-label confusion matrix. - - For each actual class (positive label) a confusion matrix is computed for each - class based on the associated predicted values such that: - - TP = positive_prediction_class_label & positive_prediction - TN = negative_prediction_class_label & negative_prediction - FP = negative_prediction_class_label & positive_prediction - FN = positive_prediction_class_label & negative_prediction - - For example, given classes 0, 1 and a given threshold, the following matrices - will be computed: - - Actual: class_0 - Predicted: class_0 - TP = is_class_0 & is_class_0 & predict_class_0 - TN = is_class_0 & not_class_0 & predict_not_class_0 - FN = is_class_0 & is_class_0 & predict_not_class_0 - FP = is_class_0 & not_class_0 & predict_class_0 - Actual: class_0 - Predicted: class_1 - TP = is_class_0 & is_class_1 & predict_class_1 - TN = is_class_0 & not_class_1 & predict_not_class_1 - FN = is_class_0 & is_class_1 & predict_not_class_1 - FP = is_class_0 & not_class_1 & predict_class_1 - Actual: class_1 - Predicted: class_0 - TP = is_class_1 & is_class_0 & predict_class_0 - TN = is_class_1 & not_class_0 & predict_not_class_0 - FN = is_class_1 & is_class_0 & predict_not_class_0 - FP = is_class_1 & not_class_0 & predict_class_0 - Actual: class_1 - Predicted: class_1 - TP = is_class_1 & is_class_1 & predict_class_1 - TN = is_class_1 & not_class_1 & predict_not_class_1 - FN = is_class_1 & is_class_1 & predict_not_class_1 - FP = is_class_1 & not_class_1 & predict_class_1 - - Note that unlike the multi-class confusion matrix, the inputs are assumed to - be multi-label whereby the predictions may not necessarily sum to 1.0 and - multiple classes can be true as the same time. - """ - - def __init__(self, - thresholds: Optional[List[float]] = None, - num_thresholds: Optional[int] = None, - name: str = MULTI_LABEL_CONFUSION_MATRIX_PLOT_NAME): - """Initializes multi-label confusion matrix. - - Args: - thresholds: Optional thresholds. Only one of either thresholds or - num_thresholds should be used. If both are unset, then [0.5] will be - assumed. - num_thresholds: Number of thresholds to use. The thresholds will be evenly - spaced between 0.0 and 1.0 and inclusive of the boundaries (i.e. to - configure the thresholds to [0.0, 0.25, 0.5, 0.75, 1.0], the parameter - should be set to 5). Only one of either thresholds or num_thresholds - should be used. - name: Metric name. + """Multi-label confusion matrix. + + For each actual class (positive label) a confusion matrix is computed for each + class based on the associated predicted values such that: + + TP = positive_prediction_class_label & positive_prediction + TN = negative_prediction_class_label & negative_prediction + FP = negative_prediction_class_label & positive_prediction + FN = positive_prediction_class_label & negative_prediction + + For example, given classes 0, 1 and a given threshold, the following matrices + will be computed: + + Actual: class_0 + Predicted: class_0 + TP = is_class_0 & is_class_0 & predict_class_0 + TN = is_class_0 & not_class_0 & predict_not_class_0 + FN = is_class_0 & is_class_0 & predict_not_class_0 + FP = is_class_0 & not_class_0 & predict_class_0 + Actual: class_0 + Predicted: class_1 + TP = is_class_0 & is_class_1 & predict_class_1 + TN = is_class_0 & not_class_1 & predict_not_class_1 + FN = is_class_0 & is_class_1 & predict_not_class_1 + FP = is_class_0 & not_class_1 & predict_class_1 + Actual: class_1 + Predicted: class_0 + TP = is_class_1 & is_class_0 & predict_class_0 + TN = is_class_1 & not_class_0 & predict_not_class_0 + FN = is_class_1 & is_class_0 & predict_not_class_0 + FP = is_class_1 & not_class_0 & predict_class_0 + Actual: class_1 + Predicted: class_1 + TP = is_class_1 & is_class_1 & predict_class_1 + TN = is_class_1 & not_class_1 & predict_not_class_1 + FN = is_class_1 & is_class_1 & predict_not_class_1 + FP = is_class_1 & not_class_1 & predict_class_1 + + Note that unlike the multi-class confusion matrix, the inputs are assumed to + be multi-label whereby the predictions may not necessarily sum to 1.0 and + multiple classes can be true as the same time. """ - super().__init__( - metric_util.merge_per_key_computations( - _multi_label_confusion_matrix_plot), - thresholds=thresholds, - num_thresholds=num_thresholds, - name=name) + + def __init__( + self, + thresholds: Optional[List[float]] = None, + num_thresholds: Optional[int] = None, + name: str = MULTI_LABEL_CONFUSION_MATRIX_PLOT_NAME, + ): + """Initializes multi-label confusion matrix. + + Args: + ---- + thresholds: Optional thresholds. Only one of either thresholds or + num_thresholds should be used. If both are unset, then [0.5] will be + assumed. + num_thresholds: Number of thresholds to use. The thresholds will be evenly + spaced between 0.0 and 1.0 and inclusive of the boundaries (i.e. to + configure the thresholds to [0.0, 0.25, 0.5, 0.75, 1.0], the parameter + should be set to 5). Only one of either thresholds or num_thresholds + should be used. + name: Metric name. + """ + super().__init__( + metric_util.merge_per_key_computations(_multi_label_confusion_matrix_plot), + thresholds=thresholds, + num_thresholds=num_thresholds, + name=name, + ) metric_types.register_metric(MultiLabelConfusionMatrixPlot) @@ -103,54 +105,63 @@ def _multi_label_confusion_matrix_plot( num_thresholds: Optional[int] = None, name: str = MULTI_LABEL_CONFUSION_MATRIX_PLOT_NAME, eval_config: Optional[config_pb2.EvalConfig] = None, - model_name: str = '', - output_name: str = '', + model_name: str = "", + output_name: str = "", example_weighted: bool = False, ) -> metric_types.MetricComputations: - """Returns computations for multi-label confusion matrix at thresholds.""" - if num_thresholds is not None and thresholds is not None: - raise ValueError( - 'only one of thresholds or num_thresholds can be set at a time') - if num_thresholds is None and thresholds is None: - thresholds = [0.5] - if num_thresholds is not None: - thresholds = [ - (i + 1) * 1.0 / (num_thresholds - 1) for i in range(num_thresholds - 2) + """Returns computations for multi-label confusion matrix at thresholds.""" + if num_thresholds is not None and thresholds is not None: + raise ValueError( + "only one of thresholds or num_thresholds can be set at a time" + ) + if num_thresholds is None and thresholds is None: + thresholds = [0.5] + if num_thresholds is not None: + thresholds = [ + (i + 1) * 1.0 / (num_thresholds - 1) for i in range(num_thresholds - 2) + ] + thresholds = [-_EPSILON] + thresholds + [1.0 + _EPSILON] + + key = metric_types.PlotKey( + name=name, + model_name=model_name, + output_name=output_name, + example_weighted=example_weighted, + ) + return [ + metric_types.MetricComputation( + keys=[key], + preprocessors=None, + combiner=_MultiLabelConfusionMatrixPlotCombiner( + key=key, + eval_config=eval_config, + example_weighted=example_weighted, + thresholds=thresholds, + ), + ) ] - thresholds = [-_EPSILON] + thresholds + [1.0 + _EPSILON] - - key = metric_types.PlotKey( - name=name, - model_name=model_name, - output_name=output_name, - example_weighted=example_weighted) - return [ - metric_types.MetricComputation( - keys=[key], - preprocessors=None, - combiner=_MultiLabelConfusionMatrixPlotCombiner( - key=key, - eval_config=eval_config, - example_weighted=example_weighted, - thresholds=thresholds)) - ] -_MatrixEntryKey = NamedTuple('_MatrixEntryKey', [('actual_class_id', int), - ('predicted_class_id', int)]) +class _MatrixEntryKey(NamedTuple): + actual_class_id: int + predicted_class_id: int class _ConfusionMatrix: - """Confusion matrix.""" - __slots__ = [ - 'false_negatives', 'true_negatives', 'false_positives', 'true_positives' - ] + """Confusion matrix.""" + + __slots__ = [ + "false_negatives", + "true_negatives", + "false_positives", + "true_positives", + ] - def __init__(self): - self.false_negatives = 0.0 - self.true_negatives = 0.0 - self.false_positives = 0.0 - self.true_positives = 0.0 + def __init__(self): + self.false_negatives = 0.0 + self.true_negatives = 0.0 + self.false_positives = 0.0 + self.true_positives = 0.0 # Thresholds -> entry -> confusion matrix @@ -158,110 +169,124 @@ def __init__(self): class _MultiLabelConfusionMatrixPlotCombiner(beam.CombineFn): - """Creates multi-label confusion matrix at thresholds from standard inputs.""" - - def __init__(self, key: metric_types.PlotKey, - eval_config: Optional[config_pb2.EvalConfig], - example_weighted: bool, thresholds: List[float]): - self._key = key - self._eval_config = eval_config - self._example_weighted = example_weighted - self._thresholds = thresholds if thresholds else [0.5] - - def create_accumulator(self) -> _Matrices: - return {} - - def add_input(self, accumulator: _Matrices, - element: metric_types.StandardMetricInputs) -> _Matrices: - labels, predictions, example_weight = next( - metric_util.to_label_prediction_example_weight( - element, - eval_config=self._eval_config, - model_name=self._key.model_name, - output_name=self._key.output_name, - example_weighted=self._example_weighted, - flatten=False, - require_single_example_weight=True)) - if not labels.shape: - raise ValueError( - 'Labels missing from example: StandardMetricInputs={}'.format( - element)) - if predictions.shape in ((), (1,)): - raise ValueError( - 'Predictions shape must be > 1 for multi-label confusion matrix: ' - 'shape={}, StandardMetricInputs={}'.format(predictions.shape, - element)) - # If the label and prediction shapes are different then assume the labels - # are sparse and convert them to dense. - if (len(labels.shape) != len(predictions.shape) or - labels.shape[-1] != predictions.shape[-1]): - labels = metric_util.one_hot(labels, predictions) - example_weight = float(example_weight) - for threshold in self._thresholds: - if threshold not in accumulator: - accumulator[threshold] = {} - for actual_class_id, label in enumerate(labels): - if not label: - continue - for class_id, prediction in enumerate(predictions): - matrix_key = _MatrixEntryKey(actual_class_id, class_id) - fn = (labels[class_id] and prediction <= threshold) * example_weight - fp = (not labels[class_id] and - prediction > threshold) * example_weight - tn = ((not labels[class_id] and prediction <= threshold) * - example_weight) - tp = (labels[class_id] and prediction > threshold) * example_weight - if matrix_key in accumulator[threshold]: - accumulator[threshold][matrix_key].false_negatives += fn - accumulator[threshold][matrix_key].true_negatives += tn - accumulator[threshold][matrix_key].false_positives += fp - accumulator[threshold][matrix_key].true_positives += tp - else: - matrix = _ConfusionMatrix() - matrix.false_negatives = fn - matrix.true_negatives = tn - matrix.false_positives = fp - matrix.true_positives = tp - accumulator[threshold][matrix_key] = matrix - return accumulator - - def merge_accumulators(self, accumulators: Iterable[_Matrices]) -> _Matrices: - accumulators = iter(accumulators) - result = next(accumulators) - for accumulator in accumulators: - for threshold, matrix in accumulator.items(): - if threshold not in result: - result[threshold] = {} - for k, v in matrix.items(): - if k in result[threshold]: - result[threshold][k].false_negatives += v.false_negatives - result[threshold][k].true_negatives += v.true_negatives - result[threshold][k].false_positives += v.false_positives - result[threshold][k].true_positives += v.true_positives - else: - result[threshold][k] = v - return result - - def extract_output( - self, accumulator: _Matrices - ) -> Dict[metric_types.PlotKey, - metrics_for_slice_pb2.MultiLabelConfusionMatrixAtThresholds]: - pb = metrics_for_slice_pb2.MultiLabelConfusionMatrixAtThresholds() - for threshold in sorted(accumulator): - # Convert -epsilon and 1.0+epsilon back to 0.0 and 1.0. - if threshold == -_EPSILON: - t = 0.0 - elif threshold == 1.0 + _EPSILON: - t = 1.0 - else: - t = threshold - matrix = pb.matrices.add(threshold=t) - for k in sorted(accumulator[threshold]): - matrix.entries.add( - actual_class_id=k.actual_class_id, - predicted_class_id=k.predicted_class_id, - false_negatives=accumulator[threshold][k].false_negatives, - true_negatives=accumulator[threshold][k].true_negatives, - false_positives=accumulator[threshold][k].false_positives, - true_positives=accumulator[threshold][k].true_positives) - return {self._key: pb} + """Creates multi-label confusion matrix at thresholds from standard inputs.""" + + def __init__( + self, + key: metric_types.PlotKey, + eval_config: Optional[config_pb2.EvalConfig], + example_weighted: bool, + thresholds: List[float], + ): + self._key = key + self._eval_config = eval_config + self._example_weighted = example_weighted + self._thresholds = thresholds if thresholds else [0.5] + + def create_accumulator(self) -> _Matrices: + return {} + + def add_input( + self, accumulator: _Matrices, element: metric_types.StandardMetricInputs + ) -> _Matrices: + labels, predictions, example_weight = next( + metric_util.to_label_prediction_example_weight( + element, + eval_config=self._eval_config, + model_name=self._key.model_name, + output_name=self._key.output_name, + example_weighted=self._example_weighted, + flatten=False, + require_single_example_weight=True, + ) + ) + if not labels.shape: + raise ValueError( + f"Labels missing from example: StandardMetricInputs={element}" + ) + if predictions.shape in ((), (1,)): + raise ValueError( + "Predictions shape must be > 1 for multi-label confusion matrix: " + f"shape={predictions.shape}, StandardMetricInputs={element}" + ) + # If the label and prediction shapes are different then assume the labels + # are sparse and convert them to dense. + if ( + len(labels.shape) != len(predictions.shape) + or labels.shape[-1] != predictions.shape[-1] + ): + labels = metric_util.one_hot(labels, predictions) + example_weight = float(example_weight) + for threshold in self._thresholds: + if threshold not in accumulator: + accumulator[threshold] = {} + for actual_class_id, label in enumerate(labels): + if not label: + continue + for class_id, prediction in enumerate(predictions): + matrix_key = _MatrixEntryKey(actual_class_id, class_id) + fn = (labels[class_id] and prediction <= threshold) * example_weight + fp = ( + not labels[class_id] and prediction > threshold + ) * example_weight + tn = ( + not labels[class_id] and prediction <= threshold + ) * example_weight + tp = (labels[class_id] and prediction > threshold) * example_weight + if matrix_key in accumulator[threshold]: + accumulator[threshold][matrix_key].false_negatives += fn + accumulator[threshold][matrix_key].true_negatives += tn + accumulator[threshold][matrix_key].false_positives += fp + accumulator[threshold][matrix_key].true_positives += tp + else: + matrix = _ConfusionMatrix() + matrix.false_negatives = fn + matrix.true_negatives = tn + matrix.false_positives = fp + matrix.true_positives = tp + accumulator[threshold][matrix_key] = matrix + return accumulator + + def merge_accumulators(self, accumulators: Iterable[_Matrices]) -> _Matrices: + accumulators = iter(accumulators) + result = next(accumulators) + for accumulator in accumulators: + for threshold, matrix in accumulator.items(): + if threshold not in result: + result[threshold] = {} + for k, v in matrix.items(): + if k in result[threshold]: + result[threshold][k].false_negatives += v.false_negatives + result[threshold][k].true_negatives += v.true_negatives + result[threshold][k].false_positives += v.false_positives + result[threshold][k].true_positives += v.true_positives + else: + result[threshold][k] = v + return result + + def extract_output( + self, accumulator: _Matrices + ) -> Dict[ + metric_types.PlotKey, + metrics_for_slice_pb2.MultiLabelConfusionMatrixAtThresholds, + ]: + pb = metrics_for_slice_pb2.MultiLabelConfusionMatrixAtThresholds() + for threshold in sorted(accumulator): + # Convert -epsilon and 1.0+epsilon back to 0.0 and 1.0. + if threshold == -_EPSILON: + t = 0.0 + elif threshold == 1.0 + _EPSILON: + t = 1.0 + else: + t = threshold + matrix = pb.matrices.add(threshold=t) + for k in sorted(accumulator[threshold]): + matrix.entries.add( + actual_class_id=k.actual_class_id, + predicted_class_id=k.predicted_class_id, + false_negatives=accumulator[threshold][k].false_negatives, + true_negatives=accumulator[threshold][k].true_negatives, + false_positives=accumulator[threshold][k].false_positives, + true_positives=accumulator[threshold][k].true_positives, + ) + return {self._key: pb} diff --git a/tensorflow_model_analysis/metrics/multi_label_confusion_matrix_plot_test.py b/tensorflow_model_analysis/metrics/multi_label_confusion_matrix_plot_test.py index 53b0ce59e1..8edc53726e 100644 --- a/tensorflow_model_analysis/metrics/multi_label_confusion_matrix_plot_test.py +++ b/tensorflow_model_analysis/metrics/multi_label_confusion_matrix_plot_test.py @@ -13,68 +13,71 @@ # limitations under the License. """Tests for multi-label confusion matrix at thresholds.""" -from absl.testing import parameterized import apache_beam as beam -from apache_beam.testing import util import numpy as np import tensorflow as tf -from tensorflow_model_analysis.metrics import metric_types -from tensorflow_model_analysis.metrics import metric_util -from tensorflow_model_analysis.metrics import multi_label_confusion_matrix_plot +from absl.testing import parameterized +from apache_beam.testing import util + +from tensorflow_model_analysis.metrics import ( + metric_types, + metric_util, + multi_label_confusion_matrix_plot, +) from tensorflow_model_analysis.utils import test_util class MultiLabelConfusionMatrixPlotTest( test_util.TensorflowModelAnalysisTest, parameterized.TestCase ): + def testMultiLabelConfusionMatrixPlot(self): + computation = multi_label_confusion_matrix_plot.MultiLabelConfusionMatrixPlot().computations()[ + 0 + ] - def testMultiLabelConfusionMatrixPlot(self): - computation = ( - multi_label_confusion_matrix_plot.MultiLabelConfusionMatrixPlot() - .computations()[0]) - - example1 = { - 'labels': np.array([1.0, 1.0, 0.0]), - 'predictions': np.array([0.7, 0.5, 0.2]), - 'example_weights': np.array([1.0]) - } - example2 = { - 'labels': np.array([0.0, 1.0, 0.0]), - 'predictions': np.array([0.3, 0.6, 0.1]), - 'example_weights': np.array([1.0]) - } - example3 = { - 'labels': np.array([0.0, 0.0, 0.0]), - 'predictions': np.array([0.2, 0.4, 0.5]), - 'example_weights': np.array([1.0]) - } - example4 = { - 'labels': np.array([1.0, 0.0, 0.0]), - 'predictions': np.array([1.0, 0.4, 0.1]), - 'example_weights': np.array([1.0]) - } + example1 = { + "labels": np.array([1.0, 1.0, 0.0]), + "predictions": np.array([0.7, 0.5, 0.2]), + "example_weights": np.array([1.0]), + } + example2 = { + "labels": np.array([0.0, 1.0, 0.0]), + "predictions": np.array([0.3, 0.6, 0.1]), + "example_weights": np.array([1.0]), + } + example3 = { + "labels": np.array([0.0, 0.0, 0.0]), + "predictions": np.array([0.2, 0.4, 0.5]), + "example_weights": np.array([1.0]), + } + example4 = { + "labels": np.array([1.0, 0.0, 0.0]), + "predictions": np.array([1.0, 0.4, 0.1]), + "example_weights": np.array([1.0]), + } - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'Create' >> beam.Create([example1, example2, example3, example4]) - | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs) - | 'AddSlice' >> beam.Map(lambda x: ((), x)) - | 'ComputePlot' >> beam.CombinePerKey(computation.combiner)) + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "Create" >> beam.Create([example1, example2, example3, example4]) + | "Process" >> beam.Map(metric_util.to_standard_metric_inputs) + | "AddSlice" >> beam.Map(lambda x: ((), x)) + | "ComputePlot" >> beam.CombinePerKey(computation.combiner) + ) - # pylint: enable=no-value-for-parameter + # pylint: enable=no-value-for-parameter - def check_result(got): - try: - self.assertLen(got, 1) - got_slice_key, got_plots = got[0] - self.assertEqual(got_slice_key, ()) - self.assertLen(got_plots, 1) - key = metric_types.PlotKey(name='multi_label_confusion_matrix_plot') - got_matrix = got_plots[key] - self.assertProtoEquals( - """ + def check_result(got): + try: + self.assertLen(got, 1) + got_slice_key, got_plots = got[0] + self.assertEqual(got_slice_key, ()) + self.assertLen(got_plots, 1) + key = metric_types.PlotKey(name="multi_label_confusion_matrix_plot") + got_matrix = got_plots[key] + self.assertProtoEquals( + """ matrices { threshold: 0.5 entries { @@ -126,66 +129,69 @@ def check_result(got): true_positives: 0.0 } } - """, got_matrix) + """, + got_matrix, + ) - except AssertionError as err: - raise util.BeamAssertException(err) + except AssertionError as err: + raise util.BeamAssertException(err) - util.assert_that(result, check_result, label='result') + util.assert_that(result, check_result, label="result") - @parameterized.named_parameters(('using_num_thresholds', { - 'num_thresholds': 3 - }), ('using_thresholds', { - 'thresholds': [0.0, 0.5, 1.0] - })) - def testMultiLabelConfusionMatrixPlotWithThresholds(self, kwargs): - computation = ( - multi_label_confusion_matrix_plot.MultiLabelConfusionMatrixPlot( - **kwargs).computations(example_weighted=True)[0]) + @parameterized.named_parameters( + ("using_num_thresholds", {"num_thresholds": 3}), + ("using_thresholds", {"thresholds": [0.0, 0.5, 1.0]}), + ) + def testMultiLabelConfusionMatrixPlotWithThresholds(self, kwargs): + computation = multi_label_confusion_matrix_plot.MultiLabelConfusionMatrixPlot( + **kwargs + ).computations(example_weighted=True)[0] - example1 = { - 'labels': np.array([1.0, 1.0, 0.0]), - 'predictions': np.array([0.7, 0.5, 0.2]), - 'example_weights': np.array([0.25]) - } - example2 = { - 'labels': np.array([0.0, 1.0, 0.0]), - 'predictions': np.array([0.3, 0.6, 0.1]), - 'example_weights': np.array([0.5]) - } - example3 = { - 'labels': np.array([0.0, 0.0, 0.0]), - 'predictions': np.array([0.2, 0.4, 0.5]), - 'example_weights': np.array([0.75]) - } - example4 = { - 'labels': np.array([1.0, 0.0, 0.0]), - 'predictions': np.array([1.0, 0.4, 0.1]), - 'example_weights': np.array([1.0]) - } + example1 = { + "labels": np.array([1.0, 1.0, 0.0]), + "predictions": np.array([0.7, 0.5, 0.2]), + "example_weights": np.array([0.25]), + } + example2 = { + "labels": np.array([0.0, 1.0, 0.0]), + "predictions": np.array([0.3, 0.6, 0.1]), + "example_weights": np.array([0.5]), + } + example3 = { + "labels": np.array([0.0, 0.0, 0.0]), + "predictions": np.array([0.2, 0.4, 0.5]), + "example_weights": np.array([0.75]), + } + example4 = { + "labels": np.array([1.0, 0.0, 0.0]), + "predictions": np.array([1.0, 0.4, 0.1]), + "example_weights": np.array([1.0]), + } - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'Create' >> beam.Create([example1, example2, example3, example4]) - | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs) - | 'AddSlice' >> beam.Map(lambda x: ((), x)) - | 'ComputePlot' >> beam.CombinePerKey(computation.combiner)) + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "Create" >> beam.Create([example1, example2, example3, example4]) + | "Process" >> beam.Map(metric_util.to_standard_metric_inputs) + | "AddSlice" >> beam.Map(lambda x: ((), x)) + | "ComputePlot" >> beam.CombinePerKey(computation.combiner) + ) - # pylint: enable=no-value-for-parameter + # pylint: enable=no-value-for-parameter - def check_result(got): - try: - self.assertLen(got, 1) - got_slice_key, got_plots = got[0] - self.assertEqual(got_slice_key, ()) - self.assertLen(got_plots, 1) - key = metric_types.PlotKey( - name='multi_label_confusion_matrix_plot', example_weighted=True) - got_matrix = got_plots[key] - self.assertProtoEquals( - """ + def check_result(got): + try: + self.assertLen(got, 1) + got_slice_key, got_plots = got[0] + self.assertEqual(got_slice_key, ()) + self.assertLen(got_plots, 1) + key = metric_types.PlotKey( + name="multi_label_confusion_matrix_plot", example_weighted=True + ) + got_matrix = got_plots[key] + self.assertProtoEquals( + """ matrices { threshold: 0.0 entries { @@ -318,13 +324,15 @@ def check_result(got): true_negatives: 0.75 } } - """, got_matrix) + """, + got_matrix, + ) - except AssertionError as err: - raise util.BeamAssertException(err) + except AssertionError as err: + raise util.BeamAssertException(err) - util.assert_that(result, check_result, label='result') + util.assert_that(result, check_result, label="result") -if __name__ == '__main__': - tf.test.main() +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow_model_analysis/metrics/ndcg.py b/tensorflow_model_analysis/metrics/ndcg.py index acb97e0865..f126dbe13a 100644 --- a/tensorflow_model_analysis/metrics/ndcg.py +++ b/tensorflow_model_analysis/metrics/ndcg.py @@ -17,48 +17,49 @@ import apache_beam as beam import numpy as np -from tensorflow_model_analysis.metrics import metric_types -from tensorflow_model_analysis.metrics import metric_util + +from tensorflow_model_analysis.metrics import metric_types, metric_util from tensorflow_model_analysis.proto import config_pb2 from tensorflow_model_analysis.utils import util -NDCG_NAME = 'ndcg' +NDCG_NAME = "ndcg" class NDCG(metric_types.Metric): - """NDCG (normalized discounted cumulative gain) metric. - - Calculates NDCG@k for a given set of top_k values calculated from a list of - gains (relevance scores) that are sorted based on the associated predictions. - The top_k_list can be passed as part of the NDCG metric config or using - tfma.MetricsSpec.binarize.top_k_list if configuring multiple top_k metrics. - The gain (relevance score) is determined from the value stored in the - 'gain_key' feature. The value of NDCG@k returned is a weighted average of - NDCG@k over the set of queries using the example weights. - - NDCG@k = (DCG@k for the given rank)/(DCG@k - DCG@k = sum_{i=1}^k gain_i/log_2(i+1), where gain_i is the gain (relevance - score) of the i^th ranked response, indexed from 1. - - This is a query/ranking based metric so a query_key must also be provided in - the associated tfma.MetricsSpec. - """ - - def __init__( - self, - gain_key: str, - top_k_list: Optional[List[int]] = None, - name: str = NDCG_NAME, - ): - """Initializes NDCG. - - Args: - gain_key: Key of feature in features dictionary that holds gain values. - top_k_list: Values for top k. This can also be set using the - tfma.MetricsSpec.binarize.top_k_list associated with the metric. - name: Metric name. + """NDCG (normalized discounted cumulative gain) metric. + + Calculates NDCG@k for a given set of top_k values calculated from a list of + gains (relevance scores) that are sorted based on the associated predictions. + The top_k_list can be passed as part of the NDCG metric config or using + tfma.MetricsSpec.binarize.top_k_list if configuring multiple top_k metrics. + The gain (relevance score) is determined from the value stored in the + 'gain_key' feature. The value of NDCG@k returned is a weighted average of + NDCG@k over the set of queries using the example weights. + + NDCG@k = (DCG@k for the given rank)/(DCG@k + DCG@k = sum_{i=1}^k gain_i/log_2(i+1), where gain_i is the gain (relevance + score) of the i^th ranked response, indexed from 1. + + This is a query/ranking based metric so a query_key must also be provided in + the associated tfma.MetricsSpec. """ - super().__init__(_ndcg, gain_key=gain_key, top_k_list=top_k_list, name=name) + + def __init__( + self, + gain_key: str, + top_k_list: Optional[List[int]] = None, + name: str = NDCG_NAME, + ): + """Initializes NDCG. + + Args: + ---- + gain_key: Key of feature in features dictionary that holds gain values. + top_k_list: Values for top k. This can also be set using the + tfma.MetricsSpec.binarize.top_k_list associated with the metric. + name: Metric name. + """ + super().__init__(_ndcg, gain_key=gain_key, top_k_list=top_k_list, name=name) metric_types.register_metric(NDCG) @@ -73,227 +74,215 @@ def _ndcg( output_names: Optional[List[str]] = None, sub_keys: Optional[List[metric_types.SubKey]] = None, example_weighted: bool = False, - query_key: str = '', + query_key: str = "", ) -> metric_types.MetricComputations: - """Returns metric computations for NDCG.""" - if not query_key: - raise ValueError('a query_key is required to use NDCG metric') - sub_keys = [k for k in sub_keys if k is not None] - if top_k_list: - if sub_keys is None: - sub_keys = [] - for k in top_k_list: - if not any([sub_key.top_k == k for sub_key in sub_keys]): - sub_keys.append(metric_types.SubKey(top_k=k)) - if not sub_keys or any([sub_key.top_k is None for sub_key in sub_keys]): - raise ValueError( - 'top_k values are required to use NDCG metric: {}'.format(sub_keys) - ) - computations = [] - for model_name in model_names if model_names else ['']: - for output_name in output_names if output_names else ['']: - keys = [] - for sub_key in sub_keys: - keys.append( - metric_types.MetricKey( - name, - model_name=model_name, - output_name=output_name, - sub_key=sub_key, - example_weighted=example_weighted, + """Returns metric computations for NDCG.""" + if not query_key: + raise ValueError("a query_key is required to use NDCG metric") + sub_keys = [k for k in sub_keys if k is not None] + if top_k_list: + if sub_keys is None: + sub_keys = [] + for k in top_k_list: + if not any([sub_key.top_k == k for sub_key in sub_keys]): + sub_keys.append(metric_types.SubKey(top_k=k)) + if not sub_keys or any([sub_key.top_k is None for sub_key in sub_keys]): + raise ValueError(f"top_k values are required to use NDCG metric: {sub_keys}") + computations = [] + for model_name in model_names if model_names else [""]: + for output_name in output_names if output_names else [""]: + keys = [] + for sub_key in sub_keys: + keys.append( + metric_types.MetricKey( + name, + model_name=model_name, + output_name=output_name, + sub_key=sub_key, + example_weighted=example_weighted, + ) + ) + computations.append( + metric_types.MetricComputation( + keys=keys, + preprocessors=[ + metric_types.CombinedFeaturePreprocessor( + feature_keys=[query_key, gain_key] + ) + ], + combiner=_NDCGCombiner( + metric_keys=keys, + eval_config=eval_config, + model_name=model_name, + output_name=output_name, + example_weighted=example_weighted, + query_key=query_key, + gain_key=gain_key, + ), + ) ) - ) - computations.append( - metric_types.MetricComputation( - keys=keys, - preprocessors=[ - metric_types.CombinedFeaturePreprocessor( - feature_keys=[query_key, gain_key] - ) - ], - combiner=_NDCGCombiner( - metric_keys=keys, - eval_config=eval_config, - model_name=model_name, - output_name=output_name, - example_weighted=example_weighted, - query_key=query_key, - gain_key=gain_key, - ), - ) - ) - return computations + return computations class _NDCGAccumulator: - """NDCG accumulator.""" + """NDCG accumulator.""" - __slots__ = ['ndcg', 'total_weighted_examples'] + __slots__ = ["ndcg", "total_weighted_examples"] - def __init__(self, size: int): - self.ndcg = [0.0] * size - self.total_weighted_examples = 0.0 + def __init__(self, size: int): + self.ndcg = [0.0] * size + self.total_weighted_examples = 0.0 class _NDCGCombiner(beam.CombineFn): - """Computes NDCG (normalized discounted cumulative gain).""" - - def __init__( - self, - metric_keys: List[metric_types.MetricKey], - eval_config: Optional[config_pb2.EvalConfig], - model_name: str, - output_name: str, - example_weighted: bool, - query_key: str, - gain_key: str, - ): - """Initialize. - - Args: - metric_keys: Metric keys. - eval_config: Eval config. - model_name: Model name. - output_name: Output name. - example_weighted: True if example weights should be applied. - query_key: Query key. - gain_key: Key of feature in features dictionary that holds gain values. - """ - self._metric_keys = metric_keys - self._eval_config = eval_config - self._model_name = model_name - self._output_name = output_name - self._example_weighted = example_weighted - self._query_key = query_key - self._gain_key = gain_key - - def _query( - self, element: metric_types.StandardMetricInputs - ) -> Union[float, int, str]: - query = util.get_by_keys( - element.combined_features, [self._query_key] - ).flatten() - if query.size == 0 or not np.all(query == query[0]): - raise ValueError( - 'missing query value or not all values are the same: value={}, ' - 'metric_keys={}, StandardMetricInputs={}'.format( - query, self._metric_keys, element - ) - ) - return query[0] - - def _to_gains_example_weight( - self, element: metric_types.StandardMetricInputs - ) -> Tuple[np.ndarray, float]: - """Returns gains and example_weight sorted by prediction.""" - _, predictions, example_weight = next( - metric_util.to_label_prediction_example_weight( - element, - eval_config=self._eval_config, - model_name=self._model_name, - output_name=self._output_name, - example_weighted=self._example_weighted, - flatten=False, - require_single_example_weight=True, - ) - ) # pytype: disable=wrong-arg-types - gains = util.get_by_keys(element.combined_features, [self._gain_key]) - if gains.size != predictions.size: - raise ValueError( - 'expected {} to be same size as predictions {} != {}: ' - 'gains={}, metric_keys={}, ' - 'StandardMetricInputs={}'.format( - self._gain_key, - gains.size, - predictions.size, - gains, - self._metric_keys, - element, - ) - ) - gains = gains.reshape(predictions.shape) - # Ignore non-positive gains. - if gains.max() <= 0: - example_weight = 0.0 - return (gains[np.argsort(predictions)[::-1]], float(example_weight)) - - def _calculate_dcg_at_k(self, k: int, sorted_values: List[float]) -> float: - """Calculate the value of DCG@k. - - Args: - k: The last position to consider. - sorted_values: A list of gain values assumed to be sorted in the desired - ranking order. - - Returns: - The value of DCG@k. - """ - return np.sum( - np.array(sorted_values)[:k] / np.log2(np.array(range(2, k + 2))) - ) - - def _calculate_ndcg(self, values: List[Tuple[int, float]], k: int) -> float: - """Calculate NDCG@k, based on given rank and gain values. - - Args: - values: A list of tuples representing rank order and gain values. - k: The maximum position to consider in calculating nDCG - - Returns: - The value of NDCG@k, for the given list of values. - """ - max_rank = min(k, len(values)) - ranked_values = [ - gain for _, gain in sorted(values, key=lambda x: x[0], reverse=False) - ] - optimal_values = [ - gain for _, gain in sorted(values, key=lambda x: x[1], reverse=True) - ] - dcg = self._calculate_dcg_at_k(max_rank, ranked_values) - optimal_dcg = self._calculate_dcg_at_k(max_rank, optimal_values) - if optimal_dcg > 0: - return dcg / optimal_dcg - else: - return 0 - - def create_accumulator(self): - return _NDCGAccumulator(len(self._metric_keys)) - - def add_input( - self, - accumulator: _NDCGAccumulator, - element: metric_types.StandardMetricInputs, - ) -> _NDCGAccumulator: - gains, example_weight = self._to_gains_example_weight(element) - rank_gain = [(pos + 1, gain) for pos, gain in enumerate(gains)] - for i, key in enumerate(self._metric_keys): - if not key.sub_key or key.sub_key.top_k is None: - raise ValueError( - 'top_k values are required to use NDCG metric: {}'.format(key) - ) - accumulator.ndcg[i] += ( - self._calculate_ndcg(rank_gain, key.sub_key.top_k) * example_weight - ) - accumulator.total_weighted_examples += float(example_weight) - return accumulator - - def merge_accumulators( - self, accumulators: Iterable[_NDCGAccumulator] - ) -> _NDCGAccumulator: - accumulators = iter(accumulators) - result = next(accumulators) - for accumulator in accumulators: - result.ndcg = [a + b for a, b in zip(result.ndcg, accumulator.ndcg)] - result.total_weighted_examples += accumulator.total_weighted_examples - return result - - def extract_output( - self, accumulator: _NDCGAccumulator - ) -> Dict[metric_types.MetricKey, float]: - output = {} - for i, key in enumerate(self._metric_keys): - if accumulator.total_weighted_examples > 0: - output[key] = accumulator.ndcg[i] / accumulator.total_weighted_examples - else: - output[key] = float('nan') - return output + """Computes NDCG (normalized discounted cumulative gain).""" + + def __init__( + self, + metric_keys: List[metric_types.MetricKey], + eval_config: Optional[config_pb2.EvalConfig], + model_name: str, + output_name: str, + example_weighted: bool, + query_key: str, + gain_key: str, + ): + """Initialize. + + Args: + ---- + metric_keys: Metric keys. + eval_config: Eval config. + model_name: Model name. + output_name: Output name. + example_weighted: True if example weights should be applied. + query_key: Query key. + gain_key: Key of feature in features dictionary that holds gain values. + """ + self._metric_keys = metric_keys + self._eval_config = eval_config + self._model_name = model_name + self._output_name = output_name + self._example_weighted = example_weighted + self._query_key = query_key + self._gain_key = gain_key + + def _query( + self, element: metric_types.StandardMetricInputs + ) -> Union[float, int, str]: + query = util.get_by_keys(element.combined_features, [self._query_key]).flatten() + if query.size == 0 or not np.all(query == query[0]): + raise ValueError( + f"missing query value or not all values are the same: value={query}, " + f"metric_keys={self._metric_keys}, StandardMetricInputs={element}" + ) + return query[0] + + def _to_gains_example_weight( + self, element: metric_types.StandardMetricInputs + ) -> Tuple[np.ndarray, float]: + """Returns gains and example_weight sorted by prediction.""" + _, predictions, example_weight = next( + metric_util.to_label_prediction_example_weight( + element, + eval_config=self._eval_config, + model_name=self._model_name, + output_name=self._output_name, + example_weighted=self._example_weighted, + flatten=False, + require_single_example_weight=True, + ) + ) # pytype: disable=wrong-arg-types + gains = util.get_by_keys(element.combined_features, [self._gain_key]) + if gains.size != predictions.size: + raise ValueError( + f"expected {self._gain_key} to be same size as predictions {gains.size} != {predictions.size}: " + f"gains={gains}, metric_keys={self._metric_keys}, " + f"StandardMetricInputs={element}" + ) + gains = gains.reshape(predictions.shape) + # Ignore non-positive gains. + if gains.max() <= 0: + example_weight = 0.0 + return (gains[np.argsort(predictions)[::-1]], float(example_weight)) + + def _calculate_dcg_at_k(self, k: int, sorted_values: List[float]) -> float: + """Calculate the value of DCG@k. + + Args: + ---- + k: The last position to consider. + sorted_values: A list of gain values assumed to be sorted in the desired + ranking order. + + Returns: + ------- + The value of DCG@k. + """ + return np.sum(np.array(sorted_values)[:k] / np.log2(np.array(range(2, k + 2)))) + + def _calculate_ndcg(self, values: List[Tuple[int, float]], k: int) -> float: + """Calculate NDCG@k, based on given rank and gain values. + + Args: + ---- + values: A list of tuples representing rank order and gain values. + k: The maximum position to consider in calculating nDCG + + Returns: + ------- + The value of NDCG@k, for the given list of values. + """ + max_rank = min(k, len(values)) + ranked_values = [ + gain for _, gain in sorted(values, key=lambda x: x[0], reverse=False) + ] + optimal_values = [ + gain for _, gain in sorted(values, key=lambda x: x[1], reverse=True) + ] + dcg = self._calculate_dcg_at_k(max_rank, ranked_values) + optimal_dcg = self._calculate_dcg_at_k(max_rank, optimal_values) + if optimal_dcg > 0: + return dcg / optimal_dcg + else: + return 0 + + def create_accumulator(self): + return _NDCGAccumulator(len(self._metric_keys)) + + def add_input( + self, + accumulator: _NDCGAccumulator, + element: metric_types.StandardMetricInputs, + ) -> _NDCGAccumulator: + gains, example_weight = self._to_gains_example_weight(element) + rank_gain = [(pos + 1, gain) for pos, gain in enumerate(gains)] + for i, key in enumerate(self._metric_keys): + if not key.sub_key or key.sub_key.top_k is None: + raise ValueError(f"top_k values are required to use NDCG metric: {key}") + accumulator.ndcg[i] += ( + self._calculate_ndcg(rank_gain, key.sub_key.top_k) * example_weight + ) + accumulator.total_weighted_examples += float(example_weight) + return accumulator + + def merge_accumulators( + self, accumulators: Iterable[_NDCGAccumulator] + ) -> _NDCGAccumulator: + accumulators = iter(accumulators) + result = next(accumulators) + for accumulator in accumulators: + result.ndcg = [a + b for a, b in zip(result.ndcg, accumulator.ndcg)] + result.total_weighted_examples += accumulator.total_weighted_examples + return result + + def extract_output( + self, accumulator: _NDCGAccumulator + ) -> Dict[metric_types.MetricKey, float]: + output = {} + for i, key in enumerate(self._metric_keys): + if accumulator.total_weighted_examples > 0: + output[key] = accumulator.ndcg[i] / accumulator.total_weighted_examples + else: + output[key] = float("nan") + return output diff --git a/tensorflow_model_analysis/metrics/ndcg_test.py b/tensorflow_model_analysis/metrics/ndcg_test.py index 988002cb16..b451fdcc4d 100644 --- a/tensorflow_model_analysis/metrics/ndcg_test.py +++ b/tensorflow_model_analysis/metrics/ndcg_test.py @@ -13,158 +13,135 @@ # limitations under the License. """Tests for NDCG metric.""" -from absl.testing import parameterized import apache_beam as beam -from apache_beam.testing import util import numpy as np import tensorflow as tf -from tensorflow_model_analysis.metrics import metric_types -from tensorflow_model_analysis.metrics import metric_util -from tensorflow_model_analysis.metrics import ndcg +from absl.testing import parameterized +from apache_beam.testing import util + +from tensorflow_model_analysis.metrics import metric_types, metric_util, ndcg from tensorflow_model_analysis.utils import test_util from tensorflow_model_analysis.utils import util as tfma_util -class NDCGMetricsTest( - test_util.TensorflowModelAnalysisTest, parameterized.TestCase -): - - @parameterized.named_parameters( - ('feature_key', 'features'), - ('transformed_features_key', 'transformed_features')) - def testNDCG(self, feature_key): - # SubKeys will be a merger of top_k_list and sub_keys. - metric = ndcg.NDCG( - gain_key='gain', top_k_list=[1, 2]).computations( - sub_keys=[ - None, - metric_types.SubKey(top_k=1), - metric_types.SubKey(top_k=2) - ], - query_key='query', - example_weighted=True)[0] +class NDCGMetricsTest(test_util.TensorflowModelAnalysisTest, parameterized.TestCase): + @parameterized.named_parameters( + ("feature_key", "features"), + ("transformed_features_key", "transformed_features"), + ) + def testNDCG(self, feature_key): + # SubKeys will be a merger of top_k_list and sub_keys. + metric = ndcg.NDCG(gain_key="gain", top_k_list=[1, 2]).computations( + sub_keys=[None, metric_types.SubKey(top_k=1), metric_types.SubKey(top_k=2)], + query_key="query", + example_weighted=True, + )[0] - query1_example1 = { - 'labels': np.array([1.0]), - 'predictions': np.array([0.2]), - 'example_weights': np.array([1.0]), - feature_key: { - 'query': np.array(['query1']), - 'gain': np.array([1.0]) + query1_example1 = { + "labels": np.array([1.0]), + "predictions": np.array([0.2]), + "example_weights": np.array([1.0]), + feature_key: {"query": np.array(["query1"]), "gain": np.array([1.0])}, } - } - query1_example2 = { - 'labels': np.array([0.0]), - 'predictions': np.array([0.8]), - 'example_weights': np.array([1.0]), - feature_key: { - 'query': np.array(['query1']), - 'gain': np.array([0.5]) + query1_example2 = { + "labels": np.array([0.0]), + "predictions": np.array([0.8]), + "example_weights": np.array([1.0]), + feature_key: {"query": np.array(["query1"]), "gain": np.array([0.5])}, } - } - query2_example1 = { - 'labels': np.array([0.0]), - 'predictions': np.array([0.5]), - 'example_weights': np.array([2.0]), - feature_key: { - 'query': np.array(['query2']), - 'gain': np.array([0.5]) + query2_example1 = { + "labels": np.array([0.0]), + "predictions": np.array([0.5]), + "example_weights": np.array([2.0]), + feature_key: {"query": np.array(["query2"]), "gain": np.array([0.5])}, } - } - query2_example2 = { - 'labels': np.array([1.0]), - 'predictions': np.array([0.9]), - 'example_weights': np.array([2.0]), - feature_key: { - 'query': np.array(['query2']), - 'gain': np.array([1.0]) + query2_example2 = { + "labels": np.array([1.0]), + "predictions": np.array([0.9]), + "example_weights": np.array([2.0]), + feature_key: {"query": np.array(["query2"]), "gain": np.array([1.0])}, } - } - query2_example3 = { - 'labels': np.array([0.0]), - 'predictions': np.array([0.1]), - 'example_weights': np.array([2.0]), - feature_key: { - 'query': np.array(['query2']), - 'gain': np.array([0.1]) + query2_example3 = { + "labels": np.array([0.0]), + "predictions": np.array([0.1]), + "example_weights": np.array([2.0]), + feature_key: {"query": np.array(["query2"]), "gain": np.array([0.1])}, } - } - query3_example1 = { - 'labels': np.array([1.0]), - 'predictions': np.array([0.9]), - 'example_weights': np.array([3.0]), - feature_key: { - 'query': np.array(['query3']), - 'gain': np.array([1.0]) + query3_example1 = { + "labels": np.array([1.0]), + "predictions": np.array([0.9]), + "example_weights": np.array([3.0]), + feature_key: {"query": np.array(["query3"]), "gain": np.array([1.0])}, } - } - query4_example1 = { - 'labels': np.array([1.0]), - 'predictions': np.array([0.9]), - 'example_weights': np.array([3.0]), - feature_key: { - 'query': np.array(['query4']), - 'gain': - np.array([0.0]) # 0 gain is ignored + query4_example1 = { + "labels": np.array([1.0]), + "predictions": np.array([0.9]), + "example_weights": np.array([3.0]), + feature_key: { + "query": np.array(["query4"]), + "gain": np.array([0.0]), # 0 gain is ignored + }, } - } - examples = [ - tfma_util.merge_extracts([query1_example1, query1_example2]), - tfma_util.merge_extracts( - [query2_example1, query2_example2, query2_example3]), - tfma_util.merge_extracts([query3_example1]), - tfma_util.merge_extracts([query4_example1]) - ] + examples = [ + tfma_util.merge_extracts([query1_example1, query1_example2]), + tfma_util.merge_extracts( + [query2_example1, query2_example2, query2_example3] + ), + tfma_util.merge_extracts([query3_example1]), + tfma_util.merge_extracts([query4_example1]), + ] - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'Create' >> beam.Create(examples) - | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs, True) - | 'AddSlice' >> beam.Map(lambda x: ((), x)) - | 'Combine' >> beam.CombinePerKey(metric.combiner)) + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "Create" >> beam.Create(examples) + | "Process" >> beam.Map(metric_util.to_standard_metric_inputs, True) + | "AddSlice" >> beam.Map(lambda x: ((), x)) + | "Combine" >> beam.CombinePerKey(metric.combiner) + ) - # pylint: enable=no-value-for-parameter + # pylint: enable=no-value-for-parameter - def check_result(got): - try: - self.assertLen(got, 1) - got_slice_key, got_metrics = got[0] - self.assertEqual(got_slice_key, ()) - ndcg1_key = metric_types.MetricKey( - name='ndcg', - sub_key=metric_types.SubKey(top_k=1), - example_weighted=True) - ndcg2_key = metric_types.MetricKey( - name='ndcg', - sub_key=metric_types.SubKey(top_k=2), - example_weighted=True) - # Query1 (weight=1): (p=0.8, g=0.5) (p=0.2, g=1.0) - # Query2 (weight=2): (p=0.9, g=1.0) (p=0.5, g=0.5) (p=0.1, g=0.1) - # Query3 (weight=3): (p=0.9, g=1.0) - # - # DCG@1: 0.5, 1.0, 1.0 - # NDCG@1: 0.5, 1.0, 1.0 - # Average NDCG@1: (1 * 0.5 + 2 * 1.0 + 3 * 1.0) / (1 + 2 + 3) ~ 0.92 - # - # DCG@2: (0.5 + 1.0/log(3), (1.0 + 0.5/log(3), (1.0) - # NDCG@2: (0.5 + 1.0/log(3)) / (1.0 + 0.5/log(3)), - # (1.0 + 0.5/log(3)) / (1.0 + 0.5/log(3)), - # 1.0 - # Average NDCG@2: (1 * 0.860 + 2 * 1.0 + 3 * 1.0) / (1 + 2 + 3) ~ 0.97 - self.assertDictElementsAlmostEqual( - got_metrics, { - ndcg1_key: 0.9166667, - ndcg2_key: 0.9766198 - }, - places=5) + def check_result(got): + try: + self.assertLen(got, 1) + got_slice_key, got_metrics = got[0] + self.assertEqual(got_slice_key, ()) + ndcg1_key = metric_types.MetricKey( + name="ndcg", + sub_key=metric_types.SubKey(top_k=1), + example_weighted=True, + ) + ndcg2_key = metric_types.MetricKey( + name="ndcg", + sub_key=metric_types.SubKey(top_k=2), + example_weighted=True, + ) + # Query1 (weight=1): (p=0.8, g=0.5) (p=0.2, g=1.0) + # Query2 (weight=2): (p=0.9, g=1.0) (p=0.5, g=0.5) (p=0.1, g=0.1) + # Query3 (weight=3): (p=0.9, g=1.0) + # + # DCG@1: 0.5, 1.0, 1.0 + # NDCG@1: 0.5, 1.0, 1.0 + # Average NDCG@1: (1 * 0.5 + 2 * 1.0 + 3 * 1.0) / (1 + 2 + 3) ~ 0.92 + # + # DCG@2: (0.5 + 1.0/log(3), (1.0 + 0.5/log(3), (1.0) + # NDCG@2: (0.5 + 1.0/log(3)) / (1.0 + 0.5/log(3)), + # (1.0 + 0.5/log(3)) / (1.0 + 0.5/log(3)), + # 1.0 + # Average NDCG@2: (1 * 0.860 + 2 * 1.0 + 3 * 1.0) / (1 + 2 + 3) ~ 0.97 + self.assertDictElementsAlmostEqual( + got_metrics, + {ndcg1_key: 0.9166667, ndcg2_key: 0.9766198}, + places=5, + ) - except AssertionError as err: - raise util.BeamAssertException(err) + except AssertionError as err: + raise util.BeamAssertException(err) - util.assert_that(result, check_result, label='result') + util.assert_that(result, check_result, label="result") -if __name__ == '__main__': - tf.test.main() +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow_model_analysis/metrics/object_detection_confusion_matrix_metrics.py b/tensorflow_model_analysis/metrics/object_detection_confusion_matrix_metrics.py index 9fd9470a99..2eb60d7ee5 100644 --- a/tensorflow_model_analysis/metrics/object_detection_confusion_matrix_metrics.py +++ b/tensorflow_model_analysis/metrics/object_detection_confusion_matrix_metrics.py @@ -15,767 +15,801 @@ from typing import List, Optional, Tuple, Union -from tensorflow_model_analysis.metrics import confusion_matrix_metrics -from tensorflow_model_analysis.metrics import metric_types -from tensorflow_model_analysis.metrics import metric_util -from tensorflow_model_analysis.metrics import preprocessors +from tensorflow_model_analysis.metrics import ( + confusion_matrix_metrics, + metric_types, + metric_util, + preprocessors, +) from tensorflow_model_analysis.proto import config_pb2 # default values for object detection settings _DEFAULT_IOU_THRESHOLD = 0.5 -_DEFAULT_AREA_RANGE = (0, float('inf')) +_DEFAULT_AREA_RANGE = (0, float("inf")) -OBJECT_DETECTION_MAX_RECALL_NAME = 'object_detection_max_recall' -OBJECT_DETECTION_PRECISION_AT_RECALL_NAME = 'object_detection_precision_at_recall' -OBJECT_DETECTION_RECALL_NAME = 'object_detection_recall' -OBJECT_DETECTION_PRECISION_NAME = 'object_detection_precision' -OBJECT_DETECTION_THRESHOLD_AT_RECALL_NAME = 'object_detection_threshold_at_recall' +OBJECT_DETECTION_MAX_RECALL_NAME = "object_detection_max_recall" +OBJECT_DETECTION_PRECISION_AT_RECALL_NAME = "object_detection_precision_at_recall" +OBJECT_DETECTION_RECALL_NAME = "object_detection_recall" +OBJECT_DETECTION_PRECISION_NAME = "object_detection_precision" +OBJECT_DETECTION_THRESHOLD_AT_RECALL_NAME = "object_detection_threshold_at_recall" class ObjectDetectionPrecisionAtRecall( # pytype: disable=signature-mismatch # always-use-return-annotations confusion_matrix_metrics.PrecisionAtRecall ): - """Computes best precision where recall is >= specified value. - - The threshold for the given recall value is computed and used to evaluate the - corresponding precision. - - If `sample_weight` is `None`, weights default to 1. - Use `sample_weight` of 0 to mask values. - """ - - def __init__(self, - recall: Union[float, List[float]], - thresholds: Optional[List[float]] = None, - num_thresholds: Optional[int] = None, - name: Optional[str] = None, - iou_threshold: Optional[float] = None, - class_id: Optional[int] = None, - class_weight: Optional[float] = None, - area_range: Optional[Tuple[float, float]] = None, - max_num_detections: Optional[int] = None, - labels_to_stack: Optional[List[str]] = None, - predictions_to_stack: Optional[List[str]] = None, - num_detections_key: Optional[str] = None, - allow_missing_key: bool = False): - """Initializes PrecisionAtRecall metric. - - The metric supports using multiple outputs to form the labels/predictions if - the user specifies the label/predcition keys to stack. In this case, the - metric is not expected to work with multi-outputs. The metric only supports - multi outputs if the output of model is already pre-stacked in the expected - format, i.e. ['xmin', 'ymin', 'xmax', 'ymax', 'class_id'] for labels and - ['xmin', 'ymin', 'xmax', 'ymax', 'class_id', 'confidence scores'] for - predictions. - - Args: - recall: A scalar or a list of scalar values in range `[0, 1]`. - thresholds: (Optional) Thresholds to use for calculating the matrices. Use - one of either thresholds or num_thresholds. - num_thresholds: (Optional) Defaults to 1000. The number of thresholds to - use for matching the given recall. - name: (Optional) string name of the metric instance. - iou_threshold: (Optional) Thresholds for a detection and ground truth pair - with specific iou to be considered as a match. Default to 0.5 - class_id: (Optional) The class id for calculating metrics. - class_weight: (Optional) The weight associated with the object class id. - area_range: (Optional) A tuple (inclusive) representing the area-range for - objects to be considered for metrics. Default to (0, inf). - max_num_detections: (Optional) The maximum number of detections for a - single image. Default to None. - labels_to_stack: (Optional) Keys for columns to be stacked as a single - numpy array as the labels. It is searched under the key labels, features - and transformed features. The desired format is [left bounadary, top - boudnary, right boundary, bottom boundary, class id]. e.g. ['xmin', - 'ymin', 'xmax', 'ymax', 'class_id'] - predictions_to_stack: (Optional) Output names for columns to be stacked as - a single numpy array as the prediction. It should be the model's output - names. The desired format is [left bounadary, top boudnary, right - boundary, bottom boundary, class id, confidence score]. e.g. ['xmin', - 'ymin', 'xmax', 'ymax', 'class_id', 'scores'] - num_detections_key: (Optional) An output name in which to find the number - of detections to use for evaluation for a given example. It does nothing - if predictions_to_stack is not set. The value for this output should be - a scalar value or a single-value tensor. The stacked predicitions will - be truncated with the specified number of detections. - allow_missing_key: (Optional) If true, the preprocessor will return empty - array instead of raising errors. + """Computes best precision where recall is >= specified value. + + The threshold for the given recall value is computed and used to evaluate the + corresponding precision. + + If `sample_weight` is `None`, weights default to 1. + Use `sample_weight` of 0 to mask values. """ - for r in [recall] if isinstance(recall, float) else recall: - if r < 0 or r > 1: - raise ValueError('Argument `recall` must be in the range [0, 1]. ' - f'Received: recall={r}') - - super().__init__( - thresholds=thresholds, - num_thresholds=num_thresholds, - recall=recall, - name=name, - class_id=class_id, - iou_threshold=iou_threshold, - area_range=area_range, - max_num_detections=max_num_detections, - class_weight=class_weight, - labels_to_stack=labels_to_stack, - predictions_to_stack=predictions_to_stack, - num_detections_key=num_detections_key, - allow_missing_key=allow_missing_key) - - def _default_name(self) -> str: - return OBJECT_DETECTION_PRECISION_AT_RECALL_NAME - - def _metric_computations( - self, - thresholds: Optional[Union[float, List[float]]] = None, - num_thresholds: Optional[int] = None, - iou_threshold: Optional[float] = None, - class_id: Optional[int] = None, - class_weight: Optional[float] = None, - area_range: Optional[Tuple[float, float]] = None, - max_num_detections: Optional[int] = None, - name: Optional[str] = None, - eval_config: Optional[config_pb2.EvalConfig] = None, - model_name: str = '', - output_name: str = '', - labels_to_stack: Optional[List[str]] = None, - predictions_to_stack: Optional[List[str]] = None, - num_detections_key: Optional[str] = None, - allow_missing_key: bool = False, - example_weighted: bool = False, - **kwargs, - ) -> metric_types.MetricComputations: - metric_util.validate_object_detection_arguments( - class_id=class_id, - class_weight=class_weight, - area_range=area_range, - max_num_detections=max_num_detections, - labels_to_stack=labels_to_stack, - predictions_to_stack=predictions_to_stack, - output_name=output_name, - ) - - key = metric_types.MetricKey( - name=name, - model_name=model_name, - output_name=output_name, - sub_key=metric_types.SubKey(class_id=class_id), - example_weighted=example_weighted, - aggregation_type=None, - ) - - preprocessor = preprocessors.BoundingBoxMatchPreprocessor( - class_id=class_id, - iou_threshold=iou_threshold, - area_range=area_range, - max_num_detections=max_num_detections, - class_weight=class_weight, - labels_to_stack=labels_to_stack, - predictions_to_stack=predictions_to_stack, - num_detections_key=num_detections_key, - allow_missing_key=allow_missing_key, - model_name=model_name) - return super()._metric_computations( - thresholds=thresholds, - num_thresholds=num_thresholds, - name=name, - eval_config=eval_config, - model_name=model_name, - output_name=output_name, - preprocessors=[preprocessor], - metric_key=key, - example_weighted=example_weighted, + + def __init__( + self, + recall: Union[float, List[float]], + thresholds: Optional[List[float]] = None, + num_thresholds: Optional[int] = None, + name: Optional[str] = None, + iou_threshold: Optional[float] = None, + class_id: Optional[int] = None, + class_weight: Optional[float] = None, + area_range: Optional[Tuple[float, float]] = None, + max_num_detections: Optional[int] = None, + labels_to_stack: Optional[List[str]] = None, + predictions_to_stack: Optional[List[str]] = None, + num_detections_key: Optional[str] = None, + allow_missing_key: bool = False, + ): + """Initializes PrecisionAtRecall metric. + + The metric supports using multiple outputs to form the labels/predictions if + the user specifies the label/predcition keys to stack. In this case, the + metric is not expected to work with multi-outputs. The metric only supports + multi outputs if the output of model is already pre-stacked in the expected + format, i.e. ['xmin', 'ymin', 'xmax', 'ymax', 'class_id'] for labels and + ['xmin', 'ymin', 'xmax', 'ymax', 'class_id', 'confidence scores'] for + predictions. + + Args: + ---- + recall: A scalar or a list of scalar values in range `[0, 1]`. + thresholds: (Optional) Thresholds to use for calculating the matrices. Use + one of either thresholds or num_thresholds. + num_thresholds: (Optional) Defaults to 1000. The number of thresholds to + use for matching the given recall. + name: (Optional) string name of the metric instance. + iou_threshold: (Optional) Thresholds for a detection and ground truth pair + with specific iou to be considered as a match. Default to 0.5 + class_id: (Optional) The class id for calculating metrics. + class_weight: (Optional) The weight associated with the object class id. + area_range: (Optional) A tuple (inclusive) representing the area-range for + objects to be considered for metrics. Default to (0, inf). + max_num_detections: (Optional) The maximum number of detections for a + single image. Default to None. + labels_to_stack: (Optional) Keys for columns to be stacked as a single + numpy array as the labels. It is searched under the key labels, features + and transformed features. The desired format is [left bounadary, top + boudnary, right boundary, bottom boundary, class id]. e.g. ['xmin', + 'ymin', 'xmax', 'ymax', 'class_id'] + predictions_to_stack: (Optional) Output names for columns to be stacked as + a single numpy array as the prediction. It should be the model's output + names. The desired format is [left bounadary, top boudnary, right + boundary, bottom boundary, class id, confidence score]. e.g. ['xmin', + 'ymin', 'xmax', 'ymax', 'class_id', 'scores'] + num_detections_key: (Optional) An output name in which to find the number + of detections to use for evaluation for a given example. It does nothing + if predictions_to_stack is not set. The value for this output should be + a scalar value or a single-value tensor. The stacked predicitions will + be truncated with the specified number of detections. + allow_missing_key: (Optional) If true, the preprocessor will return empty + array instead of raising errors. + """ + for r in [recall] if isinstance(recall, float) else recall: + if r < 0 or r > 1: + raise ValueError( + "Argument `recall` must be in the range [0, 1]. " + f"Received: recall={r}" + ) + + super().__init__( + thresholds=thresholds, + num_thresholds=num_thresholds, + recall=recall, + name=name, + class_id=class_id, + iou_threshold=iou_threshold, + area_range=area_range, + max_num_detections=max_num_detections, + class_weight=class_weight, + labels_to_stack=labels_to_stack, + predictions_to_stack=predictions_to_stack, + num_detections_key=num_detections_key, + allow_missing_key=allow_missing_key, + ) + + def _default_name(self) -> str: + return OBJECT_DETECTION_PRECISION_AT_RECALL_NAME + + def _metric_computations( + self, + thresholds: Optional[Union[float, List[float]]] = None, + num_thresholds: Optional[int] = None, + iou_threshold: Optional[float] = None, + class_id: Optional[int] = None, + class_weight: Optional[float] = None, + area_range: Optional[Tuple[float, float]] = None, + max_num_detections: Optional[int] = None, + name: Optional[str] = None, + eval_config: Optional[config_pb2.EvalConfig] = None, + model_name: str = "", + output_name: str = "", + labels_to_stack: Optional[List[str]] = None, + predictions_to_stack: Optional[List[str]] = None, + num_detections_key: Optional[str] = None, + allow_missing_key: bool = False, + example_weighted: bool = False, **kwargs, - ) + ) -> metric_types.MetricComputations: + metric_util.validate_object_detection_arguments( + class_id=class_id, + class_weight=class_weight, + area_range=area_range, + max_num_detections=max_num_detections, + labels_to_stack=labels_to_stack, + predictions_to_stack=predictions_to_stack, + output_name=output_name, + ) + + key = metric_types.MetricKey( + name=name, + model_name=model_name, + output_name=output_name, + sub_key=metric_types.SubKey(class_id=class_id), + example_weighted=example_weighted, + aggregation_type=None, + ) + + preprocessor = preprocessors.BoundingBoxMatchPreprocessor( + class_id=class_id, + iou_threshold=iou_threshold, + area_range=area_range, + max_num_detections=max_num_detections, + class_weight=class_weight, + labels_to_stack=labels_to_stack, + predictions_to_stack=predictions_to_stack, + num_detections_key=num_detections_key, + allow_missing_key=allow_missing_key, + model_name=model_name, + ) + return super()._metric_computations( + thresholds=thresholds, + num_thresholds=num_thresholds, + name=name, + eval_config=eval_config, + model_name=model_name, + output_name=output_name, + preprocessors=[preprocessor], + metric_key=key, + example_weighted=example_weighted, + **kwargs, + ) metric_types.register_metric(ObjectDetectionPrecisionAtRecall) class ObjectDetectionRecall(confusion_matrix_metrics.Recall): - """Computes the recall of the predictions with respect to the labels. - - The metric uses true positives and false negatives to compute recall by - dividing the true positives by the sum of true positives and false negatives. - - If `sample_weight` is `None`, weights default to 1. - Use `sample_weight` of 0 to mask values. - """ - - def __init__(self, - thresholds: Optional[Union[float, List[float]]] = None, - name: Optional[str] = None, - iou_threshold: Optional[float] = None, - class_id: Optional[int] = None, - class_weight: Optional[float] = None, - area_range: Optional[Tuple[float, float]] = None, - max_num_detections: Optional[int] = None, - labels_to_stack: Optional[List[str]] = None, - predictions_to_stack: Optional[List[str]] = None, - num_detections_key: Optional[str] = None, - allow_missing_key: bool = False): - """Initializes Recall metric. - - The metric supports using multiple outputs to form the labels/predictions if - the user specifies the label/predcition keys to stack. In this case, the - metric is not expected to work with multi-outputs. The metric only supports - multi outputs if the output of model is already pre-stacked in the expected - format, i.e. ['xmin', 'ymin', 'xmax', 'ymax', 'class_id'] for labels and - ['xmin', 'ymin', 'xmax', 'ymax', 'class_id', 'confidence scores'] for - predictions. - - Args: - thresholds: (Optional) A float value or a python list/tuple of float - threshold values in [0, 1]. A threshold is compared with prediction - values to determine the truth value of predictions (i.e., above the - threshold is `true`, below is `false`). One metric value is generated - for each threshold value. The default is to calculate recall with - `thresholds=0.5`. - name: (Optional) string name of the metric instance. - iou_threshold: (Optional) Thresholds for a detection and ground truth pair - with specific iou to be considered as a match. Default to 0.5 - class_id: (Optional) The class id for calculating metrics. - class_weight: (Optional) The weight associated with the object class id. - area_range: (Optional) A tuple (inclusive) representing the area-range for - objects to be considered for metrics. Default to (0, inf). - max_num_detections: (Optional) The maximum number of detections for a - single image. Default to None. - labels_to_stack: (Optional) Keys for columns to be stacked as a single - numpy array as the labels. It is searched under the key labels, features - and transformed features. The desired format is [left bounadary, top - boudnary, right boundary, bottom boundary, class id]. e.g. ['xmin', - 'ymin', 'xmax', 'ymax', 'class_id'] - predictions_to_stack: (Optional) Output names for columns to be stacked as - a single numpy array as the prediction. It should be the model's output - names. The desired format is [left bounadary, top boudnary, right - boundary, bottom boundary, class id, confidence score]. e.g. ['xmin', - 'ymin', 'xmax', 'ymax', 'class_id', 'scores'] - num_detections_key: (Optional) An output name in which to find the number - of detections to use for evaluation for a given example. It does nothing - if predictions_to_stack is not set. The value for this output should be - a scalar value or a single-value tensor. The stacked predicitions will - be truncated with the specified number of detections. - allow_missing_key: (Optional) If true, the preprocessor will return empty - array instead of raising errors. + """Computes the recall of the predictions with respect to the labels. + + The metric uses true positives and false negatives to compute recall by + dividing the true positives by the sum of true positives and false negatives. + + If `sample_weight` is `None`, weights default to 1. + Use `sample_weight` of 0 to mask values. """ - super().__init__( - thresholds=thresholds, - name=name, - class_id=class_id, - iou_threshold=iou_threshold, - area_range=area_range, - max_num_detections=max_num_detections, - class_weight=class_weight, - labels_to_stack=labels_to_stack, - predictions_to_stack=predictions_to_stack, - num_detections_key=num_detections_key, - allow_missing_key=allow_missing_key) - - def _default_name(self) -> str: - return OBJECT_DETECTION_RECALL_NAME - - def _metric_computations( - self, - thresholds: Optional[Union[float, List[float]]] = None, - num_thresholds: Optional[int] = None, - iou_threshold: Optional[float] = None, - class_id: Optional[int] = None, - class_weight: Optional[float] = None, - area_range: Optional[Tuple[float, float]] = None, - max_num_detections: Optional[int] = None, - name: Optional[str] = None, - eval_config: Optional[config_pb2.EvalConfig] = None, - model_name: str = '', - output_name: str = '', - labels_to_stack: Optional[List[str]] = None, - predictions_to_stack: Optional[List[str]] = None, - num_detections_key: Optional[str] = None, - allow_missing_key: bool = False, - example_weighted: bool = False, - **kwargs, - ) -> metric_types.MetricComputations: - metric_util.validate_object_detection_arguments( - class_id=class_id, - class_weight=class_weight, - area_range=area_range, - max_num_detections=max_num_detections, - labels_to_stack=labels_to_stack, - predictions_to_stack=predictions_to_stack, - output_name=output_name) - - key = metric_types.MetricKey( - name=name, - model_name=model_name, - output_name=output_name, - sub_key=metric_types.SubKey(class_id=class_id), - example_weighted=example_weighted, - aggregation_type=None, - ) - - preprocessor = preprocessors.BoundingBoxMatchPreprocessor( - class_id=class_id, - iou_threshold=iou_threshold, - area_range=area_range, - max_num_detections=max_num_detections, - class_weight=class_weight, - labels_to_stack=labels_to_stack, - predictions_to_stack=predictions_to_stack, - num_detections_key=num_detections_key, - allow_missing_key=allow_missing_key, - model_name=model_name) - return super()._metric_computations( - thresholds=thresholds, - num_thresholds=num_thresholds, - name=name, - eval_config=eval_config, - model_name=model_name, - output_name=output_name, - preprocessors=[preprocessor], - metric_key=key, - example_weighted=example_weighted, - ) + + def __init__( + self, + thresholds: Optional[Union[float, List[float]]] = None, + name: Optional[str] = None, + iou_threshold: Optional[float] = None, + class_id: Optional[int] = None, + class_weight: Optional[float] = None, + area_range: Optional[Tuple[float, float]] = None, + max_num_detections: Optional[int] = None, + labels_to_stack: Optional[List[str]] = None, + predictions_to_stack: Optional[List[str]] = None, + num_detections_key: Optional[str] = None, + allow_missing_key: bool = False, + ): + """Initializes Recall metric. + + The metric supports using multiple outputs to form the labels/predictions if + the user specifies the label/predcition keys to stack. In this case, the + metric is not expected to work with multi-outputs. The metric only supports + multi outputs if the output of model is already pre-stacked in the expected + format, i.e. ['xmin', 'ymin', 'xmax', 'ymax', 'class_id'] for labels and + ['xmin', 'ymin', 'xmax', 'ymax', 'class_id', 'confidence scores'] for + predictions. + + Args: + ---- + thresholds: (Optional) A float value or a python list/tuple of float + threshold values in [0, 1]. A threshold is compared with prediction + values to determine the truth value of predictions (i.e., above the + threshold is `true`, below is `false`). One metric value is generated + for each threshold value. The default is to calculate recall with + `thresholds=0.5`. + name: (Optional) string name of the metric instance. + iou_threshold: (Optional) Thresholds for a detection and ground truth pair + with specific iou to be considered as a match. Default to 0.5 + class_id: (Optional) The class id for calculating metrics. + class_weight: (Optional) The weight associated with the object class id. + area_range: (Optional) A tuple (inclusive) representing the area-range for + objects to be considered for metrics. Default to (0, inf). + max_num_detections: (Optional) The maximum number of detections for a + single image. Default to None. + labels_to_stack: (Optional) Keys for columns to be stacked as a single + numpy array as the labels. It is searched under the key labels, features + and transformed features. The desired format is [left bounadary, top + boudnary, right boundary, bottom boundary, class id]. e.g. ['xmin', + 'ymin', 'xmax', 'ymax', 'class_id'] + predictions_to_stack: (Optional) Output names for columns to be stacked as + a single numpy array as the prediction. It should be the model's output + names. The desired format is [left bounadary, top boudnary, right + boundary, bottom boundary, class id, confidence score]. e.g. ['xmin', + 'ymin', 'xmax', 'ymax', 'class_id', 'scores'] + num_detections_key: (Optional) An output name in which to find the number + of detections to use for evaluation for a given example. It does nothing + if predictions_to_stack is not set. The value for this output should be + a scalar value or a single-value tensor. The stacked predicitions will + be truncated with the specified number of detections. + allow_missing_key: (Optional) If true, the preprocessor will return empty + array instead of raising errors. + """ + super().__init__( + thresholds=thresholds, + name=name, + class_id=class_id, + iou_threshold=iou_threshold, + area_range=area_range, + max_num_detections=max_num_detections, + class_weight=class_weight, + labels_to_stack=labels_to_stack, + predictions_to_stack=predictions_to_stack, + num_detections_key=num_detections_key, + allow_missing_key=allow_missing_key, + ) + + def _default_name(self) -> str: + return OBJECT_DETECTION_RECALL_NAME + + def _metric_computations( + self, + thresholds: Optional[Union[float, List[float]]] = None, + num_thresholds: Optional[int] = None, + iou_threshold: Optional[float] = None, + class_id: Optional[int] = None, + class_weight: Optional[float] = None, + area_range: Optional[Tuple[float, float]] = None, + max_num_detections: Optional[int] = None, + name: Optional[str] = None, + eval_config: Optional[config_pb2.EvalConfig] = None, + model_name: str = "", + output_name: str = "", + labels_to_stack: Optional[List[str]] = None, + predictions_to_stack: Optional[List[str]] = None, + num_detections_key: Optional[str] = None, + allow_missing_key: bool = False, + example_weighted: bool = False, + **kwargs, + ) -> metric_types.MetricComputations: + metric_util.validate_object_detection_arguments( + class_id=class_id, + class_weight=class_weight, + area_range=area_range, + max_num_detections=max_num_detections, + labels_to_stack=labels_to_stack, + predictions_to_stack=predictions_to_stack, + output_name=output_name, + ) + + key = metric_types.MetricKey( + name=name, + model_name=model_name, + output_name=output_name, + sub_key=metric_types.SubKey(class_id=class_id), + example_weighted=example_weighted, + aggregation_type=None, + ) + + preprocessor = preprocessors.BoundingBoxMatchPreprocessor( + class_id=class_id, + iou_threshold=iou_threshold, + area_range=area_range, + max_num_detections=max_num_detections, + class_weight=class_weight, + labels_to_stack=labels_to_stack, + predictions_to_stack=predictions_to_stack, + num_detections_key=num_detections_key, + allow_missing_key=allow_missing_key, + model_name=model_name, + ) + return super()._metric_computations( + thresholds=thresholds, + num_thresholds=num_thresholds, + name=name, + eval_config=eval_config, + model_name=model_name, + output_name=output_name, + preprocessors=[preprocessor], + metric_key=key, + example_weighted=example_weighted, + ) metric_types.register_metric(ObjectDetectionRecall) class ObjectDetectionPrecision(confusion_matrix_metrics.Precision): - """Computes the precision of the predictions with respect to the labels. - - The metric uses true positives and false positives to compute precision by - dividing the true positives by the sum of true positives and false positives. - - If `sample_weight` is `None`, weights default to 1. - Use `sample_weight` of 0 to mask values. - """ - - def __init__(self, - thresholds: Optional[Union[float, List[float]]] = None, - name: Optional[str] = None, - iou_threshold: Optional[float] = None, - class_id: Optional[int] = None, - class_weight: Optional[float] = None, - area_range: Optional[Tuple[float, float]] = None, - max_num_detections: Optional[int] = None, - labels_to_stack: Optional[List[str]] = None, - predictions_to_stack: Optional[List[str]] = None, - num_detections_key: Optional[str] = None, - allow_missing_key: bool = False): - """Initializes Recall metric. - - The metric supports using multiple outputs to form the labels/predictions if - the user specifies the label/predcition keys to stack. In this case, the - metric is not expected to work with multi-outputs. The metric only supports - multi outputs if the output of model is already pre-stacked in the expected - format, i.e. ['xmin', 'ymin', 'xmax', 'ymax', 'class_id'] for labels and - ['xmin', 'ymin', 'xmax', 'ymax', 'class_id', 'confidence scores'] for - predictions. - - Args: - thresholds: (Optional) A float value or a python list/tuple of float - threshold values in [0, 1]. A threshold is compared with prediction - values to determine the truth value of predictions (i.e., above the - threshold is `true`, below is `false`). One metric value is generated - for each threshold value. The default is to calculate precision with - `thresholds=0.5`. - name: (Optional) string name of the metric instance. - iou_threshold: (Optional) Thresholds for a detection and ground truth pair - with specific iou to be considered as a match. Default to 0.5 - class_id: (Optional) The class id for calculating metrics. - class_weight: (Optional) The weight associated with the object class id. - area_range: (Optional) A tuple (inclusive) representing the area-range for - objects to be considered for metrics. Default to (0, inf). - max_num_detections: (Optional) The maximum number of detections for a - single image. Default to None. - labels_to_stack: (Optional) Keys for columns to be stacked as a single - numpy array as the labels. It is searched under the key labels, features - and transformed features. The desired format is [left bounadary, top - boudnary, right boundary, bottom boundary, class id]. e.g. ['xmin', - 'ymin', 'xmax', 'ymax', 'class_id'] - predictions_to_stack: (Optional) Output names for columns to be stacked as - a single numpy array as the prediction. It should be the model's output - names. The desired format is [left bounadary, top boudnary, right - boundary, bottom boundary, class id, confidence score]. e.g. ['xmin', - 'ymin', 'xmax', 'ymax', 'class_id', 'scores'] - num_detections_key: (Optional) An output name in which to find the number - of detections to use for evaluation for a given example. It does nothing - if predictions_to_stack is not set. The value for this output should be - a scalar value or a single-value tensor. The stacked predicitions will - be truncated with the specified number of detections. - allow_missing_key: (Optional) If true, the preprocessor will return empty - array instead of raising errors. + """Computes the precision of the predictions with respect to the labels. + + The metric uses true positives and false positives to compute precision by + dividing the true positives by the sum of true positives and false positives. + + If `sample_weight` is `None`, weights default to 1. + Use `sample_weight` of 0 to mask values. """ - super().__init__( - thresholds=thresholds, - name=name, - class_id=class_id, - iou_threshold=iou_threshold, - area_range=area_range, - max_num_detections=max_num_detections, - class_weight=class_weight, - labels_to_stack=labels_to_stack, - predictions_to_stack=predictions_to_stack, - num_detections_key=num_detections_key, - allow_missing_key=allow_missing_key) - - def _default_name(self) -> str: - return OBJECT_DETECTION_PRECISION_NAME - - def _metric_computations( - self, - thresholds: Optional[Union[float, List[float]]] = None, - num_thresholds: Optional[int] = None, - iou_threshold: Optional[float] = None, - class_id: Optional[int] = None, - class_weight: Optional[float] = None, - area_range: Optional[Tuple[float, float]] = None, - max_num_detections: Optional[int] = None, - name: Optional[str] = None, - eval_config: Optional[config_pb2.EvalConfig] = None, - model_name: str = '', - output_name: str = '', - labels_to_stack: Optional[List[str]] = None, - predictions_to_stack: Optional[List[str]] = None, - num_detections_key: Optional[str] = None, - allow_missing_key: bool = False, - example_weighted: bool = False, - **kwargs, - ) -> metric_types.MetricComputations: - metric_util.validate_object_detection_arguments( - class_id=class_id, - class_weight=class_weight, - area_range=area_range, - max_num_detections=max_num_detections, - labels_to_stack=labels_to_stack, - predictions_to_stack=predictions_to_stack, - output_name=output_name) - - key = metric_types.MetricKey( - name=name, - model_name=model_name, - output_name=output_name, - sub_key=metric_types.SubKey(class_id=class_id), - example_weighted=example_weighted, - aggregation_type=None, - ) - - preprocessor = preprocessors.BoundingBoxMatchPreprocessor( - class_id=class_id, - iou_threshold=iou_threshold, - area_range=area_range, - max_num_detections=max_num_detections, - class_weight=class_weight, - labels_to_stack=labels_to_stack, - predictions_to_stack=predictions_to_stack, - num_detections_key=num_detections_key, - allow_missing_key=allow_missing_key, - model_name=model_name) - return super()._metric_computations( - thresholds=thresholds, - num_thresholds=num_thresholds, - name=name, - eval_config=eval_config, - model_name=model_name, - output_name=output_name, - preprocessors=[preprocessor], - metric_key=key, - example_weighted=example_weighted, - ) + + def __init__( + self, + thresholds: Optional[Union[float, List[float]]] = None, + name: Optional[str] = None, + iou_threshold: Optional[float] = None, + class_id: Optional[int] = None, + class_weight: Optional[float] = None, + area_range: Optional[Tuple[float, float]] = None, + max_num_detections: Optional[int] = None, + labels_to_stack: Optional[List[str]] = None, + predictions_to_stack: Optional[List[str]] = None, + num_detections_key: Optional[str] = None, + allow_missing_key: bool = False, + ): + """Initializes Recall metric. + + The metric supports using multiple outputs to form the labels/predictions if + the user specifies the label/predcition keys to stack. In this case, the + metric is not expected to work with multi-outputs. The metric only supports + multi outputs if the output of model is already pre-stacked in the expected + format, i.e. ['xmin', 'ymin', 'xmax', 'ymax', 'class_id'] for labels and + ['xmin', 'ymin', 'xmax', 'ymax', 'class_id', 'confidence scores'] for + predictions. + + Args: + ---- + thresholds: (Optional) A float value or a python list/tuple of float + threshold values in [0, 1]. A threshold is compared with prediction + values to determine the truth value of predictions (i.e., above the + threshold is `true`, below is `false`). One metric value is generated + for each threshold value. The default is to calculate precision with + `thresholds=0.5`. + name: (Optional) string name of the metric instance. + iou_threshold: (Optional) Thresholds for a detection and ground truth pair + with specific iou to be considered as a match. Default to 0.5 + class_id: (Optional) The class id for calculating metrics. + class_weight: (Optional) The weight associated with the object class id. + area_range: (Optional) A tuple (inclusive) representing the area-range for + objects to be considered for metrics. Default to (0, inf). + max_num_detections: (Optional) The maximum number of detections for a + single image. Default to None. + labels_to_stack: (Optional) Keys for columns to be stacked as a single + numpy array as the labels. It is searched under the key labels, features + and transformed features. The desired format is [left bounadary, top + boudnary, right boundary, bottom boundary, class id]. e.g. ['xmin', + 'ymin', 'xmax', 'ymax', 'class_id'] + predictions_to_stack: (Optional) Output names for columns to be stacked as + a single numpy array as the prediction. It should be the model's output + names. The desired format is [left bounadary, top boudnary, right + boundary, bottom boundary, class id, confidence score]. e.g. ['xmin', + 'ymin', 'xmax', 'ymax', 'class_id', 'scores'] + num_detections_key: (Optional) An output name in which to find the number + of detections to use for evaluation for a given example. It does nothing + if predictions_to_stack is not set. The value for this output should be + a scalar value or a single-value tensor. The stacked predicitions will + be truncated with the specified number of detections. + allow_missing_key: (Optional) If true, the preprocessor will return empty + array instead of raising errors. + """ + super().__init__( + thresholds=thresholds, + name=name, + class_id=class_id, + iou_threshold=iou_threshold, + area_range=area_range, + max_num_detections=max_num_detections, + class_weight=class_weight, + labels_to_stack=labels_to_stack, + predictions_to_stack=predictions_to_stack, + num_detections_key=num_detections_key, + allow_missing_key=allow_missing_key, + ) + + def _default_name(self) -> str: + return OBJECT_DETECTION_PRECISION_NAME + + def _metric_computations( + self, + thresholds: Optional[Union[float, List[float]]] = None, + num_thresholds: Optional[int] = None, + iou_threshold: Optional[float] = None, + class_id: Optional[int] = None, + class_weight: Optional[float] = None, + area_range: Optional[Tuple[float, float]] = None, + max_num_detections: Optional[int] = None, + name: Optional[str] = None, + eval_config: Optional[config_pb2.EvalConfig] = None, + model_name: str = "", + output_name: str = "", + labels_to_stack: Optional[List[str]] = None, + predictions_to_stack: Optional[List[str]] = None, + num_detections_key: Optional[str] = None, + allow_missing_key: bool = False, + example_weighted: bool = False, + **kwargs, + ) -> metric_types.MetricComputations: + metric_util.validate_object_detection_arguments( + class_id=class_id, + class_weight=class_weight, + area_range=area_range, + max_num_detections=max_num_detections, + labels_to_stack=labels_to_stack, + predictions_to_stack=predictions_to_stack, + output_name=output_name, + ) + + key = metric_types.MetricKey( + name=name, + model_name=model_name, + output_name=output_name, + sub_key=metric_types.SubKey(class_id=class_id), + example_weighted=example_weighted, + aggregation_type=None, + ) + + preprocessor = preprocessors.BoundingBoxMatchPreprocessor( + class_id=class_id, + iou_threshold=iou_threshold, + area_range=area_range, + max_num_detections=max_num_detections, + class_weight=class_weight, + labels_to_stack=labels_to_stack, + predictions_to_stack=predictions_to_stack, + num_detections_key=num_detections_key, + allow_missing_key=allow_missing_key, + model_name=model_name, + ) + return super()._metric_computations( + thresholds=thresholds, + num_thresholds=num_thresholds, + name=name, + eval_config=eval_config, + model_name=model_name, + output_name=output_name, + preprocessors=[preprocessor], + metric_key=key, + example_weighted=example_weighted, + ) metric_types.register_metric(ObjectDetectionPrecision) class ObjectDetectionMaxRecall(confusion_matrix_metrics.MaxRecall): - """Computes the max recall of the predictions with respect to the labels. - - The metric uses true positives and false negatives to compute recall by - dividing the true positives by the sum of true positives and false negatives. - - Effectively the recall at threshold = epsilon(1.0e-12). It is equilvalent - to the recall defined in COCO metrics. - - If `sample_weight` is `None`, weights default to 1. - Use `sample_weight` of 0 to mask values. - """ - - def __init__(self, - name: Optional[str] = None, - iou_threshold: Optional[float] = None, - class_id: Optional[int] = None, - class_weight: Optional[float] = None, - area_range: Optional[Tuple[float, float]] = None, - max_num_detections: Optional[int] = None, - labels_to_stack: Optional[List[str]] = None, - predictions_to_stack: Optional[List[str]] = None, - num_detections_key: Optional[str] = None, - allow_missing_key: bool = False): - """Initializes MaxRecall metrics, it calculates the maximum recall. - - The metric supports using multiple outputs to form the labels/predictions if - the user specifies the label/predcition keys to stack. In this case, the - metric is not expected to work with multi-outputs. The metric only supports - multi outputs if the output of model is already pre-stacked in the expected - format, i.e. ['xmin', 'ymin', 'xmax', 'ymax', 'class_id'] for labels and - ['xmin', 'ymin', 'xmax', 'ymax', 'class_id', 'confidence scores'] for - predictions. - - Args: - name: (Optional) string name of the metric instance. - iou_threshold: (Optional) Thresholds for a detection and ground truth pair - with specific iou to be considered as a match. Default to 0.5 - class_id: (Optional) The class id for calculating metrics. - class_weight: (Optional) The weight associated with the object class id. - area_range: (Optional) A tuple (inclusive) representing the area-range for - objects to be considered for metrics. Default to (0, inf). - max_num_detections: (Optional) The maximum number of detections for a - single image. Default to None. - labels_to_stack: (Optional) Keys for columns to be stacked as a single - numpy array as the labels. It is searched under the key labels, features - and transformed features. The desired format is [left bounadary, top - boudnary, right boundary, bottom boundary, class id]. e.g. ['xmin', - 'ymin', 'xmax', 'ymax', 'class_id'] - predictions_to_stack: (Optional) Output names for columns to be stacked as - a single numpy array as the prediction. It should be the model's output - names. The desired format is [left bounadary, top boudnary, right - boundary, bottom boundary, class id, confidence score]. e.g. ['xmin', - 'ymin', 'xmax', 'ymax', 'class_id', 'scores'] - num_detections_key: (Optional) An output name in which to find the number - of detections to use for evaluation for a given example. It does nothing - if predictions_to_stack is not set. The value for this output should be - a scalar value or a single-value tensor. The stacked predicitions will - be truncated with the specified number of detections. - allow_missing_key: (Optional) If true, the preprocessor will return empty - array instead of raising errors. + """Computes the max recall of the predictions with respect to the labels. + + The metric uses true positives and false negatives to compute recall by + dividing the true positives by the sum of true positives and false negatives. + + Effectively the recall at threshold = epsilon(1.0e-12). It is equilvalent + to the recall defined in COCO metrics. + + If `sample_weight` is `None`, weights default to 1. + Use `sample_weight` of 0 to mask values. """ - super().__init__( - name=name, - class_id=class_id, - iou_threshold=iou_threshold, - area_range=area_range, - max_num_detections=max_num_detections, - class_weight=class_weight, - labels_to_stack=labels_to_stack, - predictions_to_stack=predictions_to_stack, - num_detections_key=num_detections_key, - allow_missing_key=allow_missing_key) - - def _default_name(self) -> str: - return OBJECT_DETECTION_MAX_RECALL_NAME - - def _metric_computations( - self, - thresholds: Optional[Union[float, List[float]]] = None, - num_thresholds: Optional[int] = None, - iou_threshold: Optional[float] = None, - class_id: Optional[int] = None, - class_weight: Optional[float] = None, - area_range: Optional[Tuple[float, float]] = None, - max_num_detections: Optional[int] = None, - name: Optional[str] = None, - eval_config: Optional[config_pb2.EvalConfig] = None, - model_name: str = '', - output_name: str = '', - labels_to_stack: Optional[List[str]] = None, - predictions_to_stack: Optional[List[str]] = None, - num_detections_key: Optional[str] = None, - allow_missing_key: bool = False, - example_weighted: bool = False, - **kwargs, - ) -> metric_types.MetricComputations: - metric_util.validate_object_detection_arguments( - class_id=class_id, - class_weight=class_weight, - area_range=area_range, - max_num_detections=max_num_detections, - labels_to_stack=labels_to_stack, - predictions_to_stack=predictions_to_stack) - - key = metric_types.MetricKey( - name=name, - model_name=model_name, - output_name=output_name, - sub_key=metric_types.SubKey(class_id=class_id), - example_weighted=example_weighted, - aggregation_type=None, - ) - - preprocessor = preprocessors.BoundingBoxMatchPreprocessor( - class_id=class_id, - iou_threshold=iou_threshold, - area_range=area_range, - max_num_detections=max_num_detections, - class_weight=class_weight, - labels_to_stack=labels_to_stack, - predictions_to_stack=predictions_to_stack, - num_detections_key=num_detections_key, - allow_missing_key=allow_missing_key, - model_name=model_name) - return super()._metric_computations( - thresholds=thresholds, - num_thresholds=num_thresholds, - name=name, - eval_config=eval_config, - model_name=model_name, - output_name=output_name, - preprocessors=[preprocessor], - metric_key=key, - example_weighted=example_weighted, + + def __init__( + self, + name: Optional[str] = None, + iou_threshold: Optional[float] = None, + class_id: Optional[int] = None, + class_weight: Optional[float] = None, + area_range: Optional[Tuple[float, float]] = None, + max_num_detections: Optional[int] = None, + labels_to_stack: Optional[List[str]] = None, + predictions_to_stack: Optional[List[str]] = None, + num_detections_key: Optional[str] = None, + allow_missing_key: bool = False, + ): + """Initializes MaxRecall metrics, it calculates the maximum recall. + + The metric supports using multiple outputs to form the labels/predictions if + the user specifies the label/predcition keys to stack. In this case, the + metric is not expected to work with multi-outputs. The metric only supports + multi outputs if the output of model is already pre-stacked in the expected + format, i.e. ['xmin', 'ymin', 'xmax', 'ymax', 'class_id'] for labels and + ['xmin', 'ymin', 'xmax', 'ymax', 'class_id', 'confidence scores'] for + predictions. + + Args: + ---- + name: (Optional) string name of the metric instance. + iou_threshold: (Optional) Thresholds for a detection and ground truth pair + with specific iou to be considered as a match. Default to 0.5 + class_id: (Optional) The class id for calculating metrics. + class_weight: (Optional) The weight associated with the object class id. + area_range: (Optional) A tuple (inclusive) representing the area-range for + objects to be considered for metrics. Default to (0, inf). + max_num_detections: (Optional) The maximum number of detections for a + single image. Default to None. + labels_to_stack: (Optional) Keys for columns to be stacked as a single + numpy array as the labels. It is searched under the key labels, features + and transformed features. The desired format is [left bounadary, top + boudnary, right boundary, bottom boundary, class id]. e.g. ['xmin', + 'ymin', 'xmax', 'ymax', 'class_id'] + predictions_to_stack: (Optional) Output names for columns to be stacked as + a single numpy array as the prediction. It should be the model's output + names. The desired format is [left bounadary, top boudnary, right + boundary, bottom boundary, class id, confidence score]. e.g. ['xmin', + 'ymin', 'xmax', 'ymax', 'class_id', 'scores'] + num_detections_key: (Optional) An output name in which to find the number + of detections to use for evaluation for a given example. It does nothing + if predictions_to_stack is not set. The value for this output should be + a scalar value or a single-value tensor. The stacked predicitions will + be truncated with the specified number of detections. + allow_missing_key: (Optional) If true, the preprocessor will return empty + array instead of raising errors. + """ + super().__init__( + name=name, + class_id=class_id, + iou_threshold=iou_threshold, + area_range=area_range, + max_num_detections=max_num_detections, + class_weight=class_weight, + labels_to_stack=labels_to_stack, + predictions_to_stack=predictions_to_stack, + num_detections_key=num_detections_key, + allow_missing_key=allow_missing_key, + ) + + def _default_name(self) -> str: + return OBJECT_DETECTION_MAX_RECALL_NAME + + def _metric_computations( + self, + thresholds: Optional[Union[float, List[float]]] = None, + num_thresholds: Optional[int] = None, + iou_threshold: Optional[float] = None, + class_id: Optional[int] = None, + class_weight: Optional[float] = None, + area_range: Optional[Tuple[float, float]] = None, + max_num_detections: Optional[int] = None, + name: Optional[str] = None, + eval_config: Optional[config_pb2.EvalConfig] = None, + model_name: str = "", + output_name: str = "", + labels_to_stack: Optional[List[str]] = None, + predictions_to_stack: Optional[List[str]] = None, + num_detections_key: Optional[str] = None, + allow_missing_key: bool = False, + example_weighted: bool = False, **kwargs, - ) + ) -> metric_types.MetricComputations: + metric_util.validate_object_detection_arguments( + class_id=class_id, + class_weight=class_weight, + area_range=area_range, + max_num_detections=max_num_detections, + labels_to_stack=labels_to_stack, + predictions_to_stack=predictions_to_stack, + ) + + key = metric_types.MetricKey( + name=name, + model_name=model_name, + output_name=output_name, + sub_key=metric_types.SubKey(class_id=class_id), + example_weighted=example_weighted, + aggregation_type=None, + ) + + preprocessor = preprocessors.BoundingBoxMatchPreprocessor( + class_id=class_id, + iou_threshold=iou_threshold, + area_range=area_range, + max_num_detections=max_num_detections, + class_weight=class_weight, + labels_to_stack=labels_to_stack, + predictions_to_stack=predictions_to_stack, + num_detections_key=num_detections_key, + allow_missing_key=allow_missing_key, + model_name=model_name, + ) + return super()._metric_computations( + thresholds=thresholds, + num_thresholds=num_thresholds, + name=name, + eval_config=eval_config, + model_name=model_name, + output_name=output_name, + preprocessors=[preprocessor], + metric_key=key, + example_weighted=example_weighted, + **kwargs, + ) metric_types.register_metric(ObjectDetectionMaxRecall) -class ObjectDetectionThresholdAtRecall( - confusion_matrix_metrics.ThresholdAtRecall): - """Computes maximum threshold where recall is >= specified value. - - If `sample_weight` is `None`, weights default to 1. - Use `sample_weight` of 0 to mask values. - """ - - def __init__(self, - recall: Union[float, List[float]], - thresholds: Optional[List[float]] = None, - num_thresholds: Optional[int] = None, - name: Optional[str] = None, - iou_threshold: Optional[float] = None, - class_id: Optional[int] = None, - class_weight: Optional[float] = None, - area_range: Optional[Tuple[float, float]] = None, - max_num_detections: Optional[int] = None, - labels_to_stack: Optional[List[str]] = None, - predictions_to_stack: Optional[List[str]] = None, - num_detections_key: Optional[str] = None, - allow_missing_key: bool = False): - """Initializes ThresholdAtRecall metric. - - The metric supports using multiple outputs to form the labels/predictions if - the user specifies the label/predcition keys to stack. In this case, the - metric is not expected to work with multi-outputs. The metric only supports - multi outputs if the output of model is already pre-stacked in the expected - format, i.e. ['xmin', 'ymin', 'xmax', 'ymax', 'class_id'] for labels and - ['xmin', 'ymin', 'xmax', 'ymax', 'class_id', 'confidence scores'] for - predictions. - - Args: - recall: A scalar or a list of scalar values in range `[0, 1]`. - thresholds: (Optional) Thresholds to use for calculating the matrices. Use - one of either thresholds or num_thresholds. - num_thresholds: (Optional) Defaults to 1000. The number of thresholds to - use for matching the given recall. - name: (Optional) string name of the metric instance. - iou_threshold: (Optional) Thresholds for a detection and ground truth pair - with specific iou to be considered as a match. Default to 0.5 - class_id: (Optional) The class id for calculating metrics. - class_weight: (Optional) The weight associated with the object class id. - area_range: (Optional) A tuple (inclusive) representing the area-range for - objects to be considered for metrics. Default to (0, inf). - max_num_detections: (Optional) The maximum number of detections for a - single image. Default to None. - labels_to_stack: (Optional) Keys for columns to be stacked as a single - numpy array as the labels. It is searched under the key labels, features - and transformed features. The desired format is [left bounadary, top - boudnary, right boundary, bottom boundary, class id]. e.g. ['xmin', - 'ymin', 'xmax', 'ymax', 'class_id'] - predictions_to_stack: (Optional) Output names for columns to be stacked as - a single numpy array as the prediction. It should be the model's output - names. The desired format is [left bounadary, top boudnary, right - boundary, bottom boundary, class id, confidence score]. e.g. ['xmin', - 'ymin', 'xmax', 'ymax', 'class_id', 'scores'] - num_detections_key: (Optional) An output name in which to find the number - of detections to use for evaluation for a given example. It does nothing - if predictions_to_stack is not set. The value for this output should be - a scalar value or a single-value tensor. The stacked predicitions will - be truncated with the specified number of detections. - allow_missing_key: (Optional) If true, the preprocessor will return empty - array instead of raising errors. +class ObjectDetectionThresholdAtRecall(confusion_matrix_metrics.ThresholdAtRecall): + """Computes maximum threshold where recall is >= specified value. + + If `sample_weight` is `None`, weights default to 1. + Use `sample_weight` of 0 to mask values. """ - for r in [recall] if isinstance(recall, float) else recall: - if r < 0 or r > 1: - raise ValueError('Argument `recall` must be in the range [0, 1]. ' - f'Received: recall={r}') - - super().__init__( - thresholds=thresholds, - num_thresholds=num_thresholds, - recall=recall, - name=name, - class_id=class_id, - iou_threshold=iou_threshold, - area_range=area_range, - max_num_detections=max_num_detections, - class_weight=class_weight, - labels_to_stack=labels_to_stack, - predictions_to_stack=predictions_to_stack, - num_detections_key=num_detections_key, - allow_missing_key=allow_missing_key) - - def _default_name(self) -> str: - return OBJECT_DETECTION_THRESHOLD_AT_RECALL_NAME - - def _metric_computations( - self, - thresholds: Optional[Union[float, List[float]]] = None, - num_thresholds: Optional[int] = None, - iou_threshold: Optional[float] = None, - class_id: Optional[int] = None, - class_weight: Optional[float] = None, - area_range: Optional[Tuple[float, float]] = None, - max_num_detections: Optional[int] = None, - name: Optional[str] = None, - eval_config: Optional[config_pb2.EvalConfig] = None, - model_name: str = '', - output_name: str = '', - labels_to_stack: Optional[List[str]] = None, - predictions_to_stack: Optional[List[str]] = None, - num_detections_key: Optional[str] = None, - allow_missing_key: bool = False, - example_weighted: bool = False, - **kwargs, - ) -> metric_types.MetricComputations: - metric_util.validate_object_detection_arguments( - class_id=class_id, - class_weight=class_weight, - area_range=area_range, - max_num_detections=max_num_detections, - labels_to_stack=labels_to_stack, - predictions_to_stack=predictions_to_stack, - output_name=output_name) - - key = metric_types.MetricKey( - name=name, - model_name=model_name, - output_name=output_name, - sub_key=metric_types.SubKey(class_id=class_id), - example_weighted=example_weighted, - aggregation_type=None, - ) - - preprocessor = preprocessors.BoundingBoxMatchPreprocessor( - class_id=class_id, - iou_threshold=iou_threshold, - area_range=area_range, - max_num_detections=max_num_detections, - class_weight=class_weight, - labels_to_stack=labels_to_stack, - predictions_to_stack=predictions_to_stack, - num_detections_key=num_detections_key, - allow_missing_key=allow_missing_key, - model_name=model_name) - return super()._metric_computations( - thresholds=thresholds, - num_thresholds=num_thresholds, - name=name, - eval_config=eval_config, - model_name=model_name, - output_name=output_name, - preprocessors=[preprocessor], - metric_key=key, - example_weighted=example_weighted, + + def __init__( + self, + recall: Union[float, List[float]], + thresholds: Optional[List[float]] = None, + num_thresholds: Optional[int] = None, + name: Optional[str] = None, + iou_threshold: Optional[float] = None, + class_id: Optional[int] = None, + class_weight: Optional[float] = None, + area_range: Optional[Tuple[float, float]] = None, + max_num_detections: Optional[int] = None, + labels_to_stack: Optional[List[str]] = None, + predictions_to_stack: Optional[List[str]] = None, + num_detections_key: Optional[str] = None, + allow_missing_key: bool = False, + ): + """Initializes ThresholdAtRecall metric. + + The metric supports using multiple outputs to form the labels/predictions if + the user specifies the label/predcition keys to stack. In this case, the + metric is not expected to work with multi-outputs. The metric only supports + multi outputs if the output of model is already pre-stacked in the expected + format, i.e. ['xmin', 'ymin', 'xmax', 'ymax', 'class_id'] for labels and + ['xmin', 'ymin', 'xmax', 'ymax', 'class_id', 'confidence scores'] for + predictions. + + Args: + ---- + recall: A scalar or a list of scalar values in range `[0, 1]`. + thresholds: (Optional) Thresholds to use for calculating the matrices. Use + one of either thresholds or num_thresholds. + num_thresholds: (Optional) Defaults to 1000. The number of thresholds to + use for matching the given recall. + name: (Optional) string name of the metric instance. + iou_threshold: (Optional) Thresholds for a detection and ground truth pair + with specific iou to be considered as a match. Default to 0.5 + class_id: (Optional) The class id for calculating metrics. + class_weight: (Optional) The weight associated with the object class id. + area_range: (Optional) A tuple (inclusive) representing the area-range for + objects to be considered for metrics. Default to (0, inf). + max_num_detections: (Optional) The maximum number of detections for a + single image. Default to None. + labels_to_stack: (Optional) Keys for columns to be stacked as a single + numpy array as the labels. It is searched under the key labels, features + and transformed features. The desired format is [left bounadary, top + boudnary, right boundary, bottom boundary, class id]. e.g. ['xmin', + 'ymin', 'xmax', 'ymax', 'class_id'] + predictions_to_stack: (Optional) Output names for columns to be stacked as + a single numpy array as the prediction. It should be the model's output + names. The desired format is [left bounadary, top boudnary, right + boundary, bottom boundary, class id, confidence score]. e.g. ['xmin', + 'ymin', 'xmax', 'ymax', 'class_id', 'scores'] + num_detections_key: (Optional) An output name in which to find the number + of detections to use for evaluation for a given example. It does nothing + if predictions_to_stack is not set. The value for this output should be + a scalar value or a single-value tensor. The stacked predicitions will + be truncated with the specified number of detections. + allow_missing_key: (Optional) If true, the preprocessor will return empty + array instead of raising errors. + """ + for r in [recall] if isinstance(recall, float) else recall: + if r < 0 or r > 1: + raise ValueError( + "Argument `recall` must be in the range [0, 1]. " + f"Received: recall={r}" + ) + + super().__init__( + thresholds=thresholds, + num_thresholds=num_thresholds, + recall=recall, + name=name, + class_id=class_id, + iou_threshold=iou_threshold, + area_range=area_range, + max_num_detections=max_num_detections, + class_weight=class_weight, + labels_to_stack=labels_to_stack, + predictions_to_stack=predictions_to_stack, + num_detections_key=num_detections_key, + allow_missing_key=allow_missing_key, + ) + + def _default_name(self) -> str: + return OBJECT_DETECTION_THRESHOLD_AT_RECALL_NAME + + def _metric_computations( + self, + thresholds: Optional[Union[float, List[float]]] = None, + num_thresholds: Optional[int] = None, + iou_threshold: Optional[float] = None, + class_id: Optional[int] = None, + class_weight: Optional[float] = None, + area_range: Optional[Tuple[float, float]] = None, + max_num_detections: Optional[int] = None, + name: Optional[str] = None, + eval_config: Optional[config_pb2.EvalConfig] = None, + model_name: str = "", + output_name: str = "", + labels_to_stack: Optional[List[str]] = None, + predictions_to_stack: Optional[List[str]] = None, + num_detections_key: Optional[str] = None, + allow_missing_key: bool = False, + example_weighted: bool = False, **kwargs, - ) + ) -> metric_types.MetricComputations: + metric_util.validate_object_detection_arguments( + class_id=class_id, + class_weight=class_weight, + area_range=area_range, + max_num_detections=max_num_detections, + labels_to_stack=labels_to_stack, + predictions_to_stack=predictions_to_stack, + output_name=output_name, + ) + + key = metric_types.MetricKey( + name=name, + model_name=model_name, + output_name=output_name, + sub_key=metric_types.SubKey(class_id=class_id), + example_weighted=example_weighted, + aggregation_type=None, + ) + + preprocessor = preprocessors.BoundingBoxMatchPreprocessor( + class_id=class_id, + iou_threshold=iou_threshold, + area_range=area_range, + max_num_detections=max_num_detections, + class_weight=class_weight, + labels_to_stack=labels_to_stack, + predictions_to_stack=predictions_to_stack, + num_detections_key=num_detections_key, + allow_missing_key=allow_missing_key, + model_name=model_name, + ) + return super()._metric_computations( + thresholds=thresholds, + num_thresholds=num_thresholds, + name=name, + eval_config=eval_config, + model_name=model_name, + output_name=output_name, + preprocessors=[preprocessor], + metric_key=key, + example_weighted=example_weighted, + **kwargs, + ) metric_types.register_metric(ObjectDetectionThresholdAtRecall) diff --git a/tensorflow_model_analysis/metrics/object_detection_confusion_matrix_metrics_test.py b/tensorflow_model_analysis/metrics/object_detection_confusion_matrix_metrics_test.py index 9d772e6e00..c64e71c197 100644 --- a/tensorflow_model_analysis/metrics/object_detection_confusion_matrix_metrics_test.py +++ b/tensorflow_model_analysis/metrics/object_detection_confusion_matrix_metrics_test.py @@ -12,21 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. """Tests for object detection related confusion matrix metrics.""" -from absl.testing import absltest -from absl.testing import parameterized + import apache_beam as beam -from apache_beam.testing import util import numpy as np +from absl.testing import absltest, parameterized +from apache_beam.testing import util +from google.protobuf import text_format + import tensorflow_model_analysis as tfma -from tensorflow_model_analysis.proto import config_pb2 from tensorflow_model_analysis.metrics import metric_types -from google.protobuf import text_format +from tensorflow_model_analysis.proto import config_pb2 -class ObjectDetectionConfusionMatrixMetricsTest(parameterized.TestCase): - @parameterized.named_parameters(('_max_recall', - text_format.Parse( - """ +class ObjectDetectionConfusionMatrixMetricsTest(parameterized.TestCase): + @parameterized.named_parameters( + ( + "_max_recall", + text_format.Parse( + """ model_specs { signature_name: "serving_default" prediction_key: "predictions" # placeholder @@ -41,10 +44,16 @@ class ObjectDetectionConfusionMatrixMetricsTest(parameterized.TestCase): '"max_num_detections":100, "name":"maxrecall"' } } - """, config_pb2.EvalConfig()), ['maxrecall'], [2 / 3]), - ('_precision_at_recall', - text_format.Parse( - """ + """, + config_pb2.EvalConfig(), + ), + ["maxrecall"], + [2 / 3], + ), + ( + "_precision_at_recall", + text_format.Parse( + """ model_specs { signature_name: "serving_default" prediction_key: "predictions" # placeholder @@ -59,10 +68,16 @@ class ObjectDetectionConfusionMatrixMetricsTest(parameterized.TestCase): '"max_num_detections":100, "name":"precisionatrecall"' } } - """, config_pb2.EvalConfig()), ['precisionatrecall'], [3 / 5]), - ('_recall', - text_format.Parse( - """ + """, + config_pb2.EvalConfig(), + ), + ["precisionatrecall"], + [3 / 5], + ), + ( + "_recall", + text_format.Parse( + """ model_specs { signature_name: "serving_default" prediction_key: "predictions" # placeholder @@ -77,9 +92,16 @@ class ObjectDetectionConfusionMatrixMetricsTest(parameterized.TestCase): '"max_num_detections":100, "name":"recall"' } } - """, config_pb2.EvalConfig()), ['recall'], [2 / 3]), ('_precision', - text_format.Parse( - """ + """, + config_pb2.EvalConfig(), + ), + ["recall"], + [2 / 3], + ), + ( + "_precision", + text_format.Parse( + """ model_specs { signature_name: "serving_default" prediction_key: "predictions" # placeholder @@ -94,9 +116,16 @@ class ObjectDetectionConfusionMatrixMetricsTest(parameterized.TestCase): '"max_num_detections":100, "name":"precision"' } } - """, config_pb2.EvalConfig()), ['precision'], [0.5]), ('_threshold_at_recall', - text_format.Parse( - """ + """, + config_pb2.EvalConfig(), + ), + ["precision"], + [0.5], + ), + ( + "_threshold_at_recall", + text_format.Parse( + """ model_specs { signature_name: "serving_default" prediction_key: "predictions" # placeholder @@ -111,86 +140,99 @@ class ObjectDetectionConfusionMatrixMetricsTest(parameterized.TestCase): '"max_num_detections":100, "name":"thresholdatrecall"' } } - """, config_pb2.EvalConfig()), ['thresholdatrecall'], [0.3])) - def testObjectDetectionMetrics(self, eval_config, name_list, - expected_results): - - extracts = [ - { + """, + config_pb2.EvalConfig(), + ), + ["thresholdatrecall"], + [0.3], + ), + ) + def testObjectDetectionMetrics(self, eval_config, name_list, expected_results): + extracts = [ + { + # The match at iou_threshold = 0.5 is + # gt_matches: [[0]] dt_matches: [[0, -1]] + # Results after preprocess: + # 'labels': np.asarray([1., 0.]), + # 'predictions': np.asarray([0.7, 0.3]) + "features": { + "labels": np.asarray( + [[[30, 100, 70, 300, 0], [50, 100, 80, 200, 1]]] + ), + "predictions": np.asarray( + [ + [ + [20, 130, 60, 290, 0, 0.7], + [30, 100, 70, 300, 0, 0.3], + [500, 100, 800, 300, 1, 0.1], + ] + ] + ), + } + }, + # This is a binary classification case, the iou matrix should be: + # [[0., 2/3], [0., 4/11]] # The match at iou_threshold = 0.5 is - # gt_matches: [[0]] dt_matches: [[0, -1]] + # gt_matches: [[-1, 0]] dt_matches: [[1, -1]] # Results after preprocess: - # 'labels': np.asarray([1., 0.]), - # 'predictions': np.asarray([0.7, 0.3]) - 'features': { - 'labels': - np.asarray([[[30, 100, 70, 300, 0], [50, 100, 80, 200, - 1]]]), - 'predictions': - np.asarray([[[20, 130, 60, 290, 0, 0.7], - [30, 100, 70, 300, 0, 0.3], - [500, 100, 800, 300, 1, 0.1]]]) - } - }, - # This is a binary classification case, the iou matrix should be: - # [[0., 2/3], [0., 4/11]] - # The match at iou_threshold = 0.5 is - # gt_matches: [[-1, 0]] dt_matches: [[1, -1]] - # Results after preprocess: - # 'labels': np.asarray([1., 1., 0.]), - # 'predictions': np.asarray([0., 0.4, 0.3]) - # thresholds=[-1e-7, 0.5, 1.0 + 1e-7], - # tp=[3.0, 1.0, 0.0], - # fp=[2.0, 0.0, 0.0], - # tn=[0.0, 2.0, 2.0], - # fn=[0.0, 2.0, 3.0]) - # Precision: [3/5, 1.0, 'nan'] - # Recall: [1.0, 1/3, 0.0] - { - 'features': { - 'labels': - np.asarray([[[30, 100, 70, 400, 0], [10, 200, 80, 300, - 0]]]), - 'predictions': - np.asarray([[[100, 130, 160, 290, 0, 0.4], - [30, 100, 70, 300, 0, 0.3]]]) - } - } - ] - - evaluators = tfma.default_evaluators(eval_config=eval_config) - extractors = tfma.default_extractors( - eval_shared_model=None, eval_config=eval_config) + # 'labels': np.asarray([1., 1., 0.]), + # 'predictions': np.asarray([0., 0.4, 0.3]) + # thresholds=[-1e-7, 0.5, 1.0 + 1e-7], + # tp=[3.0, 1.0, 0.0], + # fp=[2.0, 0.0, 0.0], + # tn=[0.0, 2.0, 2.0], + # fn=[0.0, 2.0, 3.0]) + # Precision: [3/5, 1.0, 'nan'] + # Recall: [1.0, 1/3, 0.0] + { + "features": { + "labels": np.asarray( + [[[30, 100, 70, 400, 0], [10, 200, 80, 300, 0]]] + ), + "predictions": np.asarray( + [[[100, 130, 160, 290, 0, 0.4], [30, 100, 70, 300, 0, 0.3]]] + ), + } + }, + ] - with beam.Pipeline() as p: - result = ( - p | 'LoadData' >> beam.Create(extracts) - | 'ExtractEval' >> tfma.ExtractAndEvaluate( - extractors=extractors, evaluators=evaluators)) + evaluators = tfma.default_evaluators(eval_config=eval_config) + extractors = tfma.default_extractors( + eval_shared_model=None, eval_config=eval_config + ) - def check_result(got): - try: - self.assertLen(got, 1) - got_slice_key, got_metrics = got[0] - self.assertEqual(got_slice_key, ()) - self.assertLen(got_metrics, len(name_list)) - for name, expected_result in zip(name_list, expected_results): - key = metric_types.MetricKey( - name=name, sub_key=metric_types.SubKey(class_id=0) + with beam.Pipeline() as p: + result = ( + p + | "LoadData" >> beam.Create(extracts) + | "ExtractEval" + >> tfma.ExtractAndEvaluate(extractors=extractors, evaluators=evaluators) ) - self.assertIn(key, got_metrics) - got_metric = got_metrics[key] - np.testing.assert_allclose( - expected_result, - got_metric, - rtol=1e-3, - err_msg=f'This {name} metric fails.') - except AssertionError as err: - raise util.BeamAssertException(err) - self.assertIn('metrics', result) - util.assert_that(result['metrics'], check_result, label='result') + def check_result(got): + try: + self.assertLen(got, 1) + got_slice_key, got_metrics = got[0] + self.assertEqual(got_slice_key, ()) + self.assertLen(got_metrics, len(name_list)) + for name, expected_result in zip(name_list, expected_results): + key = metric_types.MetricKey( + name=name, sub_key=metric_types.SubKey(class_id=0) + ) + self.assertIn(key, got_metrics) + got_metric = got_metrics[key] + np.testing.assert_allclose( + expected_result, + got_metric, + rtol=1e-3, + err_msg=f"This {name} metric fails.", + ) + except AssertionError as err: + raise util.BeamAssertException(err) + + self.assertIn("metrics", result) + util.assert_that(result["metrics"], check_result, label="result") -if __name__ == '__main__': - absltest.main() +if __name__ == "__main__": + absltest.main() diff --git a/tensorflow_model_analysis/metrics/object_detection_confusion_matrix_plot.py b/tensorflow_model_analysis/metrics/object_detection_confusion_matrix_plot.py index 24d365aa39..6de6f83bff 100644 --- a/tensorflow_model_analysis/metrics/object_detection_confusion_matrix_plot.py +++ b/tensorflow_model_analysis/metrics/object_detection_confusion_matrix_plot.py @@ -15,145 +15,152 @@ from typing import List, Optional, Tuple -from tensorflow_model_analysis.metrics import confusion_matrix_plot -from tensorflow_model_analysis.metrics import metric_types -from tensorflow_model_analysis.metrics import metric_util -from tensorflow_model_analysis.metrics import preprocessors +from tensorflow_model_analysis.metrics import ( + confusion_matrix_plot, + metric_types, + metric_util, + preprocessors, +) from tensorflow_model_analysis.proto import config_pb2 DEFAULT_NUM_THRESHOLDS = 1000 -CONFUSION_MATRIX_PLOT_NAME = 'confusion_matrix_plot' +CONFUSION_MATRIX_PLOT_NAME = "confusion_matrix_plot" -class ObjectDetectionConfusionMatrixPlot( - confusion_matrix_plot.ConfusionMatrixPlot): - """Object Detection Confusion matrix plot.""" +class ObjectDetectionConfusionMatrixPlot(confusion_matrix_plot.ConfusionMatrixPlot): + """Object Detection Confusion matrix plot.""" - def __init__(self, - num_thresholds: int = DEFAULT_NUM_THRESHOLDS, - name: str = CONFUSION_MATRIX_PLOT_NAME, - iou_threshold: Optional[float] = None, - class_id: Optional[int] = None, - class_weight: Optional[float] = None, - area_range: Optional[Tuple[float, float]] = None, - max_num_detections: Optional[int] = None, - labels_to_stack: Optional[List[str]] = None, - predictions_to_stack: Optional[List[str]] = None, - num_detections_key: Optional[str] = None, - allow_missing_key: bool = False): - """Initializes confusion matrix plot for object detection. + def __init__( + self, + num_thresholds: int = DEFAULT_NUM_THRESHOLDS, + name: str = CONFUSION_MATRIX_PLOT_NAME, + iou_threshold: Optional[float] = None, + class_id: Optional[int] = None, + class_weight: Optional[float] = None, + area_range: Optional[Tuple[float, float]] = None, + max_num_detections: Optional[int] = None, + labels_to_stack: Optional[List[str]] = None, + predictions_to_stack: Optional[List[str]] = None, + num_detections_key: Optional[str] = None, + allow_missing_key: bool = False, + ): + """Initializes confusion matrix plot for object detection. - The metric supports using multiple outputs to form the labels/predictions if - the user specifies the label/predcition keys to stack. In this case, the - metric is not expected to work with multi-outputs. The metric only supports - multi outputs if the output of model is already pre-stacked in the expected - format, i.e. ['xmin', 'ymin', 'xmax', 'ymax', 'class_id'] for labels and - ['xmin', 'ymin', 'xmax', 'ymax', 'class_id', 'confidence scores'] for - predictions. + The metric supports using multiple outputs to form the labels/predictions if + the user specifies the label/predcition keys to stack. In this case, the + metric is not expected to work with multi-outputs. The metric only supports + multi outputs if the output of model is already pre-stacked in the expected + format, i.e. ['xmin', 'ymin', 'xmax', 'ymax', 'class_id'] for labels and + ['xmin', 'ymin', 'xmax', 'ymax', 'class_id', 'confidence scores'] for + predictions. - Args: - num_thresholds: Number of thresholds to use when discretizing the curve. - Values must be > 1. Defaults to 1000. - name: Metric name. - iou_threshold: (Optional) Thresholds for a detection and ground truth pair - with specific iou to be considered as a match. Default to 0.5 - class_id: (Optional) The class id for calculating metrics. - class_weight: (Optional) The weight associated with the object class id. - area_range: (Optional) A tuple (inclusive) representing the area-range for - objects to be considered for metrics. Default to (0, inf). - max_num_detections: (Optional) The maximum number of detections for a - single image. Default to None. - labels_to_stack: (Optional) Keys for columns to be stacked as a single - numpy array as the labels. It is searched under the key labels, features - and transformed features. The desired format is [left bounadary, top - boudnary, right boundary, bottom boundary, class id]. e.g. ['xmin', - 'ymin', 'xmax', 'ymax', 'class_id'] - predictions_to_stack: (Optional) Output names for columns to be stacked as - a single numpy array as the prediction. It should be the model's output - names. The desired format is [left bounadary, top boudnary, right - boundary, bottom boundary, class id, confidence score]. e.g. ['xmin', - 'ymin', 'xmax', 'ymax', 'class_id', 'scores'] - num_detections_key: (Optional) An output name in which to find the number - of detections to use for evaluation for a given example. It does nothing - if predictions_to_stack is not set. The value for this output should be - a scalar value or a single-value tensor. The stacked predicitions will - be truncated with the specified number of detections. - allow_missing_key: (Optional) If true, the preprocessor will return empty - array instead of raising errors. - """ - super().__init__( - num_thresholds=num_thresholds, - name=name, - class_id=class_id, - iou_threshold=iou_threshold, - area_range=area_range, - max_num_detections=max_num_detections, - class_weight=class_weight, - labels_to_stack=labels_to_stack, - predictions_to_stack=predictions_to_stack, - num_detections_key=num_detections_key, - allow_missing_key=allow_missing_key) + Args: + ---- + num_thresholds: Number of thresholds to use when discretizing the curve. + Values must be > 1. Defaults to 1000. + name: Metric name. + iou_threshold: (Optional) Thresholds for a detection and ground truth pair + with specific iou to be considered as a match. Default to 0.5 + class_id: (Optional) The class id for calculating metrics. + class_weight: (Optional) The weight associated with the object class id. + area_range: (Optional) A tuple (inclusive) representing the area-range for + objects to be considered for metrics. Default to (0, inf). + max_num_detections: (Optional) The maximum number of detections for a + single image. Default to None. + labels_to_stack: (Optional) Keys for columns to be stacked as a single + numpy array as the labels. It is searched under the key labels, features + and transformed features. The desired format is [left bounadary, top + boudnary, right boundary, bottom boundary, class id]. e.g. ['xmin', + 'ymin', 'xmax', 'ymax', 'class_id'] + predictions_to_stack: (Optional) Output names for columns to be stacked as + a single numpy array as the prediction. It should be the model's output + names. The desired format is [left bounadary, top boudnary, right + boundary, bottom boundary, class id, confidence score]. e.g. ['xmin', + 'ymin', 'xmax', 'ymax', 'class_id', 'scores'] + num_detections_key: (Optional) An output name in which to find the number + of detections to use for evaluation for a given example. It does nothing + if predictions_to_stack is not set. The value for this output should be + a scalar value or a single-value tensor. The stacked predicitions will + be truncated with the specified number of detections. + allow_missing_key: (Optional) If true, the preprocessor will return empty + array instead of raising errors. + """ + super().__init__( + num_thresholds=num_thresholds, + name=name, + class_id=class_id, + iou_threshold=iou_threshold, + area_range=area_range, + max_num_detections=max_num_detections, + class_weight=class_weight, + labels_to_stack=labels_to_stack, + predictions_to_stack=predictions_to_stack, + num_detections_key=num_detections_key, + allow_missing_key=allow_missing_key, + ) - def _confusion_matrix_plot( - self, - num_thresholds: int = DEFAULT_NUM_THRESHOLDS, - name: str = CONFUSION_MATRIX_PLOT_NAME, - eval_config: Optional[config_pb2.EvalConfig] = None, - model_name: str = '', - output_name: str = '', - iou_threshold: Optional[float] = None, - class_id: Optional[int] = None, - class_weight: Optional[float] = None, - area_range: Optional[Tuple[float, float]] = None, - max_num_detections: Optional[int] = None, - labels_to_stack: Optional[List[str]] = None, - predictions_to_stack: Optional[List[str]] = None, - num_detections_key: Optional[str] = None, - allow_missing_key: bool = False, - example_weighted: bool = False, - **kwargs - ) -> metric_types.MetricComputations: - metric_util.validate_object_detection_arguments( - class_id=class_id, - class_weight=class_weight, - area_range=area_range, - max_num_detections=max_num_detections, - labels_to_stack=labels_to_stack, - predictions_to_stack=predictions_to_stack, - output_name=output_name) + def _confusion_matrix_plot( + self, + num_thresholds: int = DEFAULT_NUM_THRESHOLDS, + name: str = CONFUSION_MATRIX_PLOT_NAME, + eval_config: Optional[config_pb2.EvalConfig] = None, + model_name: str = "", + output_name: str = "", + iou_threshold: Optional[float] = None, + class_id: Optional[int] = None, + class_weight: Optional[float] = None, + area_range: Optional[Tuple[float, float]] = None, + max_num_detections: Optional[int] = None, + labels_to_stack: Optional[List[str]] = None, + predictions_to_stack: Optional[List[str]] = None, + num_detections_key: Optional[str] = None, + allow_missing_key: bool = False, + example_weighted: bool = False, + **kwargs, + ) -> metric_types.MetricComputations: + metric_util.validate_object_detection_arguments( + class_id=class_id, + class_weight=class_weight, + area_range=area_range, + max_num_detections=max_num_detections, + labels_to_stack=labels_to_stack, + predictions_to_stack=predictions_to_stack, + output_name=output_name, + ) - key = metric_types.PlotKey( - name=name, - model_name=model_name, - output_name=output_name, - sub_key=metric_types.SubKey(class_id=class_id), - example_weighted=example_weighted, - ) + key = metric_types.PlotKey( + name=name, + model_name=model_name, + output_name=output_name, + sub_key=metric_types.SubKey(class_id=class_id), + example_weighted=example_weighted, + ) - preprocessor = preprocessors.BoundingBoxMatchPreprocessor( - class_id=class_id, - iou_threshold=iou_threshold, - area_range=area_range, - max_num_detections=max_num_detections, - class_weight=class_weight, - labels_to_stack=labels_to_stack, - predictions_to_stack=predictions_to_stack, - num_detections_key=num_detections_key, - allow_missing_key=allow_missing_key, - model_name=model_name) + preprocessor = preprocessors.BoundingBoxMatchPreprocessor( + class_id=class_id, + iou_threshold=iou_threshold, + area_range=area_range, + max_num_detections=max_num_detections, + class_weight=class_weight, + labels_to_stack=labels_to_stack, + predictions_to_stack=predictions_to_stack, + num_detections_key=num_detections_key, + allow_missing_key=allow_missing_key, + model_name=model_name, + ) - return super()._confusion_matrix_plot( - num_thresholds=num_thresholds, - name=name, - eval_config=eval_config, - model_name=model_name, - output_name=output_name, - preprocessors=[preprocessor], - plot_key=key, - example_weighted=example_weighted, - **kwargs - ) + return super()._confusion_matrix_plot( + num_thresholds=num_thresholds, + name=name, + eval_config=eval_config, + model_name=model_name, + output_name=output_name, + preprocessors=[preprocessor], + plot_key=key, + example_weighted=example_weighted, + **kwargs, + ) metric_types.register_metric(ObjectDetectionConfusionMatrixPlot) diff --git a/tensorflow_model_analysis/metrics/object_detection_confusion_matrix_plot_test.py b/tensorflow_model_analysis/metrics/object_detection_confusion_matrix_plot_test.py index 3289cd5b15..0844929048 100644 --- a/tensorflow_model_analysis/metrics/object_detection_confusion_matrix_plot_test.py +++ b/tensorflow_model_analysis/metrics/object_detection_confusion_matrix_plot_test.py @@ -13,25 +13,24 @@ # limitations under the License. """Tests for object detection confusion matrix plot.""" -from absl.testing import absltest import apache_beam as beam -from apache_beam.testing import util import numpy as np +from absl.testing import absltest +from apache_beam.testing import util +from google.protobuf import text_format + import tensorflow_model_analysis as tfma -from tensorflow_model_analysis.proto import config_pb2 from tensorflow_model_analysis.metrics import metric_types +from tensorflow_model_analysis.proto import config_pb2 from tensorflow_model_analysis.utils import test_util -from google.protobuf import text_format - class ObjectDetectionConfusionMatrixPlotTest( test_util.TensorflowModelAnalysisTest, absltest.TestCase ): - - def testConfusionMatrixPlot(self): - eval_config = text_format.Parse( - """ + def testConfusionMatrixPlot(self): + eval_config = text_format.Parse( + """ model_specs { signature_name: "serving_default" prediction_key: "predictions" @@ -46,71 +45,81 @@ def testConfusionMatrixPlot(self): '"max_num_detections":100, "name":"iou0.5"' } } - """, config_pb2.EvalConfig()) - extracts = [ - # The match at iou_threshold = 0.5 is - # gt_matches: [[0]] dt_matches: [[0, -1]] - # Results after preprocess: - # 'labels': np.asarray([1., 0.]), - # 'predictions': np.asarray([0.7, 0.3]) - { - 'features': { - 'labels': - np.asarray([[[30, 100, 70, 300, 0], [50, 100, 80, 200, - 1]]]), - 'predictions': - np.asarray([[[20, 130, 60, 290, 0, 0.7], - [30, 100, 70, 300, 0, 0.3], - [500, 100, 800, 300, 1, 0.1]]]) - } - }, - # This is a binary classification case, the iou matrix should be: - # [[0., 2/3], [0., 4/11]] - # The match at iou_threshold = 0.5 is - # gt_matches: [[-1, 0]] dt_matches: [[1, -1]] - # Results after preprocess: - # 'labels': np.asarray([1., 1., 0.]), - # 'predictions': np.asarray([0., 0.4, 0.3]) - # thresholds=[-1e-7, 0.5, 1.0 + 1e-7], - # tp=[3.0, 1.0, 0.0], - # fp=[2.0, 0.0, 0.0], - # tn=[0.0, 2.0, 2.0], - # fn=[0.0, 2.0, 3.0]) - # Precision: [3/5, 1.0, 'nan'] - # Recall: [1.0, 1/3, 0.0] - { - 'features': { - 'labels': - np.asarray([[[30, 100, 70, 400, 0], [10, 200, 80, 300, - 0]]]), - 'predictions': - np.asarray([[[100, 130, 160, 290, 0, 0.4], - [30, 100, 70, 300, 0, 0.3]]]) - } - } - ] - evaluators = tfma.default_evaluators(eval_config=eval_config) - extractors = tfma.default_extractors( - eval_shared_model=None, eval_config=eval_config) + """, + config_pb2.EvalConfig(), + ) + extracts = [ + # The match at iou_threshold = 0.5 is + # gt_matches: [[0]] dt_matches: [[0, -1]] + # Results after preprocess: + # 'labels': np.asarray([1., 0.]), + # 'predictions': np.asarray([0.7, 0.3]) + { + "features": { + "labels": np.asarray( + [[[30, 100, 70, 300, 0], [50, 100, 80, 200, 1]]] + ), + "predictions": np.asarray( + [ + [ + [20, 130, 60, 290, 0, 0.7], + [30, 100, 70, 300, 0, 0.3], + [500, 100, 800, 300, 1, 0.1], + ] + ] + ), + } + }, + # This is a binary classification case, the iou matrix should be: + # [[0., 2/3], [0., 4/11]] + # The match at iou_threshold = 0.5 is + # gt_matches: [[-1, 0]] dt_matches: [[1, -1]] + # Results after preprocess: + # 'labels': np.asarray([1., 1., 0.]), + # 'predictions': np.asarray([0., 0.4, 0.3]) + # thresholds=[-1e-7, 0.5, 1.0 + 1e-7], + # tp=[3.0, 1.0, 0.0], + # fp=[2.0, 0.0, 0.0], + # tn=[0.0, 2.0, 2.0], + # fn=[0.0, 2.0, 3.0]) + # Precision: [3/5, 1.0, 'nan'] + # Recall: [1.0, 1/3, 0.0] + { + "features": { + "labels": np.asarray( + [[[30, 100, 70, 400, 0], [10, 200, 80, 300, 0]]] + ), + "predictions": np.asarray( + [[[100, 130, 160, 290, 0, 0.4], [30, 100, 70, 300, 0, 0.3]]] + ), + } + }, + ] + evaluators = tfma.default_evaluators(eval_config=eval_config) + extractors = tfma.default_extractors( + eval_shared_model=None, eval_config=eval_config + ) - with beam.Pipeline() as p: - result = ( - p | 'LoadData' >> beam.Create(extracts) - | 'ExtractEval' >> tfma.ExtractAndEvaluate( - extractors=extractors, evaluators=evaluators)) + with beam.Pipeline() as p: + result = ( + p + | "LoadData" >> beam.Create(extracts) + | "ExtractEval" + >> tfma.ExtractAndEvaluate(extractors=extractors, evaluators=evaluators) + ) - def check_result(got): - try: - self.assertLen(got, 1) - got_slice_key, got_plots = got[0] - self.assertEqual(got_slice_key, ()) - key = metric_types.PlotKey( - name='iou0.5', sub_key=metric_types.SubKey(class_id=1) - ) - self.assertIn(key, got_plots) - got_plot = got_plots[key] - self.assertProtoEquals( - """ + def check_result(got): + try: + self.assertLen(got, 1) + got_slice_key, got_plots = got[0] + self.assertEqual(got_slice_key, ()) + key = metric_types.PlotKey( + name="iou0.5", sub_key=metric_types.SubKey(class_id=1) + ) + self.assertIn(key, got_plots) + got_plot = got_plots[key] + self.assertProtoEquals( + """ matrices { threshold: -1e-06 false_positives: 1.0 @@ -173,14 +182,14 @@ def check_result(got): false_omission_rate: 0.5 } """, - got_plot, - ) - except AssertionError as err: - raise util.BeamAssertException(err) + got_plot, + ) + except AssertionError as err: + raise util.BeamAssertException(err) - self.assertIn('plots', result) - util.assert_that(result['plots'], check_result, label='result') + self.assertIn("plots", result) + util.assert_that(result["plots"], check_result, label="result") -if __name__ == '__main__': - absltest.main() +if __name__ == "__main__": + absltest.main() diff --git a/tensorflow_model_analysis/metrics/object_detection_metrics.py b/tensorflow_model_analysis/metrics/object_detection_metrics.py index 3a17353094..66468c2317 100644 --- a/tensorflow_model_analysis/metrics/object_detection_metrics.py +++ b/tensorflow_model_analysis/metrics/object_detection_metrics.py @@ -18,172 +18,182 @@ from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np -from tensorflow_model_analysis.metrics import confusion_matrix_metrics -from tensorflow_model_analysis.metrics import metric_types -from tensorflow_model_analysis.metrics import metric_util -from tensorflow_model_analysis.metrics import object_detection_confusion_matrix_metrics -AVERAGE_RECALL_NAME = 'average_recall' -AVERAGE_PRECISION_NAME = 'average_precision' -MEAN_AVERAGE_PRECISION_NAME = 'mean_average_precision' -MEAN_AVERAGE_RECALL_NAME = 'mean_average_recall' +from tensorflow_model_analysis.metrics import ( + confusion_matrix_metrics, + metric_types, + metric_util, + object_detection_confusion_matrix_metrics, +) + +AVERAGE_RECALL_NAME = "average_recall" +AVERAGE_PRECISION_NAME = "average_precision" +MEAN_AVERAGE_PRECISION_NAME = "mean_average_precision" +MEAN_AVERAGE_RECALL_NAME = "mean_average_recall" class COCOAveragePrecision(metric_types.Metric): - """Confusion matrix at thresholds. - - It computes the average precision of object detections for a single class and - a single iou_threshold. - """ - - def __init__(self, - num_thresholds: Optional[int] = None, - iou_threshold: Optional[float] = None, - class_id: Optional[int] = None, - class_weight: Optional[float] = None, - area_range: Optional[Tuple[float, float]] = None, - max_num_detections: Optional[int] = None, - recalls: Optional[List[float]] = None, - num_recalls: Optional[int] = None, - name: Optional[str] = None, - labels_to_stack: Optional[List[str]] = None, - predictions_to_stack: Optional[List[str]] = None, - num_detections_key: Optional[str] = None, - allow_missing_key: bool = False): - """Initialize average precision metric. - - This metric is only used in object-detection setting. It does not support - sub_key parameters due to the matching algorithm of bounding boxes. - - The metric supports using multiple outputs to form the labels/predictions if - the user specifies the label/predcition keys to stack. In this case, the - metric is not expected to work with multi-outputs. The metric only supports - multi outputs if the output of model is already pre-stacked in the expected - format, i.e. ['xmin', 'ymin', 'xmax', 'ymax', 'class_id'] for labels and - ['xmin', 'ymin', 'xmax', 'ymax', 'class_id', 'confidence scores'] for - predictions. - - Args: - num_thresholds: (Optional) Number of thresholds to use for calculating the - matrices and finding the precision at given recall. - iou_threshold: (Optional) Threholds for a detection and ground truth pair - with specific iou to be considered as a match. - class_id: (Optional) The class id for calculating metrics. - class_weight: (Optional) The weight associated with the object class id. - area_range: (Optional) The area-range for objects to be considered for - metrics. - max_num_detections: (Optional) The maximum number of detections for a - single image. - recalls: (Optional) recalls at which precisions will be calculated. - num_recalls: (Optional) Used for objecth detection, the number of recalls - for calculating average precision, it equally generates points bewteen 0 - and 1. (Only one of recalls and num_recalls should be used). - name: (Optional) string name of the metric instance. - labels_to_stack: (Optional) Keys for columns to be stacked as a single - numpy array as the labels. It is searched under the key labels, features - and transformed features. The desired format is [left bounadary, top - boudnary, right boundary, bottom boundary, class id]. e.g. ['xmin', - 'ymin', 'xmax', 'ymax', 'class_id'] - predictions_to_stack: (Optional) Output names for columns to be stacked as - a single numpy array as the prediction. It should be the model's output - names. The desired format is [left bounadary, top boudnary, right - boundary, bottom boundary, class id, confidence score]. e.g. ['xmin', - 'ymin', 'xmax', 'ymax', 'class_id', 'scores'] - num_detections_key: (Optional) An output name in which to find the number - of detections to use for evaluation for a given example. It does nothing - if predictions_to_stack is not set. The value for this output should be - a scalar value or a single-value tensor. The stacked predicitions will - be truncated with the specified number of detections. - allow_missing_key: (Optional) If true, the preprocessor will return empty - array instead of raising errors. + """Confusion matrix at thresholds. + + It computes the average precision of object detections for a single class and + a single iou_threshold. """ - if recalls is not None: - recall_thresholds = recalls - elif num_recalls is not None: - recall_thresholds = np.linspace(0.0, 1.0, num_recalls) - else: - # by default set recall_thresholds to [0.0:0.01:1.0]. - recall_thresholds = np.linspace(0.0, 1.0, 101) - - super().__init__( - metric_util.merge_per_key_computations(self._metric_computations), - num_thresholds=num_thresholds, - iou_threshold=iou_threshold, - class_id=class_id, - class_weight=class_weight, - area_range=area_range, - max_num_detections=max_num_detections, - recall_thresholds=recall_thresholds, - name=name, - labels_to_stack=labels_to_stack, - predictions_to_stack=predictions_to_stack, - num_detections_key=num_detections_key, - allow_missing_key=allow_missing_key) - - def _default_name(self) -> str: - return AVERAGE_PRECISION_NAME - - def _metric_computations( - self, - num_thresholds: Optional[int] = None, - iou_threshold: Optional[float] = None, - class_id: Optional[int] = None, - class_weight: Optional[float] = None, - max_num_detections: Optional[int] = None, - area_range: Optional[Tuple[float, float]] = None, - recall_thresholds: Optional[List[float]] = None, - name: Optional[str] = None, - model_name: str = '', - output_name: str = '', - example_weighted: bool = False, - labels_to_stack: Optional[List[str]] = None, - predictions_to_stack: Optional[List[str]] = None, - num_detections_key: Optional[str] = None, - allow_missing_key: bool = False, - ) -> metric_types.MetricComputations: - """Returns computations for confusion matrix metric.""" - metric_util.validate_object_detection_arguments( - class_id=class_id, - class_weight=class_weight, - area_range=area_range, - max_num_detections=max_num_detections, - labels_to_stack=labels_to_stack, - predictions_to_stack=predictions_to_stack, - output_name=output_name) - - key = metric_types.MetricKey( - name=name, - model_name=model_name, - output_name=output_name, - sub_key=None, - example_weighted=example_weighted, - aggregation_type=None) - - if recall_thresholds is None: - # If recall thresholds is not defined, initialize it as [0.0] - recall_thresholds = [0.0] - - if num_thresholds is None: - num_thresholds = 10000 - thresholds = [1.e-12] + [ - (i + 1) * 1.0 / (num_thresholds - 1) for i in range(num_thresholds - 2) - ] + [1.0 - 1.e-12] - # PrecisionAtRecall is a public function. To hide it from users who do not - # need it, we make the name private with '_'. - precision_at_recall_name = metric_util.generate_private_name_from_arguments( - confusion_matrix_metrics.PRECISION_AT_RECALL_NAME, - recall=recall_thresholds, - num_thresholds=num_thresholds, - iou_threshold=iou_threshold, - class_id=class_id, - class_weight=class_weight, - area_range=area_range, - max_num_detections=max_num_detections, - allow_missing_key=allow_missing_key) - - pr = ( - object_detection_confusion_matrix_metrics - .ObjectDetectionPrecisionAtRecall( + + def __init__( + self, + num_thresholds: Optional[int] = None, + iou_threshold: Optional[float] = None, + class_id: Optional[int] = None, + class_weight: Optional[float] = None, + area_range: Optional[Tuple[float, float]] = None, + max_num_detections: Optional[int] = None, + recalls: Optional[List[float]] = None, + num_recalls: Optional[int] = None, + name: Optional[str] = None, + labels_to_stack: Optional[List[str]] = None, + predictions_to_stack: Optional[List[str]] = None, + num_detections_key: Optional[str] = None, + allow_missing_key: bool = False, + ): + """Initialize average precision metric. + + This metric is only used in object-detection setting. It does not support + sub_key parameters due to the matching algorithm of bounding boxes. + + The metric supports using multiple outputs to form the labels/predictions if + the user specifies the label/predcition keys to stack. In this case, the + metric is not expected to work with multi-outputs. The metric only supports + multi outputs if the output of model is already pre-stacked in the expected + format, i.e. ['xmin', 'ymin', 'xmax', 'ymax', 'class_id'] for labels and + ['xmin', 'ymin', 'xmax', 'ymax', 'class_id', 'confidence scores'] for + predictions. + + Args: + ---- + num_thresholds: (Optional) Number of thresholds to use for calculating the + matrices and finding the precision at given recall. + iou_threshold: (Optional) Threholds for a detection and ground truth pair + with specific iou to be considered as a match. + class_id: (Optional) The class id for calculating metrics. + class_weight: (Optional) The weight associated with the object class id. + area_range: (Optional) The area-range for objects to be considered for + metrics. + max_num_detections: (Optional) The maximum number of detections for a + single image. + recalls: (Optional) recalls at which precisions will be calculated. + num_recalls: (Optional) Used for objecth detection, the number of recalls + for calculating average precision, it equally generates points bewteen 0 + and 1. (Only one of recalls and num_recalls should be used). + name: (Optional) string name of the metric instance. + labels_to_stack: (Optional) Keys for columns to be stacked as a single + numpy array as the labels. It is searched under the key labels, features + and transformed features. The desired format is [left bounadary, top + boudnary, right boundary, bottom boundary, class id]. e.g. ['xmin', + 'ymin', 'xmax', 'ymax', 'class_id'] + predictions_to_stack: (Optional) Output names for columns to be stacked as + a single numpy array as the prediction. It should be the model's output + names. The desired format is [left bounadary, top boudnary, right + boundary, bottom boundary, class id, confidence score]. e.g. ['xmin', + 'ymin', 'xmax', 'ymax', 'class_id', 'scores'] + num_detections_key: (Optional) An output name in which to find the number + of detections to use for evaluation for a given example. It does nothing + if predictions_to_stack is not set. The value for this output should be + a scalar value or a single-value tensor. The stacked predicitions will + be truncated with the specified number of detections. + allow_missing_key: (Optional) If true, the preprocessor will return empty + array instead of raising errors. + """ + if recalls is not None: + recall_thresholds = recalls + elif num_recalls is not None: + recall_thresholds = np.linspace(0.0, 1.0, num_recalls) + else: + # by default set recall_thresholds to [0.0:0.01:1.0]. + recall_thresholds = np.linspace(0.0, 1.0, 101) + + super().__init__( + metric_util.merge_per_key_computations(self._metric_computations), + num_thresholds=num_thresholds, + iou_threshold=iou_threshold, + class_id=class_id, + class_weight=class_weight, + area_range=area_range, + max_num_detections=max_num_detections, + recall_thresholds=recall_thresholds, + name=name, + labels_to_stack=labels_to_stack, + predictions_to_stack=predictions_to_stack, + num_detections_key=num_detections_key, + allow_missing_key=allow_missing_key, + ) + + def _default_name(self) -> str: + return AVERAGE_PRECISION_NAME + + def _metric_computations( + self, + num_thresholds: Optional[int] = None, + iou_threshold: Optional[float] = None, + class_id: Optional[int] = None, + class_weight: Optional[float] = None, + max_num_detections: Optional[int] = None, + area_range: Optional[Tuple[float, float]] = None, + recall_thresholds: Optional[List[float]] = None, + name: Optional[str] = None, + model_name: str = "", + output_name: str = "", + example_weighted: bool = False, + labels_to_stack: Optional[List[str]] = None, + predictions_to_stack: Optional[List[str]] = None, + num_detections_key: Optional[str] = None, + allow_missing_key: bool = False, + ) -> metric_types.MetricComputations: + """Returns computations for confusion matrix metric.""" + metric_util.validate_object_detection_arguments( + class_id=class_id, + class_weight=class_weight, + area_range=area_range, + max_num_detections=max_num_detections, + labels_to_stack=labels_to_stack, + predictions_to_stack=predictions_to_stack, + output_name=output_name, + ) + + key = metric_types.MetricKey( + name=name, + model_name=model_name, + output_name=output_name, + sub_key=None, + example_weighted=example_weighted, + aggregation_type=None, + ) + + if recall_thresholds is None: + # If recall thresholds is not defined, initialize it as [0.0] + recall_thresholds = [0.0] + + if num_thresholds is None: + num_thresholds = 10000 + thresholds = ( + [1.0e-12] + + [(i + 1) * 1.0 / (num_thresholds - 1) for i in range(num_thresholds - 2)] + + [1.0 - 1.0e-12] + ) + # PrecisionAtRecall is a public function. To hide it from users who do not + # need it, we make the name private with '_'. + precision_at_recall_name = metric_util.generate_private_name_from_arguments( + confusion_matrix_metrics.PRECISION_AT_RECALL_NAME, + recall=recall_thresholds, + num_thresholds=num_thresholds, + iou_threshold=iou_threshold, + class_id=class_id, + class_weight=class_weight, + area_range=area_range, + max_num_detections=max_num_detections, + allow_missing_key=allow_missing_key, + ) + + pr = object_detection_confusion_matrix_metrics.ObjectDetectionPrecisionAtRecall( recall=recall_thresholds, thresholds=thresholds, iou_threshold=iou_threshold, @@ -195,532 +205,562 @@ def _metric_computations( labels_to_stack=labels_to_stack, predictions_to_stack=predictions_to_stack, num_detections_key=num_detections_key, - allow_missing_key=allow_missing_key)) - computations = pr.computations( - model_names=[model_name], output_names=[output_name]) - precisions_key = computations[-1].keys[-1] - - def result( - metrics: Dict[metric_types.MetricKey, Any] - ) -> Dict[metric_types.MetricKey, Union[float, np.ndarray]]: - value = np.nanmean(metrics[precisions_key]) - return {key: value} - - derived_computation = metric_types.DerivedMetricComputation( - keys=[key], result=result) - computations.append(derived_computation) - return computations + allow_missing_key=allow_missing_key, + ) + computations = pr.computations( + model_names=[model_name], output_names=[output_name] + ) + precisions_key = computations[-1].keys[-1] + + def result( + metrics: Dict[metric_types.MetricKey, Any], + ) -> Dict[metric_types.MetricKey, Union[float, np.ndarray]]: + value = np.nanmean(metrics[precisions_key]) + return {key: value} + + derived_computation = metric_types.DerivedMetricComputation( + keys=[key], result=result + ) + computations.append(derived_computation) + return computations metric_types.register_metric(COCOAveragePrecision) class COCOMeanAveragePrecision(metric_types.Metric): - """Mean average precision for object detections. - - It calculates the mean average precision metric for object detections. It - averages COCOAveragePrecision over multiple classes and IoU thresholds. - """ - - def __init__(self, - num_thresholds: Optional[int] = None, - iou_thresholds: Optional[List[float]] = None, - class_ids: Optional[List[int]] = None, - class_weights: Optional[List[float]] = None, - area_range: Optional[Tuple[float, float]] = None, - max_num_detections: Optional[int] = None, - recalls: Optional[List[float]] = None, - num_recalls: Optional[int] = None, - name: Optional[str] = None, - labels_to_stack: Optional[List[str]] = None, - predictions_to_stack: Optional[List[str]] = None, - num_detections_key: Optional[str] = None, - allow_missing_key: bool = False): - """Initializes mean average precision metric. - - This metric is only used in object-detection setting. It does not support - sub_key parameters due to the matching algorithm of bounding boxes. - - The metric supports using multiple outputs to form the labels/predictions if - the user specifies the label/predcition keys to stack. In this case, the - metric is not expected to work with multi-outputs. The metric only supports - multi outputs if the output of model is already pre-stacked in the expected - format, i.e. ['xmin', 'ymin', 'xmax', 'ymax', 'class_id'] for labels and - ['xmin', 'ymin', 'xmax', 'ymax', 'class_id', 'confidence scores'] for - predictions. - - Args: - num_thresholds: (Optional) Number of thresholds to use for calculating the - matrices and finding the precision at given recall. - iou_thresholds: (Optional) Threholds for a detection and ground truth pair - with specific iou to be considered as a match. - class_ids: (Optional) The class ids for calculating metrics. - class_weights: (Optional) The weight associated with the object class ids. - If it is provided, it should have the same length as class_ids. - area_range: (Optional) The area-range for objects to be considered for - metrics. - max_num_detections: (Optional) The maximum number of detections for a - single image. - recalls: (Optional) recalls at which precisions will be calculated. - num_recalls: (Optional) Used for objecth detection, the number of recalls - for calculating average precision, it equally generates points bewteen 0 - and 1. (Only one of recalls and num_recalls should be used). - name: (Optional) Metric name. - labels_to_stack: (Optional) Keys for columns to be stacked as a single - numpy array as the labels. It is searched under the key labels, features - and transformed features. The desired format is [left bounadary, top - boudnary, right boundary, bottom boundary, class id]. e.g. ['xmin', - 'ymin', 'xmax', 'ymax', 'class_id'] - predictions_to_stack: (Optional) Output names for columns to be stacked as - a single numpy array as the prediction. It should be the model's output - names. The desired format is [left bounadary, top boudnary, right - boundary, bottom boundary, class id, confidence score]. e.g. ['xmin', - 'ymin', 'xmax', 'ymax', 'class_id', 'scores'] - num_detections_key: (Optional) An output name in which to find the number - of detections to use for evaluation for a given example. It does nothing - if predictions_to_stack is not set. The value for this output should be - a scalar value or a single-value tensor. The stacked predicitions will - be truncated with the specified number of detections. - allow_missing_key: (Optional) If true, the preprocessor will return empty - array instead of raising errors. - """ + """Mean average precision for object detections. - super().__init__( - metric_util.merge_per_key_computations(self._metric_computations), - num_thresholds=num_thresholds, - iou_thresholds=iou_thresholds, - class_ids=class_ids, - class_weights=class_weights, - area_range=area_range, - max_num_detections=max_num_detections, - recalls=recalls, - num_recalls=num_recalls, - name=name, - labels_to_stack=labels_to_stack, - predictions_to_stack=predictions_to_stack, - num_detections_key=num_detections_key, - allow_missing_key=allow_missing_key) - - def _default_name(self) -> str: - return MEAN_AVERAGE_PRECISION_NAME - - def _metric_computations(self, - num_thresholds: Optional[int] = None, - iou_thresholds: Optional[List[float]] = None, - class_ids: Optional[List[int]] = None, - class_weights: Optional[List[float]] = None, - max_num_detections: Optional[int] = None, - area_range: Optional[Tuple[float, float]] = None, - recalls: Optional[List[float]] = None, - num_recalls: Optional[int] = None, - name: Optional[str] = None, - model_name: str = '', - output_name: str = '', - example_weighted: bool = False, - labels_to_stack: Optional[List[str]] = None, - predictions_to_stack: Optional[List[str]] = None, - num_detections_key: Optional[str] = None, - allow_missing_key: bool = False, - **kwargs) -> metric_types.MetricComputations: - """Returns computations for confusion matrix metric.""" - - metric_util.validate_object_detection_arguments( - class_id=class_ids, - class_weight=class_weights, - area_range=area_range, - max_num_detections=max_num_detections, - labels_to_stack=labels_to_stack, - predictions_to_stack=predictions_to_stack, - output_name=output_name) - - # set default value according to COCO metrics - if iou_thresholds is None: - iou_thresholds = np.linspace(0.5, 0.95, 10) - if class_weights is None: - class_weights = [1.0] * len(class_ids) - - key = metric_types.MetricKey( - name=name, - model_name=model_name, - output_name=output_name, - sub_key=None, - example_weighted=example_weighted, - aggregation_type=None) - - computations = [] - precisions_keys = [] - for iou_threshold in iou_thresholds: - for class_id, class_weight in zip(class_ids, class_weights): - - average_precision_name = ( - metric_util.generate_private_name_from_arguments( - AVERAGE_PRECISION_NAME, - recall=recalls, - num_recalls=num_recalls, - num_thresholds=num_thresholds, - iou_threshold=iou_threshold, - class_id=class_id, - class_weight=class_weight, - area_range=area_range, - max_num_detections=max_num_detections, - allow_missing_key=allow_missing_key)) + It calculates the mean average precision metric for object detections. It + averages COCOAveragePrecision over multiple classes and IoU thresholds. + """ - ap = COCOAveragePrecision( + def __init__( + self, + num_thresholds: Optional[int] = None, + iou_thresholds: Optional[List[float]] = None, + class_ids: Optional[List[int]] = None, + class_weights: Optional[List[float]] = None, + area_range: Optional[Tuple[float, float]] = None, + max_num_detections: Optional[int] = None, + recalls: Optional[List[float]] = None, + num_recalls: Optional[int] = None, + name: Optional[str] = None, + labels_to_stack: Optional[List[str]] = None, + predictions_to_stack: Optional[List[str]] = None, + num_detections_key: Optional[str] = None, + allow_missing_key: bool = False, + ): + """Initializes mean average precision metric. + + This metric is only used in object-detection setting. It does not support + sub_key parameters due to the matching algorithm of bounding boxes. + + The metric supports using multiple outputs to form the labels/predictions if + the user specifies the label/predcition keys to stack. In this case, the + metric is not expected to work with multi-outputs. The metric only supports + multi outputs if the output of model is already pre-stacked in the expected + format, i.e. ['xmin', 'ymin', 'xmax', 'ymax', 'class_id'] for labels and + ['xmin', 'ymin', 'xmax', 'ymax', 'class_id', 'confidence scores'] for + predictions. + + Args: + ---- + num_thresholds: (Optional) Number of thresholds to use for calculating the + matrices and finding the precision at given recall. + iou_thresholds: (Optional) Threholds for a detection and ground truth pair + with specific iou to be considered as a match. + class_ids: (Optional) The class ids for calculating metrics. + class_weights: (Optional) The weight associated with the object class ids. + If it is provided, it should have the same length as class_ids. + area_range: (Optional) The area-range for objects to be considered for + metrics. + max_num_detections: (Optional) The maximum number of detections for a + single image. + recalls: (Optional) recalls at which precisions will be calculated. + num_recalls: (Optional) Used for objecth detection, the number of recalls + for calculating average precision, it equally generates points bewteen 0 + and 1. (Only one of recalls and num_recalls should be used). + name: (Optional) Metric name. + labels_to_stack: (Optional) Keys for columns to be stacked as a single + numpy array as the labels. It is searched under the key labels, features + and transformed features. The desired format is [left bounadary, top + boudnary, right boundary, bottom boundary, class id]. e.g. ['xmin', + 'ymin', 'xmax', 'ymax', 'class_id'] + predictions_to_stack: (Optional) Output names for columns to be stacked as + a single numpy array as the prediction. It should be the model's output + names. The desired format is [left bounadary, top boudnary, right + boundary, bottom boundary, class id, confidence score]. e.g. ['xmin', + 'ymin', 'xmax', 'ymax', 'class_id', 'scores'] + num_detections_key: (Optional) An output name in which to find the number + of detections to use for evaluation for a given example. It does nothing + if predictions_to_stack is not set. The value for this output should be + a scalar value or a single-value tensor. The stacked predicitions will + be truncated with the specified number of detections. + allow_missing_key: (Optional) If true, the preprocessor will return empty + array instead of raising errors. + """ + super().__init__( + metric_util.merge_per_key_computations(self._metric_computations), num_thresholds=num_thresholds, - iou_threshold=iou_threshold, - class_id=class_id, - class_weight=class_weight, + iou_thresholds=iou_thresholds, + class_ids=class_ids, + class_weights=class_weights, area_range=area_range, max_num_detections=max_num_detections, recalls=recalls, num_recalls=num_recalls, - name=average_precision_name, + name=name, labels_to_stack=labels_to_stack, predictions_to_stack=predictions_to_stack, num_detections_key=num_detections_key, - allow_missing_key=allow_missing_key) - computations.extend( - ap.computations( - model_names=[model_name], output_names=[output_name])) - precisions_keys.append(computations[-1].keys[-1]) - - def result( - metrics: Dict[metric_types.MetricKey, Any] - ) -> Dict[metric_types.MetricKey, Union[float, np.ndarray]]: - precisions = [ - metrics[precisions_key] for precisions_key in precisions_keys - ] - value = np.nanmean(precisions) - return {key: value} - - derived_computation = metric_types.DerivedMetricComputation( - keys=[key], result=result) - computations.append(derived_computation) - return computations + allow_missing_key=allow_missing_key, + ) + + def _default_name(self) -> str: + return MEAN_AVERAGE_PRECISION_NAME + + def _metric_computations( + self, + num_thresholds: Optional[int] = None, + iou_thresholds: Optional[List[float]] = None, + class_ids: Optional[List[int]] = None, + class_weights: Optional[List[float]] = None, + max_num_detections: Optional[int] = None, + area_range: Optional[Tuple[float, float]] = None, + recalls: Optional[List[float]] = None, + num_recalls: Optional[int] = None, + name: Optional[str] = None, + model_name: str = "", + output_name: str = "", + example_weighted: bool = False, + labels_to_stack: Optional[List[str]] = None, + predictions_to_stack: Optional[List[str]] = None, + num_detections_key: Optional[str] = None, + allow_missing_key: bool = False, + **kwargs, + ) -> metric_types.MetricComputations: + """Returns computations for confusion matrix metric.""" + metric_util.validate_object_detection_arguments( + class_id=class_ids, + class_weight=class_weights, + area_range=area_range, + max_num_detections=max_num_detections, + labels_to_stack=labels_to_stack, + predictions_to_stack=predictions_to_stack, + output_name=output_name, + ) + + # set default value according to COCO metrics + if iou_thresholds is None: + iou_thresholds = np.linspace(0.5, 0.95, 10) + if class_weights is None: + class_weights = [1.0] * len(class_ids) + + key = metric_types.MetricKey( + name=name, + model_name=model_name, + output_name=output_name, + sub_key=None, + example_weighted=example_weighted, + aggregation_type=None, + ) + + computations = [] + precisions_keys = [] + for iou_threshold in iou_thresholds: + for class_id, class_weight in zip(class_ids, class_weights): + average_precision_name = ( + metric_util.generate_private_name_from_arguments( + AVERAGE_PRECISION_NAME, + recall=recalls, + num_recalls=num_recalls, + num_thresholds=num_thresholds, + iou_threshold=iou_threshold, + class_id=class_id, + class_weight=class_weight, + area_range=area_range, + max_num_detections=max_num_detections, + allow_missing_key=allow_missing_key, + ) + ) + + ap = COCOAveragePrecision( + num_thresholds=num_thresholds, + iou_threshold=iou_threshold, + class_id=class_id, + class_weight=class_weight, + area_range=area_range, + max_num_detections=max_num_detections, + recalls=recalls, + num_recalls=num_recalls, + name=average_precision_name, + labels_to_stack=labels_to_stack, + predictions_to_stack=predictions_to_stack, + num_detections_key=num_detections_key, + allow_missing_key=allow_missing_key, + ) + computations.extend( + ap.computations( + model_names=[model_name], output_names=[output_name] + ) + ) + precisions_keys.append(computations[-1].keys[-1]) + + def result( + metrics: Dict[metric_types.MetricKey, Any], + ) -> Dict[metric_types.MetricKey, Union[float, np.ndarray]]: + precisions = [metrics[precisions_key] for precisions_key in precisions_keys] + value = np.nanmean(precisions) + return {key: value} + + derived_computation = metric_types.DerivedMetricComputation( + keys=[key], result=result + ) + computations.append(derived_computation) + return computations metric_types.register_metric(COCOMeanAveragePrecision) class COCOAverageRecall(metric_types.Metric): - """Average recall metric for object detection. - - It computes the average precision metric for object detections for a single - class. It averages MaxRecall metric over mulitple IoU thresholds. - """ - - def __init__(self, - iou_thresholds: Optional[List[float]] = None, - class_id: Optional[int] = None, - class_weight: Optional[float] = None, - area_range: Optional[Tuple[float, float]] = None, - max_num_detections: Optional[int] = None, - name: Optional[str] = None, - labels_to_stack: Optional[List[str]] = None, - predictions_to_stack: Optional[List[str]] = None, - num_detections_key: Optional[str] = None, - allow_missing_key: bool = False): - """Initializes average recall metric. - - This metric is only used in object-detection setting. It does not support - sub_key parameters due to the matching algorithm of bounding boxes. - - The metric supports using multiple outputs to form the labels/predictions if - the user specifies the label/predcition keys to stack. In this case, the - metric is not expected to work with multi-outputs. The metric only supports - multi outputs if the output of model is already pre-stacked in the expected - format, i.e. ['xmin', 'ymin', 'xmax', 'ymax', 'class_id'] for labels and - ['xmin', 'ymin', 'xmax', 'ymax', 'class_id', 'confidence scores'] for - predictions. - - Args: - iou_thresholds: (Optional) Threholds for a detection and ground truth pair - with specific iou to be considered as a match. - class_id: (Optional) The class ids for calculating metrics. - class_weight: (Optional) The weight associated with the object class ids. - If it is provided, it should have the same length as class_ids. - area_range: (Optional) The area-range for objects to be considered for - metrics. - max_num_detections: (Optional) The maximum number of detections for a - single image. - name: (Optional) Metric name. - labels_to_stack: (Optional) Keys for columns to be stacked as a single - numpy array as the labels. It is searched under the key labels, features - and transformed features. The desired format is [left bounadary, top - boudnary, right boundary, bottom boundary, class id]. e.g. ['xmin', - 'ymin', 'xmax', 'ymax', 'class_id'] - predictions_to_stack: (Optional) Output names for columns to be stacked as - a single numpy array as the prediction. It should be the model's output - names. The desired format is [left bounadary, top boudnary, right - boundary, bottom boundary, class id, confidence score]. e.g. ['xmin', - 'ymin', 'xmax', 'ymax', 'class_id', 'scores'] - num_detections_key: (Optional) An output name in which to find the number - of detections to use for evaluation for a given example. It does nothing - if predictions_to_stack is not set. The value for this output should be - a scalar value or a single-value tensor. The stacked predicitions will - be truncated with the specified number of detections. - allow_missing_key: (Optional) If true, the preprocessor will return empty - array instead of raising errors. + """Average recall metric for object detection. + + It computes the average precision metric for object detections for a single + class. It averages MaxRecall metric over mulitple IoU thresholds. """ - super().__init__( - metric_util.merge_per_key_computations(self._metric_computations), - iou_thresholds=iou_thresholds, - class_id=class_id, - class_weight=class_weight, - area_range=area_range, - max_num_detections=max_num_detections, - name=name, - labels_to_stack=labels_to_stack, - predictions_to_stack=predictions_to_stack, - num_detections_key=num_detections_key, - allow_missing_key=allow_missing_key) - - def _default_name(self) -> str: - return AVERAGE_RECALL_NAME - - def _metric_computations( - self, - iou_thresholds: Optional[Union[float, List[float]]] = None, - class_id: Optional[int] = None, - class_weight: Optional[float] = None, - max_num_detections: Optional[int] = None, - area_range: Optional[Tuple[float, float]] = None, - name: Optional[str] = None, - model_name: str = '', - output_name: str = '', - example_weighted: bool = False, - labels_to_stack: Optional[List[str]] = None, - predictions_to_stack: Optional[List[str]] = None, - num_detections_key: Optional[str] = None, - allow_missing_key: bool = False, - ) -> metric_types.MetricComputations: - """Returns computations for confusion matrix metric.""" - - metric_util.validate_object_detection_arguments( - class_id=class_id, - class_weight=class_weight, - area_range=area_range, - max_num_detections=max_num_detections, - labels_to_stack=labels_to_stack, - predictions_to_stack=predictions_to_stack, - output_name=output_name) - - # set default value according to COCO metrics - if iou_thresholds is None: - iou_thresholds = np.linspace(0.5, 0.95, 10) - if class_weight is None: - class_weight = 1.0 - - key = metric_types.MetricKey( - name=name, - model_name=model_name, - output_name=output_name, - sub_key=None, - example_weighted=example_weighted, - aggregation_type=None) - - computations = [] - recalls_keys = [] - for iou_threshold in iou_thresholds: - max_recall_name = metric_util.generate_private_name_from_arguments( - confusion_matrix_metrics.MAX_RECALL_NAME, - iou_threshold=iou_threshold, - class_id=class_id, - class_weight=class_weight, - area_range=area_range, - max_num_detections=max_num_detections, - allow_missing_key=allow_missing_key) - - mr = object_detection_confusion_matrix_metrics.ObjectDetectionMaxRecall( - iou_threshold=iou_threshold, - class_id=class_id, - class_weight=class_weight, - area_range=area_range, - max_num_detections=max_num_detections, - name=max_recall_name, - labels_to_stack=labels_to_stack, - predictions_to_stack=predictions_to_stack, - num_detections_key=num_detections_key, - allow_missing_key=allow_missing_key) - computations.extend( - mr.computations(model_names=[model_name], output_names=[output_name])) - recalls_keys.append(computations[-1].keys[-1]) - - def result( - metrics: Dict[metric_types.MetricKey, Any] - ) -> Dict[metric_types.MetricKey, Union[float, np.ndarray]]: - for recalls_key in recalls_keys: - if math.isnan(metrics[recalls_key]): - logging.warning( - 'Recall with metric key %s is NaN, it will be' - ' ignored in the following calculation.', recalls_key) - recalls = [metrics[recalls_key] for recalls_key in recalls_keys] - value = np.nanmean(recalls) - return {key: value} - - derived_computation = metric_types.DerivedMetricComputation( - keys=[key], result=result) - computations.append(derived_computation) - return computations + def __init__( + self, + iou_thresholds: Optional[List[float]] = None, + class_id: Optional[int] = None, + class_weight: Optional[float] = None, + area_range: Optional[Tuple[float, float]] = None, + max_num_detections: Optional[int] = None, + name: Optional[str] = None, + labels_to_stack: Optional[List[str]] = None, + predictions_to_stack: Optional[List[str]] = None, + num_detections_key: Optional[str] = None, + allow_missing_key: bool = False, + ): + """Initializes average recall metric. + + This metric is only used in object-detection setting. It does not support + sub_key parameters due to the matching algorithm of bounding boxes. + + The metric supports using multiple outputs to form the labels/predictions if + the user specifies the label/predcition keys to stack. In this case, the + metric is not expected to work with multi-outputs. The metric only supports + multi outputs if the output of model is already pre-stacked in the expected + format, i.e. ['xmin', 'ymin', 'xmax', 'ymax', 'class_id'] for labels and + ['xmin', 'ymin', 'xmax', 'ymax', 'class_id', 'confidence scores'] for + predictions. + + Args: + ---- + iou_thresholds: (Optional) Threholds for a detection and ground truth pair + with specific iou to be considered as a match. + class_id: (Optional) The class ids for calculating metrics. + class_weight: (Optional) The weight associated with the object class ids. + If it is provided, it should have the same length as class_ids. + area_range: (Optional) The area-range for objects to be considered for + metrics. + max_num_detections: (Optional) The maximum number of detections for a + single image. + name: (Optional) Metric name. + labels_to_stack: (Optional) Keys for columns to be stacked as a single + numpy array as the labels. It is searched under the key labels, features + and transformed features. The desired format is [left bounadary, top + boudnary, right boundary, bottom boundary, class id]. e.g. ['xmin', + 'ymin', 'xmax', 'ymax', 'class_id'] + predictions_to_stack: (Optional) Output names for columns to be stacked as + a single numpy array as the prediction. It should be the model's output + names. The desired format is [left bounadary, top boudnary, right + boundary, bottom boundary, class id, confidence score]. e.g. ['xmin', + 'ymin', 'xmax', 'ymax', 'class_id', 'scores'] + num_detections_key: (Optional) An output name in which to find the number + of detections to use for evaluation for a given example. It does nothing + if predictions_to_stack is not set. The value for this output should be + a scalar value or a single-value tensor. The stacked predicitions will + be truncated with the specified number of detections. + allow_missing_key: (Optional) If true, the preprocessor will return empty + array instead of raising errors. + """ + super().__init__( + metric_util.merge_per_key_computations(self._metric_computations), + iou_thresholds=iou_thresholds, + class_id=class_id, + class_weight=class_weight, + area_range=area_range, + max_num_detections=max_num_detections, + name=name, + labels_to_stack=labels_to_stack, + predictions_to_stack=predictions_to_stack, + num_detections_key=num_detections_key, + allow_missing_key=allow_missing_key, + ) + + def _default_name(self) -> str: + return AVERAGE_RECALL_NAME + + def _metric_computations( + self, + iou_thresholds: Optional[Union[float, List[float]]] = None, + class_id: Optional[int] = None, + class_weight: Optional[float] = None, + max_num_detections: Optional[int] = None, + area_range: Optional[Tuple[float, float]] = None, + name: Optional[str] = None, + model_name: str = "", + output_name: str = "", + example_weighted: bool = False, + labels_to_stack: Optional[List[str]] = None, + predictions_to_stack: Optional[List[str]] = None, + num_detections_key: Optional[str] = None, + allow_missing_key: bool = False, + ) -> metric_types.MetricComputations: + """Returns computations for confusion matrix metric.""" + metric_util.validate_object_detection_arguments( + class_id=class_id, + class_weight=class_weight, + area_range=area_range, + max_num_detections=max_num_detections, + labels_to_stack=labels_to_stack, + predictions_to_stack=predictions_to_stack, + output_name=output_name, + ) + + # set default value according to COCO metrics + if iou_thresholds is None: + iou_thresholds = np.linspace(0.5, 0.95, 10) + if class_weight is None: + class_weight = 1.0 + + key = metric_types.MetricKey( + name=name, + model_name=model_name, + output_name=output_name, + sub_key=None, + example_weighted=example_weighted, + aggregation_type=None, + ) + + computations = [] + recalls_keys = [] + for iou_threshold in iou_thresholds: + max_recall_name = metric_util.generate_private_name_from_arguments( + confusion_matrix_metrics.MAX_RECALL_NAME, + iou_threshold=iou_threshold, + class_id=class_id, + class_weight=class_weight, + area_range=area_range, + max_num_detections=max_num_detections, + allow_missing_key=allow_missing_key, + ) + + mr = object_detection_confusion_matrix_metrics.ObjectDetectionMaxRecall( + iou_threshold=iou_threshold, + class_id=class_id, + class_weight=class_weight, + area_range=area_range, + max_num_detections=max_num_detections, + name=max_recall_name, + labels_to_stack=labels_to_stack, + predictions_to_stack=predictions_to_stack, + num_detections_key=num_detections_key, + allow_missing_key=allow_missing_key, + ) + computations.extend( + mr.computations(model_names=[model_name], output_names=[output_name]) + ) + recalls_keys.append(computations[-1].keys[-1]) + + def result( + metrics: Dict[metric_types.MetricKey, Any], + ) -> Dict[metric_types.MetricKey, Union[float, np.ndarray]]: + for recalls_key in recalls_keys: + if math.isnan(metrics[recalls_key]): + logging.warning( + "Recall with metric key %s is NaN, it will be" + " ignored in the following calculation.", + recalls_key, + ) + recalls = [metrics[recalls_key] for recalls_key in recalls_keys] + value = np.nanmean(recalls) + return {key: value} + + derived_computation = metric_types.DerivedMetricComputation( + keys=[key], result=result + ) + computations.append(derived_computation) + return computations metric_types.register_metric(COCOAverageRecall) class COCOMeanAverageRecall(metric_types.Metric): - """Mean Average recall metric for object detection. - - It computes the mean average precision metric for object detections for a - single class. It averages COCOAverageRecall metric over mulitple classes. - """ - - def __init__(self, - iou_thresholds: Optional[List[float]] = None, - class_ids: Optional[List[int]] = None, - class_weights: Optional[List[float]] = None, - area_range: Optional[Tuple[float, float]] = None, - max_num_detections: Optional[int] = None, - name: Optional[str] = None, - labels_to_stack: Optional[List[str]] = None, - predictions_to_stack: Optional[List[str]] = None, - num_detections_key: Optional[str] = None, - allow_missing_key: bool = False): - """Initializes average recall metric. - - This metric is only used in object-detection setting. It does not support - sub_key parameters due to the matching algorithm of bounding boxes. - - The metric supports using multiple outputs to form the labels/predictions if - the user specifies the label/predcition keys to stack. In this case, the - metric is not expected to work with multi-outputs. The metric only supports - multi outputs if the output of model is already pre-stacked in the expected - format, i.e. ['xmin', 'ymin', 'xmax', 'ymax', 'class_id'] for labels and - ['xmin', 'ymin', 'xmax', 'ymax', 'class_id', 'confidence scores'] for - predictions. - - Args: - iou_thresholds: (Optional) Threholds for a detection and ground truth pair - with specific iou to be considered as a match. - class_ids: (Optional) The class ids for calculating metrics. - class_weights: (Optional) The weight associated with the object class ids. - If it is provided, it should have the same length as class_ids. - area_range: (Optional) The area-range for objects to be considered for - metrics. - max_num_detections: (Optional) The maximum number of detections for a - single image. - name: (Optional) Metric name. - labels_to_stack: (Optional) Keys for columns to be stacked as a single - numpy array as the labels. It is searched under the key labels, features - and transformed features. The desired format is [left bounadary, top - boudnary, right boundary, bottom boundary, class id]. e.g. ['xmin', - 'ymin', 'xmax', 'ymax', 'class_id'] - predictions_to_stack: (Optional) Output names for columns to be stacked as - a single numpy array as the prediction. It should be the model's output - names. The desired format is [left bounadary, top boudnary, right - boundary, bottom boundary, class id, confidence score]. e.g. ['xmin', - 'ymin', 'xmax', 'ymax', 'class_id', 'scores'] - num_detections_key: (Optional) An output name in which to find the number - of detections to use for evaluation for a given example. It does nothing - if predictions_to_stack is not set. The value for this output should be - a scalar value or a single-value tensor. The stacked predicitions will - be truncated with the specified number of detections. - allow_missing_key: (Optional) If true, the preprocessor will return empty - array instead of raising errors. + """Mean Average recall metric for object detection. + + It computes the mean average precision metric for object detections for a + single class. It averages COCOAverageRecall metric over mulitple classes. """ - super().__init__( - metric_util.merge_per_key_computations(self._metric_computations), - iou_thresholds=iou_thresholds, - class_ids=class_ids, - class_weights=class_weights, - area_range=area_range, - max_num_detections=max_num_detections, - name=name, - labels_to_stack=labels_to_stack, - predictions_to_stack=predictions_to_stack, - num_detections_key=num_detections_key, - allow_missing_key=allow_missing_key) - - def _default_name(self) -> str: - return MEAN_AVERAGE_RECALL_NAME - - def _metric_computations( - self, - iou_thresholds: Optional[List[float]] = None, - class_ids: Optional[Union[int, List[int]]] = None, - class_weights: Optional[Union[float, List[float]]] = None, - max_num_detections: Optional[int] = None, - area_range: Optional[Tuple[float, float]] = None, - name: Optional[str] = None, - model_name: str = '', - output_name: str = '', - example_weighted: bool = False, - labels_to_stack: Optional[List[str]] = None, - predictions_to_stack: Optional[List[str]] = None, - num_detections_key: Optional[str] = None, - allow_missing_key: bool = False, - ) -> metric_types.MetricComputations: - """Returns computations for confusion matrix metric.""" - - metric_util.validate_object_detection_arguments( - class_id=class_ids, - class_weight=class_weights, - area_range=area_range, - max_num_detections=max_num_detections, - labels_to_stack=labels_to_stack, - predictions_to_stack=predictions_to_stack, - output_name=output_name) - - if class_weights is None: - class_weights = [1.0] * len(class_ids) - - key = metric_types.MetricKey( - name=name, - model_name=model_name, - output_name=output_name, - sub_key=None, - example_weighted=example_weighted, - aggregation_type=None) - - computations = [] - recalls_keys = [] - for class_id, class_weight in zip(class_ids, class_weights): - max_recall_name = metric_util.generate_private_name_from_arguments( - AVERAGE_RECALL_NAME, - iou_thresholds=iou_thresholds, - class_id=class_id, - class_weight=class_weight, - area_range=area_range, - max_num_detections=max_num_detections, - allow_missing_key=allow_missing_key) - - mr = COCOAverageRecall( - iou_thresholds=iou_thresholds, - class_id=class_id, - class_weight=class_weight, - area_range=area_range, - max_num_detections=max_num_detections, - name=max_recall_name, - labels_to_stack=labels_to_stack, - predictions_to_stack=predictions_to_stack, - num_detections_key=num_detections_key, - allow_missing_key=allow_missing_key) - computations.extend( - mr.computations(model_names=[model_name], output_names=[output_name])) - recalls_keys.append(computations[-1].keys[-1]) - - def result( - metrics: Dict[metric_types.MetricKey, Any] - ) -> Dict[metric_types.MetricKey, Union[float, np.ndarray]]: - recalls = [metrics[recalls_key] for recalls_key in recalls_keys] - value = np.nanmean(recalls) - return {key: value} - - derived_computation = metric_types.DerivedMetricComputation( - keys=[key], result=result) - computations.append(derived_computation) - return computations + def __init__( + self, + iou_thresholds: Optional[List[float]] = None, + class_ids: Optional[List[int]] = None, + class_weights: Optional[List[float]] = None, + area_range: Optional[Tuple[float, float]] = None, + max_num_detections: Optional[int] = None, + name: Optional[str] = None, + labels_to_stack: Optional[List[str]] = None, + predictions_to_stack: Optional[List[str]] = None, + num_detections_key: Optional[str] = None, + allow_missing_key: bool = False, + ): + """Initializes average recall metric. + + This metric is only used in object-detection setting. It does not support + sub_key parameters due to the matching algorithm of bounding boxes. + + The metric supports using multiple outputs to form the labels/predictions if + the user specifies the label/predcition keys to stack. In this case, the + metric is not expected to work with multi-outputs. The metric only supports + multi outputs if the output of model is already pre-stacked in the expected + format, i.e. ['xmin', 'ymin', 'xmax', 'ymax', 'class_id'] for labels and + ['xmin', 'ymin', 'xmax', 'ymax', 'class_id', 'confidence scores'] for + predictions. + + Args: + ---- + iou_thresholds: (Optional) Threholds for a detection and ground truth pair + with specific iou to be considered as a match. + class_ids: (Optional) The class ids for calculating metrics. + class_weights: (Optional) The weight associated with the object class ids. + If it is provided, it should have the same length as class_ids. + area_range: (Optional) The area-range for objects to be considered for + metrics. + max_num_detections: (Optional) The maximum number of detections for a + single image. + name: (Optional) Metric name. + labels_to_stack: (Optional) Keys for columns to be stacked as a single + numpy array as the labels. It is searched under the key labels, features + and transformed features. The desired format is [left bounadary, top + boudnary, right boundary, bottom boundary, class id]. e.g. ['xmin', + 'ymin', 'xmax', 'ymax', 'class_id'] + predictions_to_stack: (Optional) Output names for columns to be stacked as + a single numpy array as the prediction. It should be the model's output + names. The desired format is [left bounadary, top boudnary, right + boundary, bottom boundary, class id, confidence score]. e.g. ['xmin', + 'ymin', 'xmax', 'ymax', 'class_id', 'scores'] + num_detections_key: (Optional) An output name in which to find the number + of detections to use for evaluation for a given example. It does nothing + if predictions_to_stack is not set. The value for this output should be + a scalar value or a single-value tensor. The stacked predicitions will + be truncated with the specified number of detections. + allow_missing_key: (Optional) If true, the preprocessor will return empty + array instead of raising errors. + """ + super().__init__( + metric_util.merge_per_key_computations(self._metric_computations), + iou_thresholds=iou_thresholds, + class_ids=class_ids, + class_weights=class_weights, + area_range=area_range, + max_num_detections=max_num_detections, + name=name, + labels_to_stack=labels_to_stack, + predictions_to_stack=predictions_to_stack, + num_detections_key=num_detections_key, + allow_missing_key=allow_missing_key, + ) + + def _default_name(self) -> str: + return MEAN_AVERAGE_RECALL_NAME + + def _metric_computations( + self, + iou_thresholds: Optional[List[float]] = None, + class_ids: Optional[Union[int, List[int]]] = None, + class_weights: Optional[Union[float, List[float]]] = None, + max_num_detections: Optional[int] = None, + area_range: Optional[Tuple[float, float]] = None, + name: Optional[str] = None, + model_name: str = "", + output_name: str = "", + example_weighted: bool = False, + labels_to_stack: Optional[List[str]] = None, + predictions_to_stack: Optional[List[str]] = None, + num_detections_key: Optional[str] = None, + allow_missing_key: bool = False, + ) -> metric_types.MetricComputations: + """Returns computations for confusion matrix metric.""" + metric_util.validate_object_detection_arguments( + class_id=class_ids, + class_weight=class_weights, + area_range=area_range, + max_num_detections=max_num_detections, + labels_to_stack=labels_to_stack, + predictions_to_stack=predictions_to_stack, + output_name=output_name, + ) + + if class_weights is None: + class_weights = [1.0] * len(class_ids) + + key = metric_types.MetricKey( + name=name, + model_name=model_name, + output_name=output_name, + sub_key=None, + example_weighted=example_weighted, + aggregation_type=None, + ) + + computations = [] + recalls_keys = [] + for class_id, class_weight in zip(class_ids, class_weights): + max_recall_name = metric_util.generate_private_name_from_arguments( + AVERAGE_RECALL_NAME, + iou_thresholds=iou_thresholds, + class_id=class_id, + class_weight=class_weight, + area_range=area_range, + max_num_detections=max_num_detections, + allow_missing_key=allow_missing_key, + ) + + mr = COCOAverageRecall( + iou_thresholds=iou_thresholds, + class_id=class_id, + class_weight=class_weight, + area_range=area_range, + max_num_detections=max_num_detections, + name=max_recall_name, + labels_to_stack=labels_to_stack, + predictions_to_stack=predictions_to_stack, + num_detections_key=num_detections_key, + allow_missing_key=allow_missing_key, + ) + computations.extend( + mr.computations(model_names=[model_name], output_names=[output_name]) + ) + recalls_keys.append(computations[-1].keys[-1]) + + def result( + metrics: Dict[metric_types.MetricKey, Any], + ) -> Dict[metric_types.MetricKey, Union[float, np.ndarray]]: + recalls = [metrics[recalls_key] for recalls_key in recalls_keys] + value = np.nanmean(recalls) + return {key: value} + + derived_computation = metric_types.DerivedMetricComputation( + keys=[key], result=result + ) + computations.append(derived_computation) + return computations metric_types.register_metric(COCOMeanAverageRecall) diff --git a/tensorflow_model_analysis/metrics/object_detection_metrics_test.py b/tensorflow_model_analysis/metrics/object_detection_metrics_test.py index 6cfa3e357e..879342dcf0 100644 --- a/tensorflow_model_analysis/metrics/object_detection_metrics_test.py +++ b/tensorflow_model_analysis/metrics/object_detection_metrics_test.py @@ -12,40 +12,43 @@ # See the License for the specific language governing permissions and # limitations under the License. """Tests for object detection related metrics.""" -from absl.testing import absltest -from absl.testing import parameterized + import apache_beam as beam -from apache_beam.testing import util import numpy as np +from absl.testing import absltest, parameterized +from apache_beam.testing import util +from google.protobuf import text_format + import tensorflow_model_analysis as tfma -from tensorflow_model_analysis.proto import config_pb2 from tensorflow_model_analysis.metrics import metric_types -from google.protobuf import text_format +from tensorflow_model_analysis.proto import config_pb2 class ObjectDetectionMetricsTest(parameterized.TestCase): - """This tests the object detection metrics. + """This tests the object detection metrics. - Results provided from COCOAPI: AP with all IoUs causes overflow of memory, - thus we do not check it here, but check the single value instead and the - average of two IoUs. - Average Precision @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.533 - Average Precision @[ IoU=0.50 | area= all | maxDets=100 ] = 0.916 - Average Precision @[ IoU=0.75 | area= all | maxDets=100 ] = 0.416 - Average Precision @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.500 - Average Precision @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.303 - Average Precision @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.701 - Average Recall @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.375 - Average Recall @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.533 - Average Recall @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.533 - Average Recall @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.500 - Average Recall @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.300 - Average Recall @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.700 - """ + Results provided from COCOAPI: AP with all IoUs causes overflow of memory, + thus we do not check it here, but check the single value instead and the + average of two IoUs. + Average Precision @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.533 + Average Precision @[ IoU=0.50 | area= all | maxDets=100 ] = 0.916 + Average Precision @[ IoU=0.75 | area= all | maxDets=100 ] = 0.416 + Average Precision @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.500 + Average Precision @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.303 + Average Precision @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.701 + Average Recall @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.375 + Average Recall @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.533 + Average Recall @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.533 + Average Recall @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.500 + Average Recall @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.300 + Average Recall @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.700 + """ - @parameterized.named_parameters(('_average_precision_iou0.5', - text_format.Parse( - """ + @parameterized.named_parameters( + ( + "_average_precision_iou0.5", + text_format.Parse( + """ model_specs { signature_name: "serving_default" prediction_key: "predictions" # placeholder @@ -60,10 +63,16 @@ class ObjectDetectionMetricsTest(parameterized.TestCase): '"max_num_detections":100, "name":"iou0.5"' } } - """, config_pb2.EvalConfig()), ['iou0.5'], [0.916]), - ('_average_precision_iou0.75', - text_format.Parse( - """ + """, + config_pb2.EvalConfig(), + ), + ["iou0.5"], + [0.916], + ), + ( + "_average_precision_iou0.75", + text_format.Parse( + """ model_specs { signature_name: "serving_default" prediction_key: "predictions" # placeholder @@ -78,10 +87,16 @@ class ObjectDetectionMetricsTest(parameterized.TestCase): '"max_num_detections":100, "name":"iou0.75"' } } - """, config_pb2.EvalConfig()), ['iou0.75'], [0.416]), - ('_average_precision_ave', - text_format.Parse( - """ + """, + config_pb2.EvalConfig(), + ), + ["iou0.75"], + [0.416], + ), + ( + "_average_precision_ave", + text_format.Parse( + """ model_specs { signature_name: "serving_default" prediction_key: "predictions" # placeholder @@ -96,9 +111,16 @@ class ObjectDetectionMetricsTest(parameterized.TestCase): '"max_num_detections":100, "name":"iouave"' } } - """, config_pb2.EvalConfig()), ['iouave'], [0.666]), ('_average_recall_mdet1', - text_format.Parse( - """ + """, + config_pb2.EvalConfig(), + ), + ["iouave"], + [0.666], + ), + ( + "_average_recall_mdet1", + text_format.Parse( + """ model_specs { signature_name: "serving_default" prediction_key: "predictions" # placeholder @@ -113,9 +135,16 @@ class ObjectDetectionMetricsTest(parameterized.TestCase): '"name":"mdet1"' } } - """, config_pb2.EvalConfig()), ['mdet1'], [0.375]), ('_average_recall_mdet10', - text_format.Parse( - """ + """, + config_pb2.EvalConfig(), + ), + ["mdet1"], + [0.375], + ), + ( + "_average_recall_mdet10", + text_format.Parse( + """ model_specs { signature_name: "serving_default" prediction_key: "predictions" # placeholder @@ -130,10 +159,16 @@ class ObjectDetectionMetricsTest(parameterized.TestCase): '"name":"mdet10"' } } - """, config_pb2.EvalConfig()), ['mdet10'], [0.533]), - ('_average_recall_mdet100', - text_format.Parse( - """ + """, + config_pb2.EvalConfig(), + ), + ["mdet10"], + [0.533], + ), + ( + "_average_recall_mdet100", + text_format.Parse( + """ model_specs { signature_name: "serving_default" prediction_key: "predictions" # placeholder @@ -148,10 +183,16 @@ class ObjectDetectionMetricsTest(parameterized.TestCase): '"name":"mdet100"' } } - """, config_pb2.EvalConfig()), ['mdet100'], [0.533]), - ('_average_recall_arsmall', - text_format.Parse( - """ + """, + config_pb2.EvalConfig(), + ), + ["mdet100"], + [0.533], + ), + ( + "_average_recall_arsmall", + text_format.Parse( + """ model_specs { signature_name: "serving_default" prediction_key: "predictions" # placeholder @@ -166,10 +207,16 @@ class ObjectDetectionMetricsTest(parameterized.TestCase): '"max_num_detections":100, "name":"arsmall"' } } - """, config_pb2.EvalConfig()), ['arsmall'], [0.500]), - ('_average_recall_armedium', - text_format.Parse( - """ + """, + config_pb2.EvalConfig(), + ), + ["arsmall"], + [0.500], + ), + ( + "_average_recall_armedium", + text_format.Parse( + """ model_specs { signature_name: "serving_default" prediction_key: "predictions" # placeholder @@ -184,10 +231,16 @@ class ObjectDetectionMetricsTest(parameterized.TestCase): '"max_num_detections":100, "name":"armedium"' } } - """, config_pb2.EvalConfig()), ['armedium'], [0.300]), - ('_average_recall_arlarge', - text_format.Parse( - """ + """, + config_pb2.EvalConfig(), + ), + ["armedium"], + [0.300], + ), + ( + "_average_recall_arlarge", + text_format.Parse( + """ model_specs { signature_name: "serving_default" prediction_key: "predictions" # placeholder @@ -202,72 +255,105 @@ class ObjectDetectionMetricsTest(parameterized.TestCase): '"max_num_detections":100, "name":"arlarge"' } } - """, config_pb2.EvalConfig()), ['arlarge'], [0.700])) - def testMetricValuesWithLargerData(self, eval_config, name_list, - expected_results): - - extracts = [{ - 'features': { - 'labels': - np.array([[[272.1, 200.23, 424.07, 480., 2.], - [181.23, 86.28, 208.67, 159.81, 2.], - [174.74, 0., 435.78, 220.79, 2.]]]), - 'predictions': - np.array([[[271.2, 178.86, 429.52, 459.57, 2., 0.64], - [178.53, 92.57, 206.39, 159.71, 2., 0.38], - [167.96, 9.97, 442.79, 235.07, 2., 0.95]]]) - } - }, { - 'features': { - 'labels': - np.array([[[473.07, 395.93, 503.07, 424.6, 1.], - [204.01, 235.08, 264.85, 412.44, 2.], - [0.43, 499.79, 340.22, 606.24, 2.], - [204.42, 304.1, 256.93, 456.86, 2.]]]), - 'predictions': - np.array([[[471.15, 398.57, 502.29, 428.26, 1., 0.54], - [198.53, 242.14, 263.93, 427.51, 2., 0.95], - [-32.86, 505.75, 338.82, 619.66, 2., 0.17], - [201.59, 299.39, 258.4, 452.88, 1., 0.05]]]) - } - }] + """, + config_pb2.EvalConfig(), + ), + ["arlarge"], + [0.700], + ), + ) + def testMetricValuesWithLargerData(self, eval_config, name_list, expected_results): + extracts = [ + { + "features": { + "labels": np.array( + [ + [ + [272.1, 200.23, 424.07, 480.0, 2.0], + [181.23, 86.28, 208.67, 159.81, 2.0], + [174.74, 0.0, 435.78, 220.79, 2.0], + ] + ] + ), + "predictions": np.array( + [ + [ + [271.2, 178.86, 429.52, 459.57, 2.0, 0.64], + [178.53, 92.57, 206.39, 159.71, 2.0, 0.38], + [167.96, 9.97, 442.79, 235.07, 2.0, 0.95], + ] + ] + ), + } + }, + { + "features": { + "labels": np.array( + [ + [ + [473.07, 395.93, 503.07, 424.6, 1.0], + [204.01, 235.08, 264.85, 412.44, 2.0], + [0.43, 499.79, 340.22, 606.24, 2.0], + [204.42, 304.1, 256.93, 456.86, 2.0], + ] + ] + ), + "predictions": np.array( + [ + [ + [471.15, 398.57, 502.29, 428.26, 1.0, 0.54], + [198.53, 242.14, 263.93, 427.51, 2.0, 0.95], + [-32.86, 505.75, 338.82, 619.66, 2.0, 0.17], + [201.59, 299.39, 258.4, 452.88, 1.0, 0.05], + ] + ] + ), + } + }, + ] - evaluators = tfma.default_evaluators(eval_config=eval_config) - extractors = tfma.default_extractors( - eval_shared_model=None, eval_config=eval_config) + evaluators = tfma.default_evaluators(eval_config=eval_config) + extractors = tfma.default_extractors( + eval_shared_model=None, eval_config=eval_config + ) - with beam.Pipeline() as p: - result = ( - p | 'LoadData' >> beam.Create(extracts) - | 'ExtractEval' >> tfma.ExtractAndEvaluate( - extractors=extractors, evaluators=evaluators)) + with beam.Pipeline() as p: + result = ( + p + | "LoadData" >> beam.Create(extracts) + | "ExtractEval" + >> tfma.ExtractAndEvaluate(extractors=extractors, evaluators=evaluators) + ) - # pylint: enable=no-value-for-parameter + # pylint: enable=no-value-for-parameter - def check_result(got): - try: - self.assertLen(got, 1) - got_slice_key, got_metrics = got[0] - self.assertEqual(got_slice_key, ()) - self.assertLen(got_metrics, len(name_list)) - for name, expected_result in zip(name_list, expected_results): - key = metric_types.MetricKey(name=name) - self.assertIn(key, got_metrics) - got_metric = got_metrics[key] - np.testing.assert_allclose( - expected_result, - got_metric, - rtol=1e-3, - err_msg=f'This {name} metric fails.') - except AssertionError as err: - raise util.BeamAssertException(err) + def check_result(got): + try: + self.assertLen(got, 1) + got_slice_key, got_metrics = got[0] + self.assertEqual(got_slice_key, ()) + self.assertLen(got_metrics, len(name_list)) + for name, expected_result in zip(name_list, expected_results): + key = metric_types.MetricKey(name=name) + self.assertIn(key, got_metrics) + got_metric = got_metrics[key] + np.testing.assert_allclose( + expected_result, + got_metric, + rtol=1e-3, + err_msg=f"This {name} metric fails.", + ) + except AssertionError as err: + raise util.BeamAssertException(err) - self.assertIn('metrics', result) - util.assert_that(result['metrics'], check_result, label='result') + self.assertIn("metrics", result) + util.assert_that(result["metrics"], check_result, label="result") - @parameterized.named_parameters(('_average_precision_iou0.5', - text_format.Parse( - """ + @parameterized.named_parameters( + ( + "_average_precision_iou0.5", + text_format.Parse( + """ model_specs { signature_name: "serving_default" prediction_key: "predictions" # placeholder @@ -284,86 +370,105 @@ def check_result(got): '"predictions_to_stack":["bbox", "class_id", "scores"]' } } - """, config_pb2.EvalConfig()), ['iou0.5'], [0.916])) - def testMetricValuesWithSplittedData(self, eval_config, name_list, - expected_results): - - extracts = [{ - 'features': { - 'labels': { - 'xmin': np.array([[272.1, 181.23, 174.74]]), - 'ymin': np.array([[200.23, 86.28, 0.]]), - 'xmax': np.array([[424.07, 208.67, 435.78]]), - 'ymax': np.array([[480., 159.81, 220.79]]), - 'class_id': np.array([[2., 2., 2.]]), + """, + config_pb2.EvalConfig(), + ), + ["iou0.5"], + [0.916], + ) + ) + def testMetricValuesWithSplittedData( + self, eval_config, name_list, expected_results + ): + extracts = [ + { + "features": { + "labels": { + "xmin": np.array([[272.1, 181.23, 174.74]]), + "ymin": np.array([[200.23, 86.28, 0.0]]), + "xmax": np.array([[424.07, 208.67, 435.78]]), + "ymax": np.array([[480.0, 159.81, 220.79]]), + "class_id": np.array([[2.0, 2.0, 2.0]]), + }, + "predictions": { + "bbox": np.array( + [ + [ + [271.2, 178.86, 429.52, 459.57], + [178.53, 92.57, 206.39, 159.71], + [167.96, 9.97, 442.79, 235.07], + ] + ] + ), + "class_id": np.array([[2.0, 2.0, 2.0]]), + "scores": np.array([[0.64, 0.38, 0.95]]), + }, + } }, - 'predictions': { - 'bbox': - np.array([[[271.2, 178.86, 429.52, 459.57], - [178.53, 92.57, 206.39, 159.71], - [167.96, 9.97, 442.79, 235.07]]]), - 'class_id': - np.array([[2., 2., 2.]]), - 'scores': - np.array([[0.64, 0.38, 0.95]]), - } - } - }, { - 'features': { - 'labels': { - 'xmin': np.array([[473.07, 204.01, 0.43, 204.42]]), - 'ymin': np.array([[395.93, 235.08, 499.79, 304.1]]), - 'xmax': np.array([[503.07, 264.85, 340.22, 256.93]]), - 'ymax': np.array([[424.6, 412.44, 606.24, 456.86]]), - 'class_id': np.array([[1., 2., 2., 2.]]), + { + "features": { + "labels": { + "xmin": np.array([[473.07, 204.01, 0.43, 204.42]]), + "ymin": np.array([[395.93, 235.08, 499.79, 304.1]]), + "xmax": np.array([[503.07, 264.85, 340.22, 256.93]]), + "ymax": np.array([[424.6, 412.44, 606.24, 456.86]]), + "class_id": np.array([[1.0, 2.0, 2.0, 2.0]]), + }, + "predictions": { + "bbox": np.array( + [ + [ + [471.15, 398.57, 502.29, 428.26], + [198.53, 242.14, 263.93, 427.51], + [-32.86, 505.75, 338.82, 619.66], + [201.59, 299.39, 258.4, 452.88], + ] + ] + ), + "class_id": np.array([[1.0, 2.0, 2.0, 1.0]]), + "scores": np.array([[0.54, 0.95, 0.17, 0.05]]), + }, + } }, - 'predictions': { - 'bbox': - np.array([[[471.15, 398.57, 502.29, 428.26], - [198.53, 242.14, 263.93, 427.51], - [-32.86, 505.75, 338.82, 619.66], - [201.59, 299.39, 258.4, 452.88]]]), - 'class_id': - np.array([[1., 2., 2., 1.]]), - 'scores': - np.array([[0.54, 0.95, 0.17, 0.05]]), - } - } - }] + ] - evaluators = tfma.default_evaluators(eval_config=eval_config) - extractors = tfma.default_extractors( - eval_shared_model=None, eval_config=eval_config) + evaluators = tfma.default_evaluators(eval_config=eval_config) + extractors = tfma.default_extractors( + eval_shared_model=None, eval_config=eval_config + ) - with beam.Pipeline() as p: - result = ( - p | 'LoadData' >> beam.Create(extracts) - | 'ExtractEval' >> tfma.ExtractAndEvaluate( - extractors=extractors, evaluators=evaluators)) + with beam.Pipeline() as p: + result = ( + p + | "LoadData" >> beam.Create(extracts) + | "ExtractEval" + >> tfma.ExtractAndEvaluate(extractors=extractors, evaluators=evaluators) + ) - # pylint: enable=no-value-for-parameter + # pylint: enable=no-value-for-parameter - def check_result(got): - try: - self.assertLen(got, 1) - got_slice_key, got_metrics = got[0] - self.assertEqual(got_slice_key, ()) - self.assertLen(got_metrics, len(name_list)) - for name, expected_result in zip(name_list, expected_results): - key = metric_types.MetricKey(name=name) - self.assertIn(key, got_metrics) - got_metric = got_metrics[key] - np.testing.assert_allclose( - expected_result, - got_metric, - rtol=1e-3, - err_msg=f'This {name} metric fails.') - except AssertionError as err: - raise util.BeamAssertException(err) + def check_result(got): + try: + self.assertLen(got, 1) + got_slice_key, got_metrics = got[0] + self.assertEqual(got_slice_key, ()) + self.assertLen(got_metrics, len(name_list)) + for name, expected_result in zip(name_list, expected_results): + key = metric_types.MetricKey(name=name) + self.assertIn(key, got_metrics) + got_metric = got_metrics[key] + np.testing.assert_allclose( + expected_result, + got_metric, + rtol=1e-3, + err_msg=f"This {name} metric fails.", + ) + except AssertionError as err: + raise util.BeamAssertException(err) - self.assertIn('metrics', result) - util.assert_that(result['metrics'], check_result, label='result') + self.assertIn("metrics", result) + util.assert_that(result["metrics"], check_result, label="result") -if __name__ == '__main__': - absltest.main() +if __name__ == "__main__": + absltest.main() diff --git a/tensorflow_model_analysis/metrics/prediction_difference_metrics.py b/tensorflow_model_analysis/metrics/prediction_difference_metrics.py index 49de133697..62e52dbce1 100644 --- a/tensorflow_model_analysis/metrics/prediction_difference_metrics.py +++ b/tensorflow_model_analysis/metrics/prediction_difference_metrics.py @@ -14,29 +14,29 @@ """PredictionDifference metrics.""" import dataclasses -from typing import Optional, Dict, Iterable, List +from typing import Dict, Iterable, List, Optional import apache_beam as beam -from tensorflow_model_analysis.metrics import metric_types -from tensorflow_model_analysis.metrics import metric_util + +from tensorflow_model_analysis.metrics import metric_types, metric_util from tensorflow_model_analysis.proto import config_pb2 from tensorflow_model_analysis.utils import model_util -SYMMETRIC_PREDICITON_DIFFERENCE_NAME = 'symmetric_prediction_difference' +SYMMETRIC_PREDICITON_DIFFERENCE_NAME = "symmetric_prediction_difference" _K_EPSILON = 1e-7 class SymmetricPredictionDifference(metric_types.Metric): - """PredictionDifference computes the avg pointwise diff between models.""" - - def __init__(self, name: str = SYMMETRIC_PREDICITON_DIFFERENCE_NAME): - """Initializes PredictionDifference metric. + """PredictionDifference computes the avg pointwise diff between models.""" - Args: - name: Metric name. - """ + def __init__(self, name: str = SYMMETRIC_PREDICITON_DIFFERENCE_NAME): + """Initializes PredictionDifference metric. - super().__init__(_symmetric_prediction_difference_computations, name=name) + Args: + ---- + name: Metric name. + """ + super().__init__(_symmetric_prediction_difference_computations, name=name) metric_types.register_metric(SymmetricPredictionDifference) @@ -48,118 +48,137 @@ def _symmetric_prediction_difference_computations( model_names: Optional[List[str]] = None, output_names: Optional[List[str]] = None, sub_keys: Optional[List[metric_types.SubKey]] = None, - example_weighted: bool = False) -> metric_types.MetricComputations: - """Returns metric computations for SymmetricPredictionDifference. - - This is not meant to be used with merge_per_key_computations because we - don't want to create computations for the baseline model, and we want to - provide the baseline model name to each Combiner - - Args: - name: The name of the metric returned by the computations. - eval_config: The EvalConfig for this TFMA evaluation. - model_names: The set of models for which to compute this metric. - output_names: The set of output names for which to compute this metric. - sub_keys: The set of sub_key settings for which to compute this metric. - example_weighted: Whether to compute this metric using example weights. - """ - computations = [] - baseline_spec = model_util.get_baseline_model_spec(eval_config) - baseline_model_name = baseline_spec.name if baseline_spec else None - for model_name in model_names or ['']: - if model_name == baseline_model_name: - continue - for output_name in output_names or ['']: - for sub_key in sub_keys or [None]: - key = metric_types.MetricKey( - name=name, - model_name=model_name, - output_name=output_name, - sub_key=sub_key, - example_weighted=example_weighted, - is_diff=True) - computations.append( - metric_types.MetricComputation( - keys=[key], - preprocessors=None, - combiner=_SymmetricPredictionDifferenceCombiner( - eval_config, baseline_model_name, model_name, output_name, - key, example_weighted))) - return computations + example_weighted: bool = False, +) -> metric_types.MetricComputations: + """Returns metric computations for SymmetricPredictionDifference. + + This is not meant to be used with merge_per_key_computations because we + don't want to create computations for the baseline model, and we want to + provide the baseline model name to each Combiner + + Args: + ---- + name: The name of the metric returned by the computations. + eval_config: The EvalConfig for this TFMA evaluation. + model_names: The set of models for which to compute this metric. + output_names: The set of output names for which to compute this metric. + sub_keys: The set of sub_key settings for which to compute this metric. + example_weighted: Whether to compute this metric using example weights. + """ + computations = [] + baseline_spec = model_util.get_baseline_model_spec(eval_config) + baseline_model_name = baseline_spec.name if baseline_spec else None + for model_name in model_names or [""]: + if model_name == baseline_model_name: + continue + for output_name in output_names or [""]: + for sub_key in sub_keys or [None]: + key = metric_types.MetricKey( + name=name, + model_name=model_name, + output_name=output_name, + sub_key=sub_key, + example_weighted=example_weighted, + is_diff=True, + ) + computations.append( + metric_types.MetricComputation( + keys=[key], + preprocessors=None, + combiner=_SymmetricPredictionDifferenceCombiner( + eval_config, + baseline_model_name, + model_name, + output_name, + key, + example_weighted, + ), + ) + ) + return computations @dataclasses.dataclass class _SymmetricPredictionDifferenceAccumulator: - num_weighted_examples: float = 0.0 - total_pointwise_sym_diff: float = 0.0 + num_weighted_examples: float = 0.0 + total_pointwise_sym_diff: float = 0.0 - def merge(self, other: '_SymmetricPredictionDifferenceAccumulator'): - self.num_weighted_examples += other.num_weighted_examples - self.total_pointwise_sym_diff += other.total_pointwise_sym_diff + def merge(self, other: "_SymmetricPredictionDifferenceAccumulator"): + self.num_weighted_examples += other.num_weighted_examples + self.total_pointwise_sym_diff += other.total_pointwise_sym_diff class _SymmetricPredictionDifferenceCombiner(beam.CombineFn): - """Computes PredictionDifference.""" - - def __init__(self, eval_config: config_pb2.EvalConfig, - baseline_model_name: str, model_name: str, output_name: str, - key: metric_types.MetricKey, example_weighted: bool): - self._eval_config = eval_config - self._baseline_model_name = baseline_model_name - self._model_name = model_name - self._output_name = output_name - self._key = key - self._example_weighted = example_weighted - - def create_accumulator(self) -> _SymmetricPredictionDifferenceAccumulator: - return _SymmetricPredictionDifferenceAccumulator() - - def add_input( - self, accumulator: _SymmetricPredictionDifferenceAccumulator, - element: metric_types.StandardMetricInputs - ) -> _SymmetricPredictionDifferenceAccumulator: - - _, base_prediction, base_example_weight = next( - metric_util.to_label_prediction_example_weight( - element, - eval_config=self._eval_config, - model_name=self._baseline_model_name, - output_name=self._output_name, - flatten=True, - example_weighted=self._example_weighted)) - - _, model_prediction, _ = next( - metric_util.to_label_prediction_example_weight( - element, - eval_config=self._eval_config, - model_name=self._model_name, - output_name=self._output_name, - flatten=True, - example_weighted=self._example_weighted)) - base_example_weight = metric_util.safe_to_scalar(base_example_weight) - accumulator.num_weighted_examples += base_example_weight - numerator = 2 * abs(base_prediction - model_prediction) - denominator = abs(base_prediction + model_prediction) - if numerator < _K_EPSILON and denominator < _K_EPSILON: - sym_pd = 0.0 - else: - sym_pd = metric_util.safe_to_scalar((numerator / denominator)) - accumulator.total_pointwise_sym_diff += sym_pd * base_example_weight - return accumulator - - def merge_accumulators( - self, accumulators: Iterable[_SymmetricPredictionDifferenceAccumulator] - ) -> _SymmetricPredictionDifferenceAccumulator: - result = next(iter(accumulators)) - for accumulator in accumulators: - result.merge(accumulator) - return result - - def extract_output( - self, accumulator: _SymmetricPredictionDifferenceAccumulator - ) -> Dict[metric_types.MetricKey, float]: - return { - self._key: - accumulator.total_pointwise_sym_diff / - accumulator.num_weighted_examples - } + """Computes PredictionDifference.""" + + def __init__( + self, + eval_config: config_pb2.EvalConfig, + baseline_model_name: str, + model_name: str, + output_name: str, + key: metric_types.MetricKey, + example_weighted: bool, + ): + self._eval_config = eval_config + self._baseline_model_name = baseline_model_name + self._model_name = model_name + self._output_name = output_name + self._key = key + self._example_weighted = example_weighted + + def create_accumulator(self) -> _SymmetricPredictionDifferenceAccumulator: + return _SymmetricPredictionDifferenceAccumulator() + + def add_input( + self, + accumulator: _SymmetricPredictionDifferenceAccumulator, + element: metric_types.StandardMetricInputs, + ) -> _SymmetricPredictionDifferenceAccumulator: + _, base_prediction, base_example_weight = next( + metric_util.to_label_prediction_example_weight( + element, + eval_config=self._eval_config, + model_name=self._baseline_model_name, + output_name=self._output_name, + flatten=True, + example_weighted=self._example_weighted, + ) + ) + + _, model_prediction, _ = next( + metric_util.to_label_prediction_example_weight( + element, + eval_config=self._eval_config, + model_name=self._model_name, + output_name=self._output_name, + flatten=True, + example_weighted=self._example_weighted, + ) + ) + base_example_weight = metric_util.safe_to_scalar(base_example_weight) + accumulator.num_weighted_examples += base_example_weight + numerator = 2 * abs(base_prediction - model_prediction) + denominator = abs(base_prediction + model_prediction) + if numerator < _K_EPSILON and denominator < _K_EPSILON: + sym_pd = 0.0 + else: + sym_pd = metric_util.safe_to_scalar(numerator / denominator) + accumulator.total_pointwise_sym_diff += sym_pd * base_example_weight + return accumulator + + def merge_accumulators( + self, accumulators: Iterable[_SymmetricPredictionDifferenceAccumulator] + ) -> _SymmetricPredictionDifferenceAccumulator: + result = next(iter(accumulators)) + for accumulator in accumulators: + result.merge(accumulator) + return result + + def extract_output( + self, accumulator: _SymmetricPredictionDifferenceAccumulator + ) -> Dict[metric_types.MetricKey, float]: + return { + self._key: accumulator.total_pointwise_sym_diff + / accumulator.num_weighted_examples + } diff --git a/tensorflow_model_analysis/metrics/prediction_difference_metrics_test.py b/tensorflow_model_analysis/metrics/prediction_difference_metrics_test.py index e79cc289e2..5fb52a1ede 100644 --- a/tensorflow_model_analysis/metrics/prediction_difference_metrics_test.py +++ b/tensorflow_model_analysis/metrics/prediction_difference_metrics_test.py @@ -13,23 +13,24 @@ # limitations under the License. """Tests for prediction difference metrics.""" -from absl.testing import absltest import apache_beam as beam +from absl.testing import absltest from apache_beam.testing import util -from tensorflow_model_analysis.metrics import metric_types -from tensorflow_model_analysis.metrics import metric_util -from tensorflow_model_analysis.metrics import prediction_difference_metrics +from google.protobuf import text_format + +from tensorflow_model_analysis.metrics import ( + metric_types, + metric_util, + prediction_difference_metrics, +) from tensorflow_model_analysis.proto import config_pb2 from tensorflow_model_analysis.writers import metrics_plots_and_validations_writer -from google.protobuf import text_format - class SymmetricPredictionDifferenceTest(absltest.TestCase): - - def testSymmetricPredictionDifference(self): - eval_config = text_format.Parse( - """ + def testSymmetricPredictionDifference(self): + eval_config = text_format.Parse( + """ model_specs { name: "baseline" is_baseline: true @@ -37,85 +38,95 @@ def testSymmetricPredictionDifference(self): model_specs { name: "candidate" } - """, config_pb2.EvalConfig()) - baseline_model_name = 'baseline' - candidate_model_name = 'candidate' - computations = prediction_difference_metrics.SymmetricPredictionDifference( - ).computations( - eval_config=eval_config, - model_names=['baseline', 'candidate'], - output_names=[''], - example_weighted=True) - self.assertLen(computations, 1) - computation = computations[0] - - examples = [{ - 'labels': [0], - 'example_weights': [1], - 'predictions': { - baseline_model_name: [0.1], - candidate_model_name: [0.2] - } - }, { - 'labels': [0], - 'example_weights': [2], - 'predictions': { - baseline_model_name: [0.2], - candidate_model_name: [0.3] - } - }, { - 'labels': [1], - 'example_weights': [3], - 'predictions': { - baseline_model_name: [0.9], - candidate_model_name: [0.8] - } - }] - - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'Create' >> beam.Create(examples) - | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs) - | 'AddSlice' >> beam.Map(lambda x: ((), x)) - | 'ComputeMetric' >> beam.CombinePerKey(computation.combiner)) - - # pylint: enable=no-value-for-parameter - - def check_result(got): - try: - self.assertLen(got, 1) - got_proto = ( - metrics_plots_and_validations_writer - .convert_slice_metrics_to_proto( - got[0], add_metrics_callbacks=None)) - self.assertLen(got_proto.metric_keys_and_values, 1) - got_kv_proto = got_proto.metric_keys_and_values[0] - self.assertEqual( - got_kv_proto.value.WhichOneof('type'), 'double_value') - got_slice_key, got_metrics = got[0] - self.assertEqual(got_slice_key, ()) - pd_key = metric_types.MetricKey( - name=prediction_difference_metrics - .SYMMETRIC_PREDICITON_DIFFERENCE_NAME, - model_name=candidate_model_name, - output_name='', - example_weighted=True, - is_diff=True) - self.assertIn(pd_key, got_metrics) - self.assertAlmostEqual( - got_metrics[pd_key], - (2 * 0.1 / 0.3 * 1 + 2 * 0.1 / 0.5 * 2 + 2 * 0.1 / 1.7 * 3) / 6) - - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(result, check_result, label='result') - - def testSymmetricPredictionDifferenceEpsilon(self): - eval_config = text_format.Parse( - """ + """, + config_pb2.EvalConfig(), + ) + baseline_model_name = "baseline" + candidate_model_name = "candidate" + computations = ( + prediction_difference_metrics.SymmetricPredictionDifference().computations( + eval_config=eval_config, + model_names=["baseline", "candidate"], + output_names=[""], + example_weighted=True, + ) + ) + self.assertLen(computations, 1) + computation = computations[0] + + examples = [ + { + "labels": [0], + "example_weights": [1], + "predictions": { + baseline_model_name: [0.1], + candidate_model_name: [0.2], + }, + }, + { + "labels": [0], + "example_weights": [2], + "predictions": { + baseline_model_name: [0.2], + candidate_model_name: [0.3], + }, + }, + { + "labels": [1], + "example_weights": [3], + "predictions": { + baseline_model_name: [0.9], + candidate_model_name: [0.8], + }, + }, + ] + + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "Create" >> beam.Create(examples) + | "Process" >> beam.Map(metric_util.to_standard_metric_inputs) + | "AddSlice" >> beam.Map(lambda x: ((), x)) + | "ComputeMetric" >> beam.CombinePerKey(computation.combiner) + ) + + # pylint: enable=no-value-for-parameter + + def check_result(got): + try: + self.assertLen(got, 1) + got_proto = metrics_plots_and_validations_writer.convert_slice_metrics_to_proto( + got[0], add_metrics_callbacks=None + ) + self.assertLen(got_proto.metric_keys_and_values, 1) + got_kv_proto = got_proto.metric_keys_and_values[0] + self.assertEqual( + got_kv_proto.value.WhichOneof("type"), "double_value" + ) + got_slice_key, got_metrics = got[0] + self.assertEqual(got_slice_key, ()) + pd_key = metric_types.MetricKey( + name=prediction_difference_metrics.SYMMETRIC_PREDICITON_DIFFERENCE_NAME, + model_name=candidate_model_name, + output_name="", + example_weighted=True, + is_diff=True, + ) + self.assertIn(pd_key, got_metrics) + self.assertAlmostEqual( + got_metrics[pd_key], + (2 * 0.1 / 0.3 * 1 + 2 * 0.1 / 0.5 * 2 + 2 * 0.1 / 1.7 * 3) / 6, + ) + + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(result, check_result, label="result") + + def testSymmetricPredictionDifferenceEpsilon(self): + eval_config = text_format.Parse( + """ model_specs { name: "baseline" is_baseline: true @@ -123,58 +134,65 @@ def testSymmetricPredictionDifferenceEpsilon(self): model_specs { name: "candidate" } - """, config_pb2.EvalConfig()) - baseline_model_name = 'baseline' - candidate_model_name = 'candidate' - computations = prediction_difference_metrics.SymmetricPredictionDifference( - ).computations( - eval_config=eval_config, - model_names=['baseline', 'candidate'], - output_names=[''], - example_weighted=True) - self.assertLen(computations, 1) - computation = computations[0] - - examples = [{ - 'labels': [0], - 'example_weights': [1], - 'predictions': { - baseline_model_name: [0.1], - candidate_model_name: [0.1] - } - }] - - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'Create' >> beam.Create(examples) - | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs) - | 'AddSlice' >> beam.Map(lambda x: ((), x)) - | 'ComputeMetric' >> beam.CombinePerKey(computation.combiner)) - - # pylint: enable=no-value-for-parameter - - def check_result(got): - try: - self.assertLen(got, 1) - got_slice_key, got_metrics = got[0] - self.assertEqual(got_slice_key, ()) - pd_key = metric_types.MetricKey( - name=prediction_difference_metrics - .SYMMETRIC_PREDICITON_DIFFERENCE_NAME, - model_name=candidate_model_name, - output_name='', - example_weighted=True, - is_diff=True) - self.assertIn(pd_key, got_metrics) - self.assertAlmostEqual(got_metrics[pd_key], 0.0) - - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(result, check_result, label='result') - - -if __name__ == '__main__': - absltest.main() + """, + config_pb2.EvalConfig(), + ) + baseline_model_name = "baseline" + candidate_model_name = "candidate" + computations = ( + prediction_difference_metrics.SymmetricPredictionDifference().computations( + eval_config=eval_config, + model_names=["baseline", "candidate"], + output_names=[""], + example_weighted=True, + ) + ) + self.assertLen(computations, 1) + computation = computations[0] + + examples = [ + { + "labels": [0], + "example_weights": [1], + "predictions": { + baseline_model_name: [0.1], + candidate_model_name: [0.1], + }, + } + ] + + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "Create" >> beam.Create(examples) + | "Process" >> beam.Map(metric_util.to_standard_metric_inputs) + | "AddSlice" >> beam.Map(lambda x: ((), x)) + | "ComputeMetric" >> beam.CombinePerKey(computation.combiner) + ) + + # pylint: enable=no-value-for-parameter + + def check_result(got): + try: + self.assertLen(got, 1) + got_slice_key, got_metrics = got[0] + self.assertEqual(got_slice_key, ()) + pd_key = metric_types.MetricKey( + name=prediction_difference_metrics.SYMMETRIC_PREDICITON_DIFFERENCE_NAME, + model_name=candidate_model_name, + output_name="", + example_weighted=True, + is_diff=True, + ) + self.assertIn(pd_key, got_metrics) + self.assertAlmostEqual(got_metrics[pd_key], 0.0) + + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(result, check_result, label="result") + + +if __name__ == "__main__": + absltest.main() diff --git a/tensorflow_model_analysis/metrics/preprocessors/__init__.py b/tensorflow_model_analysis/metrics/preprocessors/__init__.py index 41d3ee86a9..f13e6cb6c3 100644 --- a/tensorflow_model_analysis/metrics/preprocessors/__init__.py +++ b/tensorflow_model_analysis/metrics/preprocessors/__init__.py @@ -12,8 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. """Init module for TensorFlow Model Analysis related preprocssors.""" + from tensorflow_model_analysis.metrics.preprocessors import utils -from tensorflow_model_analysis.metrics.preprocessors.image_preprocessors import DecodeImagePreprocessor -from tensorflow_model_analysis.metrics.preprocessors.invert_logarithm_preprocessors import InvertBinaryLogarithmPreprocessor -from tensorflow_model_analysis.metrics.preprocessors.object_detection_preprocessors import BoundingBoxMatchPreprocessor -from tensorflow_model_analysis.metrics.preprocessors.set_match_preprocessors import SetMatchPreprocessor +from tensorflow_model_analysis.metrics.preprocessors.image_preprocessors import ( + DecodeImagePreprocessor, +) +from tensorflow_model_analysis.metrics.preprocessors.invert_logarithm_preprocessors import ( + InvertBinaryLogarithmPreprocessor, +) +from tensorflow_model_analysis.metrics.preprocessors.object_detection_preprocessors import ( + BoundingBoxMatchPreprocessor, +) +from tensorflow_model_analysis.metrics.preprocessors.set_match_preprocessors import ( + SetMatchPreprocessor, +) diff --git a/tensorflow_model_analysis/metrics/preprocessors/image_preprocessors.py b/tensorflow_model_analysis/metrics/preprocessors/image_preprocessors.py index 51af2f642b..7c9c0e0929 100644 --- a/tensorflow_model_analysis/metrics/preprocessors/image_preprocessors.py +++ b/tensorflow_model_analysis/metrics/preprocessors/image_preprocessors.py @@ -19,140 +19,141 @@ import numpy as np from PIL import Image + from tensorflow_model_analysis import constants from tensorflow_model_analysis.api import types -from tensorflow_model_analysis.metrics import metric_types -from tensorflow_model_analysis.metrics import metric_util +from tensorflow_model_analysis.metrics import metric_types, metric_util from tensorflow_model_analysis.utils import util -_DECODE_IMAGE_PREPROCESSOR_BASE_NAME = 'decode_image_preprocessor' +_DECODE_IMAGE_PREPROCESSOR_BASE_NAME = "decode_image_preprocessor" -def _image_bytes_to_numpy_array( - image_bytes: Union[bytes, np.ndarray] -) -> np.ndarray: - """Read bytes or a numpy scalar of bytes of image and return an ndarray.""" - if isinstance(image_bytes, np.ndarray): - assert image_bytes.size == 1 - image_bytes = image_bytes.item() - image = Image.open(io.BytesIO(image_bytes)) - return np.array(image, dtype=np.uint8) +def _image_bytes_to_numpy_array(image_bytes: Union[bytes, np.ndarray]) -> np.ndarray: + """Read bytes or a numpy scalar of bytes of image and return an ndarray.""" + if isinstance(image_bytes, np.ndarray): + assert image_bytes.size == 1 + image_bytes = image_bytes.item() + image = Image.open(io.BytesIO(image_bytes)) + return np.array(image, dtype=np.uint8) class DecodeImagePreprocessor(metric_types.Preprocessor): - """Read images of label and prediciton from bytes to numpy array.""" - - def __init__( - self, - ground_truth_key: str, - prediction_key: str, - decode_ground_truth: bool = True, - decode_prediction: bool = True, - name: Optional[str] = None, - model_name: str = '', - ): - """Initialize the preprocessor for image reading. - - Args: - ground_truth_key: the key for storing the ground truth of encoded image - with class ids. - prediction_key: the key for storing the predictions of encoded image with - class ids. - decode_ground_truth: If true, the ground truth is assumed to be bytes of - images and will be decoded. - decode_prediction: If true, the prediction is assumed to be bytes of - images and will be decoded. - name: (Optional) name for the preprocessor. - model_name: (Optional) model name (if multi-model evaluation). - """ - if not name: - name = metric_util.generate_private_name_from_arguments( - _DECODE_IMAGE_PREPROCESSOR_BASE_NAME, - ground_truth_key=ground_truth_key, - prediction_key=prediction_key, - ) - super().__init__(name=name) - self._ground_truth_key = ground_truth_key - self._prediction_key = prediction_key - self._model_name = model_name - self._decode_ground_truth = decode_ground_truth - self._decode_prediction = decode_prediction - - def _read_image_in_mutliple_dicts( - self, - key: str, - label_or_prediction: str, - decode_image: bool, - extracts: util.StandardExtracts, - ): - """Reads images from extracts.""" - if label_or_prediction not in ['label', 'prediction']: - raise ValueError( - 'The function could only search in the lables or predictions.' - ) - if label_or_prediction == 'label': - one_dict_to_search = extracts.get_labels(self._model_name) or {} - else: - one_dict_to_search = extracts.get_predictions(self._model_name) or {} - dict_to_search = collections.ChainMap( - one_dict_to_search, - extracts.get_features() or {}, - extracts.get_transformed_features(self._model_name) or {}, - ) - if not dict_to_search or key not in dict_to_search: - raise ValueError(f'{key} is not found in {list(dict_to_search.keys())}') - if decode_image: - result = _image_bytes_to_numpy_array(dict_to_search[key]) - else: - result = dict_to_search[key] - return result - - def process( - self, extracts: types.Extracts - ) -> Iterator[metric_types.StandardMetricInputs]: - """Reads and decodes images from extracts. - - It will search in labels/predictions, features and transformed features. It - also support decoding image from bytes. - - Args: - extracts: A tfma extract contains the image data. - - Yields: - A standard metric input contains the following key and values: - - {'labels'}: A numpy array represents the image of labels. - - {'predictions'}: A numpy array represents the image of predictions. - - {'example_weights'}: (Optional) A numpy array represents the example - weights. - """ - extracts = util.StandardExtracts(extracts) - extracted_result = {} - - extracted_result[constants.LABELS_KEY] = self._read_image_in_mutliple_dicts( - self._ground_truth_key, 'label', self._decode_ground_truth, extracts - ) - - extracted_result[constants.PREDICTIONS_KEY] = ( - self._read_image_in_mutliple_dicts( - self._prediction_key, - 'prediction', - self._decode_prediction, - extracts, - ) - ) - - if ( - extracted_result[constants.LABELS_KEY].shape - != extracted_result[constants.PREDICTIONS_KEY].shape + """Read images of label and prediciton from bytes to numpy array.""" + + def __init__( + self, + ground_truth_key: str, + prediction_key: str, + decode_ground_truth: bool = True, + decode_prediction: bool = True, + name: Optional[str] = None, + model_name: str = "", + ): + """Initialize the preprocessor for image reading. + + Args: + ---- + ground_truth_key: the key for storing the ground truth of encoded image + with class ids. + prediction_key: the key for storing the predictions of encoded image with + class ids. + decode_ground_truth: If true, the ground truth is assumed to be bytes of + images and will be decoded. + decode_prediction: If true, the prediction is assumed to be bytes of + images and will be decoded. + name: (Optional) name for the preprocessor. + model_name: (Optional) model name (if multi-model evaluation). + """ + if not name: + name = metric_util.generate_private_name_from_arguments( + _DECODE_IMAGE_PREPROCESSOR_BASE_NAME, + ground_truth_key=ground_truth_key, + prediction_key=prediction_key, + ) + super().__init__(name=name) + self._ground_truth_key = ground_truth_key + self._prediction_key = prediction_key + self._model_name = model_name + self._decode_ground_truth = decode_ground_truth + self._decode_prediction = decode_prediction + + def _read_image_in_mutliple_dicts( + self, + key: str, + label_or_prediction: str, + decode_image: bool, + extracts: util.StandardExtracts, ): - raise ValueError( - 'The image size of ground truth ' - f'{extracted_result[constants.LABELS_KEY].shape} does not match ' - 'with the image size of prediction ' - f'{extracted_result[constants.PREDICTIONS_KEY]}' - ) + """Reads images from extracts.""" + if label_or_prediction not in ["label", "prediction"]: + raise ValueError( + "The function could only search in the lables or predictions." + ) + if label_or_prediction == "label": + one_dict_to_search = extracts.get_labels(self._model_name) or {} + else: + one_dict_to_search = extracts.get_predictions(self._model_name) or {} + dict_to_search = collections.ChainMap( + one_dict_to_search, + extracts.get_features() or {}, + extracts.get_transformed_features(self._model_name) or {}, + ) + if not dict_to_search or key not in dict_to_search: + raise ValueError(f"{key} is not found in {list(dict_to_search.keys())}") + if decode_image: + result = _image_bytes_to_numpy_array(dict_to_search[key]) + else: + result = dict_to_search[key] + return result + + def process( + self, extracts: types.Extracts + ) -> Iterator[metric_types.StandardMetricInputs]: + """Reads and decodes images from extracts. + + It will search in labels/predictions, features and transformed features. It + also support decoding image from bytes. + + Args: + ---- + extracts: A tfma extract contains the image data. + + Yields: + ------ + A standard metric input contains the following key and values: + - {'labels'}: A numpy array represents the image of labels. + - {'predictions'}: A numpy array represents the image of predictions. + - {'example_weights'}: (Optional) A numpy array represents the example + weights. + """ + extracts = util.StandardExtracts(extracts) + extracted_result = {} + + extracted_result[constants.LABELS_KEY] = self._read_image_in_mutliple_dicts( + self._ground_truth_key, "label", self._decode_ground_truth, extracts + ) - if constants.EXAMPLE_WEIGHTS_KEY in extracts.keys(): - extracted_result = extracts[constants.EXAMPLE_WEIGHTS_KEY] + extracted_result[constants.PREDICTIONS_KEY] = ( + self._read_image_in_mutliple_dicts( + self._prediction_key, + "prediction", + self._decode_prediction, + extracts, + ) + ) - yield metric_util.to_standard_metric_inputs(extracted_result) + if ( + extracted_result[constants.LABELS_KEY].shape + != extracted_result[constants.PREDICTIONS_KEY].shape + ): + raise ValueError( + "The image size of ground truth " + f"{extracted_result[constants.LABELS_KEY].shape} does not match " + "with the image size of prediction " + f"{extracted_result[constants.PREDICTIONS_KEY]}" + ) + + if constants.EXAMPLE_WEIGHTS_KEY in extracts.keys(): + extracted_result = extracts[constants.EXAMPLE_WEIGHTS_KEY] + + yield metric_util.to_standard_metric_inputs(extracted_result) diff --git a/tensorflow_model_analysis/metrics/preprocessors/image_preprocessors_test.py b/tensorflow_model_analysis/metrics/preprocessors/image_preprocessors_test.py index cf2f6ab615..9de2fb4f88 100644 --- a/tensorflow_model_analysis/metrics/preprocessors/image_preprocessors_test.py +++ b/tensorflow_model_analysis/metrics/preprocessors/image_preprocessors_test.py @@ -14,152 +14,154 @@ """Tests for image related preprocessors.""" import io -from absl.testing import absltest -from absl.testing import parameterized + import apache_beam as beam -from apache_beam.testing import util as beam_testing_util import numpy as np +from absl.testing import absltest, parameterized +from apache_beam.testing import util as beam_testing_util from PIL import Image + from tensorflow_model_analysis import constants from tensorflow_model_analysis.metrics.preprocessors import image_preprocessors from tensorflow_model_analysis.utils import util class ImageDecodeTest(parameterized.TestCase): + def setUp(self): + super().setUp() + image_array = np.array([[1, 2, 3], [1, 2, 3]], dtype=np.uint8) + image = Image.fromarray(image_array) + encoded_buffer = io.BytesIO() + image.save(encoded_buffer, format="PNG") + encoded_image = encoded_buffer.getvalue() - def setUp(self): - super().setUp() - image_array = np.array([[1, 2, 3], [1, 2, 3]], dtype=np.uint8) - image = Image.fromarray(image_array) - encoded_buffer = io.BytesIO() - image.save(encoded_buffer, format='PNG') - encoded_image = encoded_buffer.getvalue() - - self._extract_inputs = [ - util.StandardExtracts({ - constants.LABELS_KEY: { - 'image/encoded': encoded_image, - 'image/decoded': image_array, - 'number_of_classes': 2, - }, - constants.PREDICTIONS_KEY: { - 'image/pred/encoded': encoded_image, - 'image/pred/decoded': image_array, - 'number_of_classes': 3, - }, - }) - ] - self._expected_processed_inputs = [ - util.StandardExtracts({ - constants.LABELS_KEY: image_array, - constants.PREDICTIONS_KEY: image_array, - }) - ] - - @parameterized.named_parameters( - ( - 'decode_encoded_image', - image_preprocessors.DecodeImagePreprocessor( - ground_truth_key='image/encoded', - prediction_key='image/pred/encoded', - ), - ), - ( - 'no_decode_image', - image_preprocessors.DecodeImagePreprocessor( - ground_truth_key='image/decoded', - prediction_key='image/pred/decoded', - decode_ground_truth=False, - decode_prediction=False, - ), - ), - ) - def testImageDecodePreprocessor(self, preprocessor): - with beam.Pipeline() as p: - updated_pcoll = ( - p - | 'Create' >> beam.Create(self._extract_inputs) - | 'Preprocess' >> beam.ParDo(preprocessor) - ) - - def check_result(result): - # Only single extract case is tested - self.assertLen(result, len(self._expected_processed_inputs)) - for updated_extracts, expected_input in zip( - result, self._expected_processed_inputs - ): - self.assertIn(constants.PREDICTIONS_KEY, updated_extracts) - np.testing.assert_allclose( - updated_extracts[constants.PREDICTIONS_KEY], - expected_input[constants.PREDICTIONS_KEY], - ) - self.assertIn(constants.LABELS_KEY, updated_extracts) - np.testing.assert_allclose( - updated_extracts[constants.LABELS_KEY], - expected_input[constants.LABELS_KEY], - ) - if constants.EXAMPLE_WEIGHTS_KEY in expected_input: - self.assertIn(constants.EXAMPLE_WEIGHTS_KEY, updated_extracts) - np.testing.assert_allclose( - updated_extracts[constants.EXAMPLE_WEIGHTS_KEY], - expected_input[constants.EXAMPLE_WEIGHTS_KEY], + self._extract_inputs = [ + util.StandardExtracts( + { + constants.LABELS_KEY: { + "image/encoded": encoded_image, + "image/decoded": image_array, + "number_of_classes": 2, + }, + constants.PREDICTIONS_KEY: { + "image/pred/encoded": encoded_image, + "image/pred/decoded": image_array, + "number_of_classes": 3, + }, + } ) + ] + self._expected_processed_inputs = [ + util.StandardExtracts( + { + constants.LABELS_KEY: image_array, + constants.PREDICTIONS_KEY: image_array, + } + ) + ] - beam_testing_util.assert_that(updated_pcoll, check_result) - - def testName(self): - preprocessor = image_preprocessors.DecodeImagePreprocessor( - ground_truth_key='image/encoded', - prediction_key='image/pred/encoded', - ) - self.assertEqual( - preprocessor.name, + @parameterized.named_parameters( ( - '_decode_image_preprocessor' - ':ground_truth_key=image/encoded,' - 'prediction_key=image/pred/encoded' + "decode_encoded_image", + image_preprocessors.DecodeImagePreprocessor( + ground_truth_key="image/encoded", + prediction_key="image/pred/encoded", + ), + ), + ( + "no_decode_image", + image_preprocessors.DecodeImagePreprocessor( + ground_truth_key="image/decoded", + prediction_key="image/pred/decoded", + decode_ground_truth=False, + decode_prediction=False, + ), ), ) + def testImageDecodePreprocessor(self, preprocessor): + with beam.Pipeline() as p: + updated_pcoll = ( + p + | "Create" >> beam.Create(self._extract_inputs) + | "Preprocess" >> beam.ParDo(preprocessor) + ) + + def check_result(result): + # Only single extract case is tested + self.assertLen(result, len(self._expected_processed_inputs)) + for updated_extracts, expected_input in zip( + result, self._expected_processed_inputs + ): + self.assertIn(constants.PREDICTIONS_KEY, updated_extracts) + np.testing.assert_allclose( + updated_extracts[constants.PREDICTIONS_KEY], + expected_input[constants.PREDICTIONS_KEY], + ) + self.assertIn(constants.LABELS_KEY, updated_extracts) + np.testing.assert_allclose( + updated_extracts[constants.LABELS_KEY], + expected_input[constants.LABELS_KEY], + ) + if constants.EXAMPLE_WEIGHTS_KEY in expected_input: + self.assertIn(constants.EXAMPLE_WEIGHTS_KEY, updated_extracts) + np.testing.assert_allclose( + updated_extracts[constants.EXAMPLE_WEIGHTS_KEY], + expected_input[constants.EXAMPLE_WEIGHTS_KEY], + ) - def testLabelNotFoundImage(self): - with self.assertRaisesRegex(ValueError, 'image/encodederror is not found'): - _ = next( - image_preprocessors.DecodeImagePreprocessor( - ground_truth_key='image/encodederror', - prediction_key='image/pred/encoded', - ).process(extracts=self._extract_inputs[0]) - ) + beam_testing_util.assert_that(updated_pcoll, check_result) - def testPredictionNotFoundImage(self): - with self.assertRaisesRegex( - ValueError, 'image/pred/encodederror is not found' - ): - _ = next( - image_preprocessors.DecodeImagePreprocessor( - ground_truth_key='image/encoded', - prediction_key='image/pred/encodederror', - ).process(extracts=self._extract_inputs[0]) - ) + def testName(self): + preprocessor = image_preprocessors.DecodeImagePreprocessor( + ground_truth_key="image/encoded", + prediction_key="image/pred/encoded", + ) + self.assertEqual( + preprocessor.name, + ( + "_decode_image_preprocessor" + ":ground_truth_key=image/encoded," + "prediction_key=image/pred/encoded" + ), + ) - def testLabelPreidictionImageSizeMismatch(self): - extracts = { - constants.LABELS_KEY: { - 'image/encoded': np.array([[1, 2]]), - }, - constants.PREDICTIONS_KEY: { - 'image/pred/encoded': np.array([[1, 2, 3]]), - }, - } - with self.assertRaisesRegex(ValueError, 'does not match'): - _ = next( - image_preprocessors.DecodeImagePreprocessor( - ground_truth_key='image/encoded', - prediction_key='image/pred/encoded', - decode_ground_truth=False, - decode_prediction=False, - ).process(extracts=extracts) - ) + def testLabelNotFoundImage(self): + with self.assertRaisesRegex(ValueError, "image/encodederror is not found"): + _ = next( + image_preprocessors.DecodeImagePreprocessor( + ground_truth_key="image/encodederror", + prediction_key="image/pred/encoded", + ).process(extracts=self._extract_inputs[0]) + ) + + def testPredictionNotFoundImage(self): + with self.assertRaisesRegex(ValueError, "image/pred/encodederror is not found"): + _ = next( + image_preprocessors.DecodeImagePreprocessor( + ground_truth_key="image/encoded", + prediction_key="image/pred/encodederror", + ).process(extracts=self._extract_inputs[0]) + ) + + def testLabelPreidictionImageSizeMismatch(self): + extracts = { + constants.LABELS_KEY: { + "image/encoded": np.array([[1, 2]]), + }, + constants.PREDICTIONS_KEY: { + "image/pred/encoded": np.array([[1, 2, 3]]), + }, + } + with self.assertRaisesRegex(ValueError, "does not match"): + _ = next( + image_preprocessors.DecodeImagePreprocessor( + ground_truth_key="image/encoded", + prediction_key="image/pred/encoded", + decode_ground_truth=False, + decode_prediction=False, + ).process(extracts=extracts) + ) -if __name__ == '__main__': - absltest.main() +if __name__ == "__main__": + absltest.main() diff --git a/tensorflow_model_analysis/metrics/preprocessors/invert_logarithm_preprocessors.py b/tensorflow_model_analysis/metrics/preprocessors/invert_logarithm_preprocessors.py index b4a76699ec..9cc16b8c80 100644 --- a/tensorflow_model_analysis/metrics/preprocessors/invert_logarithm_preprocessors.py +++ b/tensorflow_model_analysis/metrics/preprocessors/invert_logarithm_preprocessors.py @@ -16,111 +16,110 @@ from typing import Iterator, Optional import numpy as np + from tensorflow_model_analysis import constants from tensorflow_model_analysis.api import types -from tensorflow_model_analysis.metrics import metric_types -from tensorflow_model_analysis.metrics import metric_util +from tensorflow_model_analysis.metrics import metric_types, metric_util from tensorflow_model_analysis.utils import util -_INVERT_BINARY_LOGARITHM_PREPROCESSOR_BASE_NAME = ( - 'invert_binary_logarithm_preprocessor' -) +_INVERT_BINARY_LOGARITHM_PREPROCESSOR_BASE_NAME = "invert_binary_logarithm_preprocessor" def _invert_log2_values( log_values: np.ndarray, ) -> np.ndarray: - """Invert the binary logarithm and return an ndarray.""" - # We invert the following formula: log_2(y_pred + 1.0) - return np.power(2.0, log_values) - 1.0 + """Invert the binary logarithm and return an ndarray.""" + # We invert the following formula: log_2(y_pred + 1.0) + return np.power(2.0, log_values) - 1.0 class InvertBinaryLogarithmPreprocessor(metric_types.Preprocessor): - """Read label and prediction from binary logarithm to numpy array.""" - - def __init__( - self, - name: Optional[str] = None, - model_name: str = '', - prediction_winsorisation_limit_max: Optional[float] = None, - ): - """Initialize the preprocessor for binary logarithm inversion. - - Args: - name: (Optional) name for the preprocessor. - model_name: (Optional) model name (if multi-model evaluation). - prediction_winsorisation_limit_max: should the winsorisation max limit be - applied to the predictions. - """ - if not name: - name = metric_util.generate_private_name_from_arguments( - _INVERT_BINARY_LOGARITHM_PREPROCESSOR_BASE_NAME - ) - super().__init__(name=name) - self._model_name = model_name - self._prediction_winsorisation_limit_max = ( - prediction_winsorisation_limit_max - ) - - def _read_label_or_prediction_in_multiple_dicts( - self, - key: str, - extracts: util.StandardExtracts, - ) -> np.ndarray: - """Reads and inverts the binary logarithm from extracts.""" - if key == constants.LABELS_KEY: - value = extracts.get_labels(self._model_name) - else: - value = extracts.get_predictions(self._model_name) - return _invert_log2_values(value) - - def process( - self, extracts: types.Extracts - ) -> Iterator[metric_types.StandardMetricInputs]: - """Reads and inverts the binary logarithm from extracts. - - It will search in labels/predictions, features and transformed features. - - Args: - extracts: A tfma extract contains the regression data. - - Yields: - A standard metric input contains the following key and values: - - {'labels'}: A numpy array represents the regressed values. - - {'predictions'}: A numpy array represents the regression predictions. - - {'example_weights'}: (Optional) A numpy array represents the example - weights. - """ - extracts = util.StandardExtracts(extracts) - - extracts[constants.LABELS_KEY] = ( - self._read_label_or_prediction_in_multiple_dicts( - constants.LABELS_KEY, extracts - ) - ) - - predictions = self._read_label_or_prediction_in_multiple_dicts( - constants.PREDICTIONS_KEY, - extracts, - ) - if self._prediction_winsorisation_limit_max is not None: - np.clip( - predictions, - 0.0, - self._prediction_winsorisation_limit_max, - out=predictions, - ) - extracts[constants.PREDICTIONS_KEY] = predictions - - if ( - extracts[constants.LABELS_KEY].shape - != extracts[constants.PREDICTIONS_KEY].shape + """Read label and prediction from binary logarithm to numpy array.""" + + def __init__( + self, + name: Optional[str] = None, + model_name: str = "", + prediction_winsorisation_limit_max: Optional[float] = None, ): - raise ValueError( - 'The size of ground truth ' - f'{extracts[constants.LABELS_KEY].shape} does not match ' - 'with the size of prediction ' - f'{extracts[constants.PREDICTIONS_KEY].shape}' - ) - - yield metric_util.to_standard_metric_inputs(extracts) + """Initialize the preprocessor for binary logarithm inversion. + + Args: + ---- + name: (Optional) name for the preprocessor. + model_name: (Optional) model name (if multi-model evaluation). + prediction_winsorisation_limit_max: should the winsorisation max limit be + applied to the predictions. + """ + if not name: + name = metric_util.generate_private_name_from_arguments( + _INVERT_BINARY_LOGARITHM_PREPROCESSOR_BASE_NAME + ) + super().__init__(name=name) + self._model_name = model_name + self._prediction_winsorisation_limit_max = prediction_winsorisation_limit_max + + def _read_label_or_prediction_in_multiple_dicts( + self, + key: str, + extracts: util.StandardExtracts, + ) -> np.ndarray: + """Reads and inverts the binary logarithm from extracts.""" + if key == constants.LABELS_KEY: + value = extracts.get_labels(self._model_name) + else: + value = extracts.get_predictions(self._model_name) + return _invert_log2_values(value) + + def process( + self, extracts: types.Extracts + ) -> Iterator[metric_types.StandardMetricInputs]: + """Reads and inverts the binary logarithm from extracts. + + It will search in labels/predictions, features and transformed features. + + Args: + ---- + extracts: A tfma extract contains the regression data. + + Yields: + ------ + A standard metric input contains the following key and values: + - {'labels'}: A numpy array represents the regressed values. + - {'predictions'}: A numpy array represents the regression predictions. + - {'example_weights'}: (Optional) A numpy array represents the example + weights. + """ + extracts = util.StandardExtracts(extracts) + + extracts[constants.LABELS_KEY] = ( + self._read_label_or_prediction_in_multiple_dicts( + constants.LABELS_KEY, extracts + ) + ) + + predictions = self._read_label_or_prediction_in_multiple_dicts( + constants.PREDICTIONS_KEY, + extracts, + ) + if self._prediction_winsorisation_limit_max is not None: + np.clip( + predictions, + 0.0, + self._prediction_winsorisation_limit_max, + out=predictions, + ) + extracts[constants.PREDICTIONS_KEY] = predictions + + if ( + extracts[constants.LABELS_KEY].shape + != extracts[constants.PREDICTIONS_KEY].shape + ): + raise ValueError( + "The size of ground truth " + f"{extracts[constants.LABELS_KEY].shape} does not match " + "with the size of prediction " + f"{extracts[constants.PREDICTIONS_KEY].shape}" + ) + + yield metric_util.to_standard_metric_inputs(extracts) diff --git a/tensorflow_model_analysis/metrics/preprocessors/invert_logarithm_preprocessors_test.py b/tensorflow_model_analysis/metrics/preprocessors/invert_logarithm_preprocessors_test.py index 3b15d511c1..b7a7d12c35 100644 --- a/tensorflow_model_analysis/metrics/preprocessors/invert_logarithm_preprocessors_test.py +++ b/tensorflow_model_analysis/metrics/preprocessors/invert_logarithm_preprocessors_test.py @@ -13,114 +13,117 @@ # limitations under the License. """Tests for invert logarithm preprocessors.""" -from absl.testing import absltest -from absl.testing import parameterized import apache_beam as beam -from apache_beam.testing import util as beam_testing_util import numpy as np +from absl.testing import absltest, parameterized +from apache_beam.testing import util as beam_testing_util + from tensorflow_model_analysis import constants -from tensorflow_model_analysis.metrics.preprocessors import invert_logarithm_preprocessors +from tensorflow_model_analysis.metrics.preprocessors import ( + invert_logarithm_preprocessors, +) from tensorflow_model_analysis.utils import util class InvertBinaryLogarithmPreprocessorTest(parameterized.TestCase): + def setUp(self): + super().setUp() + values = np.array([[1, 2, 4], [1, 2, 4]], dtype=np.int32) - def setUp(self): - super().setUp() - values = np.array([[1, 2, 4], [1, 2, 4]], dtype=np.int32) + self._extract_inputs = [ + { + constants.LABELS_KEY: values, + constants.PREDICTIONS_KEY: values, + } + ] - self._extract_inputs = [{ - constants.LABELS_KEY: values, - constants.PREDICTIONS_KEY: values, - }] - - @parameterized.named_parameters( - ( - 'NoWinsorisation', - None, - np.array([[1, 3, 15], [1, 3, 15]], dtype=np.float32), - np.array([[1, 3, 15], [1, 3, 15]], dtype=np.float32), - ), - ( - 'Winsorised', - 1.0, - np.array([[1, 3, 15], [1, 3, 15]], dtype=np.float32), - np.array([[1, 1, 1], [1, 1, 1]], dtype=np.float32), - ), - ) - def testInvertBinaryLogarithmPreprocessor( - self, - prediction_winsorisation_limit_max, - processed_labels, - processed_predictions, - ): - expected_processed_inputs = [ - util.StandardExtracts({ - constants.LABELS_KEY: processed_labels, - constants.PREDICTIONS_KEY: processed_predictions, - }) - ] + @parameterized.named_parameters( + ( + "NoWinsorisation", + None, + np.array([[1, 3, 15], [1, 3, 15]], dtype=np.float32), + np.array([[1, 3, 15], [1, 3, 15]], dtype=np.float32), + ), + ( + "Winsorised", + 1.0, + np.array([[1, 3, 15], [1, 3, 15]], dtype=np.float32), + np.array([[1, 1, 1], [1, 1, 1]], dtype=np.float32), + ), + ) + def testInvertBinaryLogarithmPreprocessor( + self, + prediction_winsorisation_limit_max, + processed_labels, + processed_predictions, + ): + expected_processed_inputs = [ + util.StandardExtracts( + { + constants.LABELS_KEY: processed_labels, + constants.PREDICTIONS_KEY: processed_predictions, + } + ) + ] - def check_result(result, expected_processed_inputs): - # Only single extract case is tested - self.assertLen(result, len(expected_processed_inputs)) - for updated_extracts, expected_input in zip( - result, expected_processed_inputs - ): - self.assertIn(constants.PREDICTIONS_KEY, updated_extracts) - np.testing.assert_allclose( - updated_extracts[constants.PREDICTIONS_KEY], - expected_input[constants.PREDICTIONS_KEY], - ) - self.assertIn(constants.LABELS_KEY, updated_extracts) - np.testing.assert_allclose( - updated_extracts[constants.LABELS_KEY], - expected_input[constants.LABELS_KEY], - ) - if constants.EXAMPLE_WEIGHTS_KEY in expected_input: - self.assertIn(constants.EXAMPLE_WEIGHTS_KEY, updated_extracts) - np.testing.assert_allclose( - updated_extracts[constants.EXAMPLE_WEIGHTS_KEY], - expected_input[constants.EXAMPLE_WEIGHTS_KEY], - ) + def check_result(result, expected_processed_inputs): + # Only single extract case is tested + self.assertLen(result, len(expected_processed_inputs)) + for updated_extracts, expected_input in zip( + result, expected_processed_inputs + ): + self.assertIn(constants.PREDICTIONS_KEY, updated_extracts) + np.testing.assert_allclose( + updated_extracts[constants.PREDICTIONS_KEY], + expected_input[constants.PREDICTIONS_KEY], + ) + self.assertIn(constants.LABELS_KEY, updated_extracts) + np.testing.assert_allclose( + updated_extracts[constants.LABELS_KEY], + expected_input[constants.LABELS_KEY], + ) + if constants.EXAMPLE_WEIGHTS_KEY in expected_input: + self.assertIn(constants.EXAMPLE_WEIGHTS_KEY, updated_extracts) + np.testing.assert_allclose( + updated_extracts[constants.EXAMPLE_WEIGHTS_KEY], + expected_input[constants.EXAMPLE_WEIGHTS_KEY], + ) - with beam.Pipeline() as pipeline: - updated_pcoll = ( - pipeline - | 'Create' >> beam.Create(self._extract_inputs) - | 'Preprocess' - >> beam.ParDo( - invert_logarithm_preprocessors.InvertBinaryLogarithmPreprocessor( - prediction_winsorisation_limit_max=prediction_winsorisation_limit_max - ) - ) - ) + with beam.Pipeline() as pipeline: + updated_pcoll = ( + pipeline + | "Create" >> beam.Create(self._extract_inputs) + | "Preprocess" + >> beam.ParDo( + invert_logarithm_preprocessors.InvertBinaryLogarithmPreprocessor( + prediction_winsorisation_limit_max=prediction_winsorisation_limit_max + ) + ) + ) - beam_testing_util.assert_that( - updated_pcoll, - lambda result: check_result(result, expected_processed_inputs), - ) + beam_testing_util.assert_that( + updated_pcoll, + lambda result: check_result(result, expected_processed_inputs), + ) - def testName(self): - preprocessor = ( - invert_logarithm_preprocessors.InvertBinaryLogarithmPreprocessor() - ) - self.assertEqual( - preprocessor.name, '_invert_binary_logarithm_preprocessor:' - ) + def testName(self): + preprocessor = ( + invert_logarithm_preprocessors.InvertBinaryLogarithmPreprocessor() + ) + self.assertEqual(preprocessor.name, "_invert_binary_logarithm_preprocessor:") - def testLabelPreidictionSizeMismatch(self): - extracts = { - constants.LABELS_KEY: np.array([[1, 2]]), - constants.PREDICTIONS_KEY: np.array([[1, 2, 3]]), - } - with self.assertRaisesRegex(ValueError, 'does not match'): - _ = next( - invert_logarithm_preprocessors.InvertBinaryLogarithmPreprocessor().process( - extracts=extracts - ) - ) + def testLabelPreidictionSizeMismatch(self): + extracts = { + constants.LABELS_KEY: np.array([[1, 2]]), + constants.PREDICTIONS_KEY: np.array([[1, 2, 3]]), + } + with self.assertRaisesRegex(ValueError, "does not match"): + _ = next( + invert_logarithm_preprocessors.InvertBinaryLogarithmPreprocessor().process( + extracts=extracts + ) + ) -if __name__ == '__main__': - absltest.main() +if __name__ == "__main__": + absltest.main() diff --git a/tensorflow_model_analysis/metrics/preprocessors/object_detection_preprocessors.py b/tensorflow_model_analysis/metrics/preprocessors/object_detection_preprocessors.py index ef3ec474f5..8dfd35a5a4 100644 --- a/tensorflow_model_analysis/metrics/preprocessors/object_detection_preprocessors.py +++ b/tensorflow_model_analysis/metrics/preprocessors/object_detection_preprocessors.py @@ -22,155 +22,173 @@ from tensorflow_model_analysis import constants from tensorflow_model_analysis.api import types -from tensorflow_model_analysis.metrics import metric_types -from tensorflow_model_analysis.metrics import metric_util -from tensorflow_model_analysis.metrics.preprocessors.utils import bounding_box -from tensorflow_model_analysis.metrics.preprocessors.utils import box_match -from tensorflow_model_analysis.metrics.preprocessors.utils import object_detection_format +from tensorflow_model_analysis.metrics import metric_types, metric_util +from tensorflow_model_analysis.metrics.preprocessors.utils import ( + bounding_box, + box_match, + object_detection_format, +) from tensorflow_model_analysis.utils import util # indices for the inputs, it should be arranged in the following format: LEFT, TOP, RIGHT, BOTTOM, CLASS, CONFIDENCE = range(6) -_DEFAULT_BOUNDING_BOX_MATCH_PREPROCESSOR_NAME = 'bounding_box_match_preprocessor' +_DEFAULT_BOUNDING_BOX_MATCH_PREPROCESSOR_NAME = "bounding_box_match_preprocessor" _DEFAULT_IOU_THRESHOLD = 0.5 -_DEFAULT_AREA_RANGE = (0, float('inf')) +_DEFAULT_AREA_RANGE = (0, float("inf")) class BoundingBoxMatchPreprocessor(metric_types.Preprocessor): - """Computes label and prediction pairs for object detection.""" + """Computes label and prediction pairs for object detection.""" - def __init__(self, - class_id: int, - iou_threshold: Optional[float] = 0.5, - area_range: Tuple[float, float] = (0, float('inf')), - max_num_detections: Optional[int] = None, - class_weight: Optional[float] = None, - name: Optional[str] = None, - labels_to_stack: Optional[List[str]] = None, - predictions_to_stack: Optional[List[str]] = None, - num_detections_key: Optional[str] = None, - allow_missing_key: bool = False, - model_name: str = ''): - """Initialize the preprocessor for bounding box match. + def __init__( + self, + class_id: int, + iou_threshold: Optional[float] = 0.5, + area_range: Tuple[float, float] = (0, float("inf")), + max_num_detections: Optional[int] = None, + class_weight: Optional[float] = None, + name: Optional[str] = None, + labels_to_stack: Optional[List[str]] = None, + predictions_to_stack: Optional[List[str]] = None, + num_detections_key: Optional[str] = None, + allow_missing_key: bool = False, + model_name: str = "", + ): + """Initialize the preprocessor for bounding box match. - Args: - class_id: Used for object detection, the class id for calculating metrics. - It must be provided if use_object_detection is True. - iou_threshold: (Optional) Used for object detection, threholds for a - detection and ground truth pair with specific iou to be considered as a - match. - area_range: (Optional) Used for object detection, the area-range for - objects to be considered for metrics. - max_num_detections: (Optional) Used for object detection, the maximum - number of detections for a single image. - class_weight: (Optional) Used for object detection, the weight associated - with the object class id. - name: (Optional) name for the preprocessor. - labels_to_stack: (Optional) Keys for columns to be stacked as a single - numpy array as the labels. It is searched under the key labels, features - and transformed features. The desired format is [left bounadary, top - boudnary, right boundary, bottom boundary, class id]. e.g. ['xmin', - 'ymin', 'xmax', 'ymax', 'class_id'] - predictions_to_stack: (Optional) Keys for columns to be stacked as a - single numpy array as the prediction. It should be the model's output - names. The desired format is [left bounadary, top boudnary, right - boundary, bottom boundary, class id, confidence score]. e.g. ['xmin', - 'ymin', 'xmax', 'ymax', 'class_id', 'scores'] - num_detections_key: (Optional) Number of detections in each column except - the paddings. - allow_missing_key: (Optional) If true, the preprocessor will return empty - array instead of raising errors. - model_name: Optional model name (if multi-model evaluation). - """ - if not name: - name = metric_util.generate_private_name_from_arguments( - _DEFAULT_BOUNDING_BOX_MATCH_PREPROCESSOR_NAME, - class_id=class_id, - iou_threshold=iou_threshold, - area_range=area_range, - max_num_detections=max_num_detections, - class_weight=class_weight) - if not area_range: - area_range = _DEFAULT_AREA_RANGE - if iou_threshold is None: - iou_threshold = _DEFAULT_IOU_THRESHOLD - super().__init__(name=name) - self._threshold = iou_threshold - self._class_id = class_id - self._area_range = area_range - self._max_num_detections = max_num_detections - self._class_weight = class_weight - self._labels_to_stack = labels_to_stack - self._predictions_to_stack = predictions_to_stack - self._num_detections_key = num_detections_key - self._allow_missing_key = allow_missing_key - self._model_name = model_name + Args: + ---- + class_id: Used for object detection, the class id for calculating metrics. + It must be provided if use_object_detection is True. + iou_threshold: (Optional) Used for object detection, threholds for a + detection and ground truth pair with specific iou to be considered as a + match. + area_range: (Optional) Used for object detection, the area-range for + objects to be considered for metrics. + max_num_detections: (Optional) Used for object detection, the maximum + number of detections for a single image. + class_weight: (Optional) Used for object detection, the weight associated + with the object class id. + name: (Optional) name for the preprocessor. + labels_to_stack: (Optional) Keys for columns to be stacked as a single + numpy array as the labels. It is searched under the key labels, features + and transformed features. The desired format is [left bounadary, top + boudnary, right boundary, bottom boundary, class id]. e.g. ['xmin', + 'ymin', 'xmax', 'ymax', 'class_id'] + predictions_to_stack: (Optional) Keys for columns to be stacked as a + single numpy array as the prediction. It should be the model's output + names. The desired format is [left bounadary, top boudnary, right + boundary, bottom boundary, class id, confidence score]. e.g. ['xmin', + 'ymin', 'xmax', 'ymax', 'class_id', 'scores'] + num_detections_key: (Optional) Number of detections in each column except + the paddings. + allow_missing_key: (Optional) If true, the preprocessor will return empty + array instead of raising errors. + model_name: Optional model name (if multi-model evaluation). + """ + if not name: + name = metric_util.generate_private_name_from_arguments( + _DEFAULT_BOUNDING_BOX_MATCH_PREPROCESSOR_NAME, + class_id=class_id, + iou_threshold=iou_threshold, + area_range=area_range, + max_num_detections=max_num_detections, + class_weight=class_weight, + ) + if not area_range: + area_range = _DEFAULT_AREA_RANGE + if iou_threshold is None: + iou_threshold = _DEFAULT_IOU_THRESHOLD + super().__init__(name=name) + self._threshold = iou_threshold + self._class_id = class_id + self._area_range = area_range + self._max_num_detections = max_num_detections + self._class_weight = class_weight + self._labels_to_stack = labels_to_stack + self._predictions_to_stack = predictions_to_stack + self._num_detections_key = num_detections_key + self._allow_missing_key = allow_missing_key + self._model_name = model_name - def process( - self, - extracts: types.Extracts) -> Iterator[metric_types.StandardMetricInputs]: - # stack all the columns of labels and predictions(e.g. xmin, xmax, ymin, - # ymax, and class_id of bounding boxes) into one single numpy array - extracts = util.StandardExtracts(extracts) - if self._labels_to_stack: - extracts[constants.LABELS_KEY] = object_detection_format.stack_labels( - extracts=extracts, - col_names=self._labels_to_stack, - model_name=self._model_name, - allow_missing_key=self._allow_missing_key) - if self._predictions_to_stack: - predictions = object_detection_format.stack_predictions( - extracts=extracts, - col_names=self._predictions_to_stack, - model_name=self._model_name, - allow_missing_key=self._allow_missing_key) - else: - predictions = extracts.get_predictions(model_name=self._model_name) + def process( + self, extracts: types.Extracts + ) -> Iterator[metric_types.StandardMetricInputs]: + # stack all the columns of labels and predictions(e.g. xmin, xmax, ymin, + # ymax, and class_id of bounding boxes) into one single numpy array + extracts = util.StandardExtracts(extracts) + if self._labels_to_stack: + extracts[constants.LABELS_KEY] = object_detection_format.stack_labels( + extracts=extracts, + col_names=self._labels_to_stack, + model_name=self._model_name, + allow_missing_key=self._allow_missing_key, + ) + if self._predictions_to_stack: + predictions = object_detection_format.stack_predictions( + extracts=extracts, + col_names=self._predictions_to_stack, + model_name=self._model_name, + allow_missing_key=self._allow_missing_key, + ) + else: + predictions = extracts.get_predictions(model_name=self._model_name) - if self._num_detections_key and predictions.size: - predictions = ( - object_detection_format.truncate_by_num_detections( - extracts=extracts, - num_rows_key=self._num_detections_key, - array_to_truncate=predictions, - model_name=self._model_name, - allow_missing_key=self._allow_missing_key)) + if self._num_detections_key and predictions.size: + predictions = object_detection_format.truncate_by_num_detections( + extracts=extracts, + num_rows_key=self._num_detections_key, + array_to_truncate=predictions, + model_name=self._model_name, + allow_missing_key=self._allow_missing_key, + ) - extracts[constants.PREDICTIONS_KEY] = predictions + extracts[constants.PREDICTIONS_KEY] = predictions - if extracts[constants.LABELS_KEY].ndim != 2 or extracts[ - constants.LABELS_KEY].shape[1] <= 4: - raise ValueError('Raw data of ground truth should be a 2d array of shape ' - '( , 5+), ground truth is ' - f'{extracts[constants.LABELS_KEY]}') - if extracts[constants.PREDICTIONS_KEY].ndim != 2 or extracts[ - constants.PREDICTIONS_KEY].shape[1] <= 5: - raise ValueError('Raw data of prediction should be a 2d array of shape ' - '( , 6+), prediction is ' - f'{extracts[constants.PREDICTIONS_KEY]}') - boxes_gt = extracts[constants.LABELS_KEY] - boxes_pred = bounding_box.sort_boxes_by_confidence( - extracts[constants.PREDICTIONS_KEY]) - if constants.EXAMPLE_WEIGHTS_KEY in extracts: - weight = extracts[constants.EXAMPLE_WEIGHTS_KEY] - else: - weight = None + if ( + extracts[constants.LABELS_KEY].ndim != 2 + or extracts[constants.LABELS_KEY].shape[1] <= 4 + ): + raise ValueError( + "Raw data of ground truth should be a 2d array of shape " + "( , 5+), ground truth is " + f"{extracts[constants.LABELS_KEY]}" + ) + if ( + extracts[constants.PREDICTIONS_KEY].ndim != 2 + or extracts[constants.PREDICTIONS_KEY].shape[1] <= 5 + ): + raise ValueError( + "Raw data of prediction should be a 2d array of shape " + "( , 6+), prediction is " + f"{extracts[constants.PREDICTIONS_KEY]}" + ) + boxes_gt = extracts[constants.LABELS_KEY] + boxes_pred = bounding_box.sort_boxes_by_confidence( + extracts[constants.PREDICTIONS_KEY] + ) + if constants.EXAMPLE_WEIGHTS_KEY in extracts: + weight = extracts[constants.EXAMPLE_WEIGHTS_KEY] + else: + weight = None - labels, predictions, weights = ( - box_match.boxes_to_label_prediction_example_weight( - boxes_gt=boxes_gt, - boxes_pred=boxes_pred, - iou_threshold=self._threshold, - area_range=self._area_range, - class_id=self._class_id, - max_num_detections=self._max_num_detections, - class_weight=self._class_weight, - weight=weight)) + labels, predictions, weights = ( + box_match.boxes_to_label_prediction_example_weight( + boxes_gt=boxes_gt, + boxes_pred=boxes_pred, + iou_threshold=self._threshold, + area_range=self._area_range, + class_id=self._class_id, + max_num_detections=self._max_num_detections, + class_weight=self._class_weight, + weight=weight, + ) + ) - for l, p, w in zip(labels, predictions, weights): - result = {} - result[constants.LABELS_KEY] = [l] - result[constants.PREDICTIONS_KEY] = [p] - result[constants.EXAMPLE_WEIGHTS_KEY] = [w] - yield metric_util.to_standard_metric_inputs(result) + for l, p, w in zip(labels, predictions, weights): + result = {} + result[constants.LABELS_KEY] = [l] + result[constants.PREDICTIONS_KEY] = [p] + result[constants.EXAMPLE_WEIGHTS_KEY] = [w] + yield metric_util.to_standard_metric_inputs(result) diff --git a/tensorflow_model_analysis/metrics/preprocessors/object_detection_preprocessors_test.py b/tensorflow_model_analysis/metrics/preprocessors/object_detection_preprocessors_test.py index 4a40bd37c7..f3d6ba77fc 100644 --- a/tensorflow_model_analysis/metrics/preprocessors/object_detection_preprocessors_test.py +++ b/tensorflow_model_analysis/metrics/preprocessors/object_detection_preprocessors_test.py @@ -13,13 +13,15 @@ # limitations under the License. """Tests for object_detection_preprocessor.""" -from absl.testing import absltest -from absl.testing import parameterized import apache_beam as beam -from apache_beam.testing import util as beam_testing_util import numpy as np +from absl.testing import absltest, parameterized +from apache_beam.testing import util as beam_testing_util + from tensorflow_model_analysis import constants -from tensorflow_model_analysis.metrics.preprocessors import object_detection_preprocessors +from tensorflow_model_analysis.metrics.preprocessors import ( + object_detection_preprocessors, +) from tensorflow_model_analysis.utils import util # Initialize test data @@ -28,306 +30,412 @@ # [[0.5, 1., 0.], [7 / 87, 2 / 9, 0.]] # The match at iou_threshold = 0.5 is # gt_matches: [[0, -1]] dt_matches: [[0, -1, -1]] -_BOXMATCH_CASE1_BINARY = util.StandardExtracts({ - constants.LABELS_KEY: - np.array([[30, 100, 70, 300, 0], [50, 100, 80, 200, 0]]), - constants.PREDICTIONS_KEY: - np.array([[20, 130, 60, 290, 0, 0.5], [30, 100, 70, 300, 0, 0.3], - [500, 100, 800, 300, 0, 0.1]]) -}) +_BOXMATCH_CASE1_BINARY = util.StandardExtracts( + { + constants.LABELS_KEY: np.array([[30, 100, 70, 300, 0], [50, 100, 80, 200, 0]]), + constants.PREDICTIONS_KEY: np.array( + [ + [20, 130, 60, 290, 0, 0.5], + [30, 100, 70, 300, 0, 0.3], + [500, 100, 800, 300, 0, 0.1], + ] + ), + } +) -_BOXMATCH_CASE1_SPLITFORMAT_BINARY = util.StandardExtracts({ - constants.FEATURES_KEY: { - 'xmin': np.array([30, 50]), - 'ymin': np.array([100, 100]), - 'xmax': np.array([70, 80]), - 'ymax': np.array([300, 200]), - 'class_id': np.array([0, 0]) - }, - constants.PREDICTIONS_KEY: { - 'xmin': np.array([20, 30, 500]), - 'ymin': np.array([130, 100, 100]), - 'xmax': np.array([60, 70, 800]), - 'ymax': np.array([290, 300, 300]), - 'class_id': np.array([0, 0, 0]), - 'score': np.array([0.5, 0.3, 0.1]) - }, -}) +_BOXMATCH_CASE1_SPLITFORMAT_BINARY = util.StandardExtracts( + { + constants.FEATURES_KEY: { + "xmin": np.array([30, 50]), + "ymin": np.array([100, 100]), + "xmax": np.array([70, 80]), + "ymax": np.array([300, 200]), + "class_id": np.array([0, 0]), + }, + constants.PREDICTIONS_KEY: { + "xmin": np.array([20, 30, 500]), + "ymin": np.array([130, 100, 100]), + "xmax": np.array([60, 70, 800]), + "ymax": np.array([290, 300, 300]), + "class_id": np.array([0, 0, 0]), + "score": np.array([0.5, 0.3, 0.1]), + }, + } +) -_BOXMATCH_CASE1_BBOXFORMAT_BINARY = util.StandardExtracts({ - constants.LABELS_KEY: { - 'bbox': np.array([[30, 100, 70, 300], [50, 100, 80, 200]]), - 'class_id': np.array([0, 0]) - }, - constants.FEATURES_KEY: { - 'bbox': - np.array([[20, 130, 60, 290], [30, 100, 70, 300], - [500, 100, 800, 300]]), - 'class_id': - np.array([0, 0, 0]), - 'score': - np.array([0.5, 0.3, 0.1]) - }, -}) +_BOXMATCH_CASE1_BBOXFORMAT_BINARY = util.StandardExtracts( + { + constants.LABELS_KEY: { + "bbox": np.array([[30, 100, 70, 300], [50, 100, 80, 200]]), + "class_id": np.array([0, 0]), + }, + constants.FEATURES_KEY: { + "bbox": np.array( + [[20, 130, 60, 290], [30, 100, 70, 300], [500, 100, 800, 300]] + ), + "class_id": np.array([0, 0, 0]), + "score": np.array([0.5, 0.3, 0.1]), + }, + } +) -_BOXMATCH_CASE1_MULTI_MODEL_SPLITFORMAT_BINARY = util.StandardExtracts({ - constants.LABELS_KEY: {}, - # Searching labels in tranformed features - constants.TRANSFORMED_FEATURES_KEY: { - 'baseline': { - 'xmin': np.array([30, 50]), - 'ymin': np.array([100, 100]), - 'xmax': np.array([70, 80]), - 'ymax': np.array([300, 200]), - 'class_id': np.array([0, 0]) +_BOXMATCH_CASE1_MULTI_MODEL_SPLITFORMAT_BINARY = util.StandardExtracts( + { + constants.LABELS_KEY: {}, + # Searching labels in tranformed features + constants.TRANSFORMED_FEATURES_KEY: { + "baseline": { + "xmin": np.array([30, 50]), + "ymin": np.array([100, 100]), + "xmax": np.array([70, 80]), + "ymax": np.array([300, 200]), + "class_id": np.array([0, 0]), + }, + "model1": { + "xmin": np.array([30, 50]), + "ymin": np.array([100, 100]), + "xmax": np.array([70, 80]), + "ymax": np.array([300, 200]), + "class_id": np.array([0, 0]), + }, }, - 'model1': { - 'xmin': np.array([30, 50]), - 'ymin': np.array([100, 100]), - 'xmax': np.array([70, 80]), - 'ymax': np.array([300, 200]), - 'class_id': np.array([0, 0]) - } - }, - constants.PREDICTIONS_KEY: { - 'baseline': { - 'xmin': np.array([20, 30, 500]), - 'ymin': np.array([130, 100, 100]), - 'xmax': np.array([60, 70, 800]), - 'ymax': np.array([290, 300, 300]), - 'class_id': np.array([0, 0, 0]), - 'score': np.array([0.5, 0.3, 0.1]) - } - }, -}) + constants.PREDICTIONS_KEY: { + "baseline": { + "xmin": np.array([20, 30, 500]), + "ymin": np.array([130, 100, 100]), + "xmax": np.array([60, 70, 800]), + "ymax": np.array([290, 300, 300]), + "class_id": np.array([0, 0, 0]), + "score": np.array([0.5, 0.3, 0.1]), + } + }, + } +) -_BOXMATCH_CASE1_MULTI_MODEL_BBOXFORMAT_BINARY = util.StandardExtracts({ - constants.LABELS_KEY: { - 'bbox': np.array([[30, 100, 70, 300], [50, 100, 80, 200]]), - 'class_id': np.array([0, 0]) - }, - constants.PREDICTIONS_KEY: { - 'baseline': { - 'bbox': - np.array([[20, 130, 60, 290], [30, 100, 70, 300], - [500, 100, 800, 300]]), - 'class_id': - np.array([0, 0, 0]), - 'score': - np.array([0.5, 0.3, 0.1]) +_BOXMATCH_CASE1_MULTI_MODEL_BBOXFORMAT_BINARY = util.StandardExtracts( + { + constants.LABELS_KEY: { + "bbox": np.array([[30, 100, 70, 300], [50, 100, 80, 200]]), + "class_id": np.array([0, 0]), }, - 'model1': { - 'bbox': - np.array([[120, 230, 60, 290], [30, 100, 70, 300], - [500, 100, 800, 300]]), - 'class_id': - np.array([0, 0, 0]), - 'score': - np.array([0.5, 0.3, 0.1]) + constants.PREDICTIONS_KEY: { + "baseline": { + "bbox": np.array( + [[20, 130, 60, 290], [30, 100, 70, 300], [500, 100, 800, 300]] + ), + "class_id": np.array([0, 0, 0]), + "score": np.array([0.5, 0.3, 0.1]), + }, + "model1": { + "bbox": np.array( + [[120, 230, 60, 290], [30, 100, 70, 300], [500, 100, 800, 300]] + ), + "class_id": np.array([0, 0, 0]), + "score": np.array([0.5, 0.3, 0.1]), + }, }, + } +) + +_BOXMATCH_CASE1_RESULT = [ + { + constants.LABELS_KEY: np.array([1.0]), + constants.PREDICTIONS_KEY: np.array([0.5]), + constants.EXAMPLE_WEIGHTS_KEY: np.array([1.0]), + }, + { + constants.LABELS_KEY: np.array([1.0]), + constants.PREDICTIONS_KEY: np.array([0.0]), + constants.EXAMPLE_WEIGHTS_KEY: np.array([1.0]), + }, + { + constants.LABELS_KEY: np.array([0.0]), + constants.PREDICTIONS_KEY: np.array([0.3]), + constants.EXAMPLE_WEIGHTS_KEY: np.array([1.0]), }, -}) + { + constants.LABELS_KEY: np.array([0.0]), + constants.PREDICTIONS_KEY: np.array([0.1]), + constants.EXAMPLE_WEIGHTS_KEY: np.array([1.0]), + }, +] -_BOXMATCH_CASE1_RESULT = [{ - constants.LABELS_KEY: np.array([1.]), - constants.PREDICTIONS_KEY: np.array([0.5]), - constants.EXAMPLE_WEIGHTS_KEY: np.array([1.]) -}, { - constants.LABELS_KEY: np.array([1.]), - constants.PREDICTIONS_KEY: np.array([0.]), - constants.EXAMPLE_WEIGHTS_KEY: np.array([1.]) -}, { - constants.LABELS_KEY: np.array([0.]), - constants.PREDICTIONS_KEY: np.array([0.3]), - constants.EXAMPLE_WEIGHTS_KEY: np.array([1.]) -}, { - constants.LABELS_KEY: np.array([0.]), - constants.PREDICTIONS_KEY: np.array([0.1]), - constants.EXAMPLE_WEIGHTS_KEY: np.array([1.]) -}] +_BOXMATCH_CASE2_PREDICT_NOT_FOUND = util.StandardExtracts( + { + constants.LABELS_KEY: { + "xmin": np.array([30, 50]), + "ymin": np.array([100, 100]), + "xmax": np.array([70, 80]), + "ymax": np.array([300, 200]), + "class_id": np.array([0, 0]), + }, + # Searching labels in tranformed features + constants.TRANSFORMED_FEATURES_KEY: {}, + constants.PREDICTIONS_KEY: { + "xmin": np.array([20, 30, 500]), + "ymin": np.array([130, 100, 100]), + "ymax": np.array([290, 300, 300]), + "class_id": np.array([0, 0, 0]), + "score": np.array([0.5, 0.3, 0.1]), + }, + } +) -_BOXMATCH_CASE2_PREDICT_NOT_FOUND = util.StandardExtracts({ - constants.LABELS_KEY: { - 'xmin': np.array([30, 50]), - 'ymin': np.array([100, 100]), - 'xmax': np.array([70, 80]), - 'ymax': np.array([300, 200]), - 'class_id': np.array([0, 0]) +_BOXMATCH_CASE2_PREDICT_NOT_FOUND_RESULT = [ + { + constants.LABELS_KEY: np.array([1.0]), + constants.PREDICTIONS_KEY: np.array([0]), + constants.EXAMPLE_WEIGHTS_KEY: np.array([1.0]), }, - # Searching labels in tranformed features - constants.TRANSFORMED_FEATURES_KEY: {}, - constants.PREDICTIONS_KEY: { - 'xmin': np.array([20, 30, 500]), - 'ymin': np.array([130, 100, 100]), - 'ymax': np.array([290, 300, 300]), - 'class_id': np.array([0, 0, 0]), - 'score': np.array([0.5, 0.3, 0.1]) + { + constants.LABELS_KEY: np.array([1.0]), + constants.PREDICTIONS_KEY: np.array([0]), + constants.EXAMPLE_WEIGHTS_KEY: np.array([1.0]), }, -}) - -_BOXMATCH_CASE2_PREDICT_NOT_FOUND_RESULT = [{ - constants.LABELS_KEY: np.array([1.]), - constants.PREDICTIONS_KEY: np.array([0]), - constants.EXAMPLE_WEIGHTS_KEY: np.array([1.]) -}, { - constants.LABELS_KEY: np.array([1.]), - constants.PREDICTIONS_KEY: np.array([0]), - constants.EXAMPLE_WEIGHTS_KEY: np.array([1.]) -}] +] class ObjectDetectionPreprocessorTest(parameterized.TestCase): + @parameterized.named_parameters( + ( + "binary_classification", + _BOXMATCH_CASE1_BINARY, + 0, + 0.5, + _BOXMATCH_CASE1_RESULT, + ) + ) + def testBoundingBoxMatchPreprocessor( + self, extracts, class_id, iou_threshold, expected_inputs + ): + with beam.Pipeline() as p: + updated_pcoll = ( + p + | "Create" >> beam.Create([extracts]) + | "Preprocess" + >> beam.ParDo( + object_detection_preprocessors.BoundingBoxMatchPreprocessor( + class_id=class_id, iou_threshold=iou_threshold + ) + ) + ) - @parameterized.named_parameters( - ('binary_classification', _BOXMATCH_CASE1_BINARY, 0, 0.5, - _BOXMATCH_CASE1_RESULT)) - def testBoundingBoxMatchPreprocessor(self, extracts, class_id, iou_threshold, - expected_inputs): - with beam.Pipeline() as p: - updated_pcoll = ( - p | 'Create' >> beam.Create([extracts]) - | 'Preprocess' >> beam.ParDo( - object_detection_preprocessors.BoundingBoxMatchPreprocessor( - class_id=class_id, iou_threshold=iou_threshold))) - - def check_result(result): - # Only single extract case is tested - self.assertLen(result, 4) - for updated_extracts, expected_input in zip(result, expected_inputs): - self.assertIn(constants.PREDICTIONS_KEY, updated_extracts) - np.testing.assert_allclose( - updated_extracts[constants.PREDICTIONS_KEY], - expected_input[constants.PREDICTIONS_KEY]) - self.assertIn(constants.LABELS_KEY, updated_extracts) - np.testing.assert_allclose(updated_extracts[constants.LABELS_KEY], - expected_input[constants.LABELS_KEY]) - self.assertIn(constants.EXAMPLE_WEIGHTS_KEY, updated_extracts) - np.testing.assert_allclose( - updated_extracts[constants.EXAMPLE_WEIGHTS_KEY], - expected_input[constants.EXAMPLE_WEIGHTS_KEY]) + def check_result(result): + # Only single extract case is tested + self.assertLen(result, 4) + for updated_extracts, expected_input in zip(result, expected_inputs): + self.assertIn(constants.PREDICTIONS_KEY, updated_extracts) + np.testing.assert_allclose( + updated_extracts[constants.PREDICTIONS_KEY], + expected_input[constants.PREDICTIONS_KEY], + ) + self.assertIn(constants.LABELS_KEY, updated_extracts) + np.testing.assert_allclose( + updated_extracts[constants.LABELS_KEY], + expected_input[constants.LABELS_KEY], + ) + self.assertIn(constants.EXAMPLE_WEIGHTS_KEY, updated_extracts) + np.testing.assert_allclose( + updated_extracts[constants.EXAMPLE_WEIGHTS_KEY], + expected_input[constants.EXAMPLE_WEIGHTS_KEY], + ) - beam_testing_util.assert_that(updated_pcoll, check_result) + beam_testing_util.assert_that(updated_pcoll, check_result) - @parameterized.named_parameters( - ('split_format', _BOXMATCH_CASE1_SPLITFORMAT_BINARY, 0, 0.5, [ - 'xmin', 'ymin', 'xmax', 'ymax', 'class_id' - ], ['xmin', 'ymin', 'xmax', 'ymax', 'class_id', 'score' - ], _BOXMATCH_CASE1_RESULT), - ('bbox_format', _BOXMATCH_CASE1_BBOXFORMAT_BINARY, 0, 0.5, [ - 'bbox', 'class_id' - ], ['bbox', 'class_id', 'score'], _BOXMATCH_CASE1_RESULT)) - def testBoundingBoxMatchPreprocessorWithFormatChange(self, extracts, class_id, - iou_threshold, - labels_stack, - predictions_stack, - expected_inputs): - with beam.Pipeline() as p: - updated_pcoll = ( - p | 'Create' >> beam.Create([extracts]) - | 'Preprocess' >> beam.ParDo( - object_detection_preprocessors.BoundingBoxMatchPreprocessor( - class_id=class_id, - iou_threshold=iou_threshold, - labels_to_stack=labels_stack, - predictions_to_stack=predictions_stack))) + @parameterized.named_parameters( + ( + "split_format", + _BOXMATCH_CASE1_SPLITFORMAT_BINARY, + 0, + 0.5, + ["xmin", "ymin", "xmax", "ymax", "class_id"], + ["xmin", "ymin", "xmax", "ymax", "class_id", "score"], + _BOXMATCH_CASE1_RESULT, + ), + ( + "bbox_format", + _BOXMATCH_CASE1_BBOXFORMAT_BINARY, + 0, + 0.5, + ["bbox", "class_id"], + ["bbox", "class_id", "score"], + _BOXMATCH_CASE1_RESULT, + ), + ) + def testBoundingBoxMatchPreprocessorWithFormatChange( + self, + extracts, + class_id, + iou_threshold, + labels_stack, + predictions_stack, + expected_inputs, + ): + with beam.Pipeline() as p: + updated_pcoll = ( + p + | "Create" >> beam.Create([extracts]) + | "Preprocess" + >> beam.ParDo( + object_detection_preprocessors.BoundingBoxMatchPreprocessor( + class_id=class_id, + iou_threshold=iou_threshold, + labels_to_stack=labels_stack, + predictions_to_stack=predictions_stack, + ) + ) + ) - def check_result(result): - # Only single extract case is tested - self.assertLen(result, 4) - for updated_extracts, expected_input in zip(result, expected_inputs): - self.assertIn(constants.PREDICTIONS_KEY, updated_extracts) - np.testing.assert_allclose( - updated_extracts[constants.PREDICTIONS_KEY], - expected_input[constants.PREDICTIONS_KEY]) - self.assertIn(constants.LABELS_KEY, updated_extracts) - np.testing.assert_allclose(updated_extracts[constants.LABELS_KEY], - expected_input[constants.LABELS_KEY]) - self.assertIn(constants.EXAMPLE_WEIGHTS_KEY, updated_extracts) - np.testing.assert_allclose( - updated_extracts[constants.EXAMPLE_WEIGHTS_KEY], - expected_input[constants.EXAMPLE_WEIGHTS_KEY]) + def check_result(result): + # Only single extract case is tested + self.assertLen(result, 4) + for updated_extracts, expected_input in zip(result, expected_inputs): + self.assertIn(constants.PREDICTIONS_KEY, updated_extracts) + np.testing.assert_allclose( + updated_extracts[constants.PREDICTIONS_KEY], + expected_input[constants.PREDICTIONS_KEY], + ) + self.assertIn(constants.LABELS_KEY, updated_extracts) + np.testing.assert_allclose( + updated_extracts[constants.LABELS_KEY], + expected_input[constants.LABELS_KEY], + ) + self.assertIn(constants.EXAMPLE_WEIGHTS_KEY, updated_extracts) + np.testing.assert_allclose( + updated_extracts[constants.EXAMPLE_WEIGHTS_KEY], + expected_input[constants.EXAMPLE_WEIGHTS_KEY], + ) - beam_testing_util.assert_that(updated_pcoll, check_result) + beam_testing_util.assert_that(updated_pcoll, check_result) - @parameterized.named_parameters( - ('multi_output', _BOXMATCH_CASE1_MULTI_MODEL_SPLITFORMAT_BINARY, 0, 0.5, [ - 'xmin', 'ymin', 'xmax', 'ymax', 'class_id' - ], ['xmin', 'ymin', 'xmax', 'ymax', 'class_id', 'score' - ], _BOXMATCH_CASE1_RESULT, 'baseline'), - ('multi_model', _BOXMATCH_CASE1_MULTI_MODEL_BBOXFORMAT_BINARY, 0, 0.5, [ - 'bbox', 'class_id' - ], ['bbox', 'class_id', 'score'], _BOXMATCH_CASE1_RESULT, 'baseline')) - def testBoundingBoxMatchPreprocessorWithMulitModel( - self, extracts, class_id, iou_threshold, labels_stack, predictions_stack, - expected_inputs, model_name): - with beam.Pipeline() as p: - updated_pcoll = ( - p | 'Create' >> beam.Create([extracts]) - | 'Preprocess' >> beam.ParDo( - object_detection_preprocessors.BoundingBoxMatchPreprocessor( - class_id=class_id, - iou_threshold=iou_threshold, - labels_to_stack=labels_stack, - predictions_to_stack=predictions_stack, - model_name=model_name))) + @parameterized.named_parameters( + ( + "multi_output", + _BOXMATCH_CASE1_MULTI_MODEL_SPLITFORMAT_BINARY, + 0, + 0.5, + ["xmin", "ymin", "xmax", "ymax", "class_id"], + ["xmin", "ymin", "xmax", "ymax", "class_id", "score"], + _BOXMATCH_CASE1_RESULT, + "baseline", + ), + ( + "multi_model", + _BOXMATCH_CASE1_MULTI_MODEL_BBOXFORMAT_BINARY, + 0, + 0.5, + ["bbox", "class_id"], + ["bbox", "class_id", "score"], + _BOXMATCH_CASE1_RESULT, + "baseline", + ), + ) + def testBoundingBoxMatchPreprocessorWithMulitModel( + self, + extracts, + class_id, + iou_threshold, + labels_stack, + predictions_stack, + expected_inputs, + model_name, + ): + with beam.Pipeline() as p: + updated_pcoll = ( + p + | "Create" >> beam.Create([extracts]) + | "Preprocess" + >> beam.ParDo( + object_detection_preprocessors.BoundingBoxMatchPreprocessor( + class_id=class_id, + iou_threshold=iou_threshold, + labels_to_stack=labels_stack, + predictions_to_stack=predictions_stack, + model_name=model_name, + ) + ) + ) - def check_result(result): - # Only single extract case is tested - self.assertLen(result, 4) - for updated_extracts, expected_input in zip(result, expected_inputs): - self.assertIn(constants.PREDICTIONS_KEY, updated_extracts) - np.testing.assert_allclose( - updated_extracts[constants.PREDICTIONS_KEY], - expected_input[constants.PREDICTIONS_KEY]) - self.assertIn(constants.LABELS_KEY, updated_extracts) - np.testing.assert_allclose(updated_extracts[constants.LABELS_KEY], - expected_input[constants.LABELS_KEY]) - self.assertIn(constants.EXAMPLE_WEIGHTS_KEY, updated_extracts) - np.testing.assert_allclose( - updated_extracts[constants.EXAMPLE_WEIGHTS_KEY], - expected_input[constants.EXAMPLE_WEIGHTS_KEY]) + def check_result(result): + # Only single extract case is tested + self.assertLen(result, 4) + for updated_extracts, expected_input in zip(result, expected_inputs): + self.assertIn(constants.PREDICTIONS_KEY, updated_extracts) + np.testing.assert_allclose( + updated_extracts[constants.PREDICTIONS_KEY], + expected_input[constants.PREDICTIONS_KEY], + ) + self.assertIn(constants.LABELS_KEY, updated_extracts) + np.testing.assert_allclose( + updated_extracts[constants.LABELS_KEY], + expected_input[constants.LABELS_KEY], + ) + self.assertIn(constants.EXAMPLE_WEIGHTS_KEY, updated_extracts) + np.testing.assert_allclose( + updated_extracts[constants.EXAMPLE_WEIGHTS_KEY], + expected_input[constants.EXAMPLE_WEIGHTS_KEY], + ) - beam_testing_util.assert_that(updated_pcoll, check_result) + beam_testing_util.assert_that(updated_pcoll, check_result) - @parameterized.named_parameters( - ('not_found', _BOXMATCH_CASE2_PREDICT_NOT_FOUND, 0, 0.5, - ['xmin', 'ymin', 'xmax', 'ymax', - 'class_id'], ['xmin', 'ymin', 'xmax', 'ymax', 'class_id', - 'score'], _BOXMATCH_CASE2_PREDICT_NOT_FOUND_RESULT)) - def testBoundingBoxMatchPreprocessorWithKeyNotFound(self, extracts, class_id, - iou_threshold, - labels_stack, - predictions_stack, - expected_inputs): - with beam.Pipeline() as p: - updated_pcoll = ( - p | 'Create' >> beam.Create([extracts]) - | 'Preprocess' >> beam.ParDo( - object_detection_preprocessors.BoundingBoxMatchPreprocessor( - class_id=class_id, - iou_threshold=iou_threshold, - labels_to_stack=labels_stack, - predictions_to_stack=predictions_stack, - allow_missing_key=True))) + @parameterized.named_parameters( + ( + "not_found", + _BOXMATCH_CASE2_PREDICT_NOT_FOUND, + 0, + 0.5, + ["xmin", "ymin", "xmax", "ymax", "class_id"], + ["xmin", "ymin", "xmax", "ymax", "class_id", "score"], + _BOXMATCH_CASE2_PREDICT_NOT_FOUND_RESULT, + ) + ) + def testBoundingBoxMatchPreprocessorWithKeyNotFound( + self, + extracts, + class_id, + iou_threshold, + labels_stack, + predictions_stack, + expected_inputs, + ): + with beam.Pipeline() as p: + updated_pcoll = ( + p + | "Create" >> beam.Create([extracts]) + | "Preprocess" + >> beam.ParDo( + object_detection_preprocessors.BoundingBoxMatchPreprocessor( + class_id=class_id, + iou_threshold=iou_threshold, + labels_to_stack=labels_stack, + predictions_to_stack=predictions_stack, + allow_missing_key=True, + ) + ) + ) - def check_result(result): - # Only single extract case is tested - self.assertLen(result, 2) - for updated_extracts, expected_input in zip(result, expected_inputs): - self.assertIn(constants.PREDICTIONS_KEY, updated_extracts) - np.testing.assert_allclose( - updated_extracts[constants.PREDICTIONS_KEY], - expected_input[constants.PREDICTIONS_KEY]) - self.assertIn(constants.LABELS_KEY, updated_extracts) - np.testing.assert_allclose(updated_extracts[constants.LABELS_KEY], - expected_input[constants.LABELS_KEY]) - self.assertIn(constants.EXAMPLE_WEIGHTS_KEY, updated_extracts) - np.testing.assert_allclose( - updated_extracts[constants.EXAMPLE_WEIGHTS_KEY], - expected_input[constants.EXAMPLE_WEIGHTS_KEY]) + def check_result(result): + # Only single extract case is tested + self.assertLen(result, 2) + for updated_extracts, expected_input in zip(result, expected_inputs): + self.assertIn(constants.PREDICTIONS_KEY, updated_extracts) + np.testing.assert_allclose( + updated_extracts[constants.PREDICTIONS_KEY], + expected_input[constants.PREDICTIONS_KEY], + ) + self.assertIn(constants.LABELS_KEY, updated_extracts) + np.testing.assert_allclose( + updated_extracts[constants.LABELS_KEY], + expected_input[constants.LABELS_KEY], + ) + self.assertIn(constants.EXAMPLE_WEIGHTS_KEY, updated_extracts) + np.testing.assert_allclose( + updated_extracts[constants.EXAMPLE_WEIGHTS_KEY], + expected_input[constants.EXAMPLE_WEIGHTS_KEY], + ) - beam_testing_util.assert_that(updated_pcoll, check_result) + beam_testing_util.assert_that(updated_pcoll, check_result) -if __name__ == '__main__': - absltest.main() +if __name__ == "__main__": + absltest.main() diff --git a/tensorflow_model_analysis/metrics/preprocessors/set_match_preprocessors.py b/tensorflow_model_analysis/metrics/preprocessors/set_match_preprocessors.py index da2ee5bfaa..2393b2c711 100644 --- a/tensorflow_model_analysis/metrics/preprocessors/set_match_preprocessors.py +++ b/tensorflow_model_analysis/metrics/preprocessors/set_match_preprocessors.py @@ -22,237 +22,239 @@ import apache_beam as beam import numpy as np + from tensorflow_model_analysis import constants from tensorflow_model_analysis.api import types -from tensorflow_model_analysis.metrics import metric_types -from tensorflow_model_analysis.metrics import metric_util +from tensorflow_model_analysis.metrics import metric_types, metric_util from tensorflow_model_analysis.utils import util # indices for the inputs, it should be arranged in the following format: LEFT, TOP, RIGHT, BOTTOM, CLASS, CONFIDENCE = range(6) -_DEFAULT_SET_MATCH_PREPROCESSOR_NAME = 'set_match_preprocessor' +_DEFAULT_SET_MATCH_PREPROCESSOR_NAME = "set_match_preprocessor" class SetMatchPreprocessor(metric_types.Preprocessor): - """Computes label and prediction pairs for set matching.""" - - def __init__( - self, - prediction_class_key: str, - prediction_score_key: str, - class_key: Optional[str] = None, - weight_key: Optional[str] = None, - top_k: Optional[int] = None, - name: Optional[str] = None, - model_name: str = '', - ): - """Initialize the preprocessor for set matching. - - Example: - Labels: ['sun', 'moon'] - Predictions: { - 'classes': ['sun', 'sea', 'light'], - } - - The (label, prediction) tuples generated are: - (1, 1) (TP for sun) - (1, 0) (FN for moon) - (0, 1) (FP for sea) - (0, 1) (FP for light) - - Example with class weights: - Note: The preporcessor supports class wise weights inside each example. The - weight should be a numpy array stored in the features. The user could - provide the corresponding classes of the weights. If it is not provided, - then by default, the preprocessor assumes the weights are for labels. The - classes and weights should be of the same length. - - For classes with specified weights, the final weights of the label - prediction pair is class weight * example weight. - For the classes not listed in the class-weight pairs, the weight will be - the example_weight by default. - - 'labels': ['sun', 'moon'] - 'predictions': { - 'classes': ['sun', 'sea', 'light'], - 'scores': [1, 0.7, 0.3], - } - 'example_weights': [0.1] - 'features': 'class_weights': [0.3, 0.4] - - The (label, prediction, example weight) tuples generated are: - (1, 1, 0.03) (TP for sun with weight 0.3 * 0.1) - (1, 0, 0.04) (FN for moon with weight 0.4 * 0.1) - (0, 0.7, 0.1) (FP for sea with weight 0.1) - (0, 0.3, 0.1) (FP for light with weight 0.1) - - Args: - prediction_class_key: The key name of the classes in predictions. - prediction_score_key: The key name of the scores in predictions. - class_key: (Optional) The key name of the classes in class-weight pairs. - If it is not provided, the classes are assumed to be the label classes. - weight_key: (Optional) The key name of the weights of classes in - class-weight pairs. The value in this key should be a numpy array of the - same length as the classes in class_key. The key should be stored under - the features key. - top_k: (Optional) Used with a multi-class model to specify that the top-k - values should be used to compute the confusion matrix. The net effect is - that the non-top-k values are truncated and the matrix is then - constructed from the average TP, FP, TN, FN across the classes. When - top_k is used, metrics_specs.binarize settings must not be present. - name: (Optional) name for the preprocessor. - model_name: (Optional) model name (if multi-model evaluation). - """ - if not name: - name = metric_util.generate_private_name_from_arguments( - _DEFAULT_SET_MATCH_PREPROCESSOR_NAME, top_k=top_k - ) - super().__init__(name=name) - self._top_k = top_k - self._model_name = model_name - self._prediction_class_key = prediction_class_key - self._prediction_score_key = prediction_score_key - self._class_key = class_key - self._weight_key = weight_key - self._model_top_k_exceeds_prediction_length_distribution = ( - beam.metrics.Metrics.distribution( - constants.METRICS_NAMESPACE, 'top_k_exceeds_prediction_length' - ) - ) - - def process( - self, extracts: types.Extracts - ) -> Iterator[metric_types.StandardMetricInputs]: - extracts = util.StandardExtracts(extracts) - - labels = extracts.get_labels(self._model_name) - - if labels is None or labels.ndim != 1: - raise ValueError( - f'Labels must be a 1d numpy array. The classes are {labels}.' - ) - - # classes and weights are two lists representing the class wise weights. - classes = None - weights = None - if self._weight_key: - features = extracts.get_features() - if features is None: - raise ValueError( - 'Weights should be under "features" key. However, features is None' - ) - weights = util.get_by_keys(features, [self._weight_key]) - if self._class_key: - classes = util.get_by_keys(features, [self._class_key]) - else: - classes = labels - - if classes is None or not isinstance(classes, np.ndarray): - raise TypeError( - 'The classes for class-weight pair should be a numpy' - f' array. The classes are {classes}.' - ) - if weights is None or not isinstance(weights, np.ndarray): - raise TypeError( - 'The classes for class_weight pair should be a numpy' - f' array. The classes are {weights}.' - ) - if classes.shape != weights.shape: - raise ValueError( - 'Classes and weights must be of the same shape.' - f' Classes and weights are {classes} and {weights}.' + """Computes label and prediction pairs for set matching.""" + + def __init__( + self, + prediction_class_key: str, + prediction_score_key: str, + class_key: Optional[str] = None, + weight_key: Optional[str] = None, + top_k: Optional[int] = None, + name: Optional[str] = None, + model_name: str = "", + ): + """Initialize the preprocessor for set matching. + + Example: + ------- + Labels: ['sun', 'moon'] + Predictions: { + 'classes': ['sun', 'sea', 'light'], + } + + The (label, prediction) tuples generated are: + (1, 1) (TP for sun) + (1, 0) (FN for moon) + (0, 1) (FP for sea) + (0, 1) (FP for light) + + Example with class weights: + Note: The preporcessor supports class wise weights inside each example. The + weight should be a numpy array stored in the features. The user could + provide the corresponding classes of the weights. If it is not provided, + then by default, the preprocessor assumes the weights are for labels. The + classes and weights should be of the same length. + + For classes with specified weights, the final weights of the label + prediction pair is class weight * example weight. + For the classes not listed in the class-weight pairs, the weight will be + the example_weight by default. + + 'labels': ['sun', 'moon'] + 'predictions': { + 'classes': ['sun', 'sea', 'light'], + 'scores': [1, 0.7, 0.3], + } + 'example_weights': [0.1] + 'features': 'class_weights': [0.3, 0.4] + + The (label, prediction, example weight) tuples generated are: + (1, 1, 0.03) (TP for sun with weight 0.3 * 0.1) + (1, 0, 0.04) (FN for moon with weight 0.4 * 0.1) + (0, 0.7, 0.1) (FP for sea with weight 0.1) + (0, 0.3, 0.1) (FP for light with weight 0.1) + + Args: + ---- + prediction_class_key: The key name of the classes in predictions. + prediction_score_key: The key name of the scores in predictions. + class_key: (Optional) The key name of the classes in class-weight pairs. + If it is not provided, the classes are assumed to be the label classes. + weight_key: (Optional) The key name of the weights of classes in + class-weight pairs. The value in this key should be a numpy array of the + same length as the classes in class_key. The key should be stored under + the features key. + top_k: (Optional) Used with a multi-class model to specify that the top-k + values should be used to compute the confusion matrix. The net effect is + that the non-top-k values are truncated and the matrix is then + constructed from the average TP, FP, TN, FN across the classes. When + top_k is used, metrics_specs.binarize settings must not be present. + name: (Optional) name for the preprocessor. + model_name: (Optional) model name (if multi-model evaluation). + """ + if not name: + name = metric_util.generate_private_name_from_arguments( + _DEFAULT_SET_MATCH_PREPROCESSOR_NAME, top_k=top_k + ) + super().__init__(name=name) + self._top_k = top_k + self._model_name = model_name + self._prediction_class_key = prediction_class_key + self._prediction_score_key = prediction_score_key + self._class_key = class_key + self._weight_key = weight_key + self._model_top_k_exceeds_prediction_length_distribution = ( + beam.metrics.Metrics.distribution( + constants.METRICS_NAMESPACE, "top_k_exceeds_prediction_length" + ) ) - predictions = extracts.get_predictions(model_name=self._model_name) - if not isinstance(predictions, dict): - raise TypeError( - 'Predictions are expected to be a dictionary conatining ' - 'classes and scores.' - ) - - pred_classes = util.get_by_keys(predictions, [self._prediction_class_key]) - if pred_classes is None: - raise KeyError( - f'Key {self._prediction_class_key} is not found under ' - 'predictions of the extracts.' - ) - - pred_scores = None - if self._prediction_score_key: - pred_scores = util.get_by_keys(predictions, [self._prediction_score_key]) - if pred_scores is None: - raise KeyError( - f'Key {self._prediction_score_key} is not found under ' - 'predictions of the extracts.' - ) - - if pred_classes.shape != pred_scores.shape: - raise ValueError( - 'Classes and scores must be of the same shape. Classes and scores ' - f'are {pred_classes} and {pred_scores}.' - ) - - if pred_classes.ndim != 1: - raise ValueError( - 'Predicted classes must be a 1d numpy array. The classes are ' - f'{pred_classes}.' - ) - - if self._top_k is not None: - if self._top_k > len(pred_classes): - self._model_top_k_exceeds_prediction_length_distribution.update( - len(pred_classes) - ) - top_k = min(self._top_k, len(pred_classes)) - pred_classes = pred_classes[:top_k] - if pred_scores is not None: - pred_scores = pred_scores[:top_k] - - example_weight = extracts.get_example_weights(model_name=self._model_name) - - class_weights = None - if classes is not None and weights is not None: - class_weights = dict(zip(classes, weights)) - - label_classes = set(labels) - pred_classes_scores = dict(zip(pred_classes, pred_scores)) - pred_classes = set(pred_classes) - - def calculate_weights(class_name): - weight = np.array([1.0]) - if not example_weight and not class_weights: - return None - if example_weight: - weight *= example_weight - if class_weights and class_name in class_weights: - weight *= class_weights[class_name] - return weight - - # yield all true positives - for class_name in label_classes & pred_classes: - result = {} - result[constants.LABELS_KEY] = np.array([1.0]) - result[constants.PREDICTIONS_KEY] = np.array( - [pred_classes_scores[class_name]] - ) - result[constants.EXAMPLE_WEIGHTS_KEY] = calculate_weights(class_name) - yield metric_util.to_standard_metric_inputs(result) - - # yield all the false negatives - for class_name in label_classes - pred_classes: - result = {} - result[constants.LABELS_KEY] = np.array([1.0]) - # set the prediction score to float('-inf') such that it will always be - # counted as negative - result[constants.PREDICTIONS_KEY] = np.array([float('-inf')]) - result[constants.EXAMPLE_WEIGHTS_KEY] = calculate_weights(class_name) - yield metric_util.to_standard_metric_inputs(result) - - # yield all false positives - for class_name in pred_classes - label_classes: - result = {} - result[constants.LABELS_KEY] = [0.0] - result[constants.PREDICTIONS_KEY] = [pred_classes_scores[class_name]] - result[constants.EXAMPLE_WEIGHTS_KEY] = calculate_weights(class_name) - yield metric_util.to_standard_metric_inputs(result) + def process( + self, extracts: types.Extracts + ) -> Iterator[metric_types.StandardMetricInputs]: + extracts = util.StandardExtracts(extracts) + + labels = extracts.get_labels(self._model_name) + + if labels is None or labels.ndim != 1: + raise ValueError( + f"Labels must be a 1d numpy array. The classes are {labels}." + ) + + # classes and weights are two lists representing the class wise weights. + classes = None + weights = None + if self._weight_key: + features = extracts.get_features() + if features is None: + raise ValueError( + 'Weights should be under "features" key. However, features is None' + ) + weights = util.get_by_keys(features, [self._weight_key]) + if self._class_key: + classes = util.get_by_keys(features, [self._class_key]) + else: + classes = labels + + if classes is None or not isinstance(classes, np.ndarray): + raise TypeError( + "The classes for class-weight pair should be a numpy" + f" array. The classes are {classes}." + ) + if weights is None or not isinstance(weights, np.ndarray): + raise TypeError( + "The classes for class_weight pair should be a numpy" + f" array. The classes are {weights}." + ) + if classes.shape != weights.shape: + raise ValueError( + "Classes and weights must be of the same shape." + f" Classes and weights are {classes} and {weights}." + ) + + predictions = extracts.get_predictions(model_name=self._model_name) + if not isinstance(predictions, dict): + raise TypeError( + "Predictions are expected to be a dictionary conatining " + "classes and scores." + ) + + pred_classes = util.get_by_keys(predictions, [self._prediction_class_key]) + if pred_classes is None: + raise KeyError( + f"Key {self._prediction_class_key} is not found under " + "predictions of the extracts." + ) + + pred_scores = None + if self._prediction_score_key: + pred_scores = util.get_by_keys(predictions, [self._prediction_score_key]) + if pred_scores is None: + raise KeyError( + f"Key {self._prediction_score_key} is not found under " + "predictions of the extracts." + ) + + if pred_classes.shape != pred_scores.shape: + raise ValueError( + "Classes and scores must be of the same shape. Classes and scores " + f"are {pred_classes} and {pred_scores}." + ) + + if pred_classes.ndim != 1: + raise ValueError( + "Predicted classes must be a 1d numpy array. The classes are " + f"{pred_classes}." + ) + + if self._top_k is not None: + if self._top_k > len(pred_classes): + self._model_top_k_exceeds_prediction_length_distribution.update( + len(pred_classes) + ) + top_k = min(self._top_k, len(pred_classes)) + pred_classes = pred_classes[:top_k] + if pred_scores is not None: + pred_scores = pred_scores[:top_k] + + example_weight = extracts.get_example_weights(model_name=self._model_name) + + class_weights = None + if classes is not None and weights is not None: + class_weights = dict(zip(classes, weights)) + + label_classes = set(labels) + pred_classes_scores = dict(zip(pred_classes, pred_scores)) + pred_classes = set(pred_classes) + + def calculate_weights(class_name): + weight = np.array([1.0]) + if not example_weight and not class_weights: + return None + if example_weight: + weight *= example_weight + if class_weights and class_name in class_weights: + weight *= class_weights[class_name] + return weight + + # yield all true positives + for class_name in label_classes & pred_classes: + result = {} + result[constants.LABELS_KEY] = np.array([1.0]) + result[constants.PREDICTIONS_KEY] = np.array( + [pred_classes_scores[class_name]] + ) + result[constants.EXAMPLE_WEIGHTS_KEY] = calculate_weights(class_name) + yield metric_util.to_standard_metric_inputs(result) + + # yield all the false negatives + for class_name in label_classes - pred_classes: + result = {} + result[constants.LABELS_KEY] = np.array([1.0]) + # set the prediction score to float('-inf') such that it will always be + # counted as negative + result[constants.PREDICTIONS_KEY] = np.array([float("-inf")]) + result[constants.EXAMPLE_WEIGHTS_KEY] = calculate_weights(class_name) + yield metric_util.to_standard_metric_inputs(result) + + # yield all false positives + for class_name in pred_classes - label_classes: + result = {} + result[constants.LABELS_KEY] = [0.0] + result[constants.PREDICTIONS_KEY] = [pred_classes_scores[class_name]] + result[constants.EXAMPLE_WEIGHTS_KEY] = calculate_weights(class_name) + yield metric_util.to_standard_metric_inputs(result) diff --git a/tensorflow_model_analysis/metrics/preprocessors/set_match_preprocessors_test.py b/tensorflow_model_analysis/metrics/preprocessors/set_match_preprocessors_test.py index 0ea07488f7..f9ba68a2a2 100644 --- a/tensorflow_model_analysis/metrics/preprocessors/set_match_preprocessors_test.py +++ b/tensorflow_model_analysis/metrics/preprocessors/set_match_preprocessors_test.py @@ -13,24 +13,26 @@ # limitations under the License. """Tests for set match preprocessors.""" -from absl.testing import absltest -from absl.testing import parameterized import apache_beam as beam -from apache_beam.testing import util as beam_testing_util import numpy as np +from absl.testing import absltest, parameterized +from apache_beam.testing import util as beam_testing_util + from tensorflow_model_analysis import constants from tensorflow_model_analysis.metrics.preprocessors import set_match_preprocessors from tensorflow_model_analysis.utils import util # Initialize test data -_SET_MATCH_INPUT = util.StandardExtracts({ - constants.LABELS_KEY: np.array(['cats', 'dogs']), - constants.PREDICTIONS_KEY: { - 'classes': np.array(['dogs', 'birds']), - 'scores': np.array([0.3, 0.1]), - }, -}) +_SET_MATCH_INPUT = util.StandardExtracts( + { + constants.LABELS_KEY: np.array(["cats", "dogs"]), + constants.PREDICTIONS_KEY: { + "classes": np.array(["dogs", "birds"]), + "scores": np.array([0.3, 0.1]), + }, + } +) _SET_MATCH_RESULT = [ { @@ -39,7 +41,7 @@ }, { constants.LABELS_KEY: np.array([1.0]), - constants.PREDICTIONS_KEY: np.array([float('-inf')]), + constants.PREDICTIONS_KEY: np.array([float("-inf")]), }, { constants.LABELS_KEY: np.array([0.0]), @@ -47,14 +49,16 @@ }, ] -_SET_MATCH_INPUT_WITH_WEIGHT = util.StandardExtracts({ - constants.LABELS_KEY: np.array(['cats', 'dogs']), - constants.PREDICTIONS_KEY: { - 'classes': np.array(['dogs', 'birds']), - 'scores': np.array([0.3, 0.1]), - }, - constants.EXAMPLE_WEIGHTS_KEY: np.array([0.7]), -}) +_SET_MATCH_INPUT_WITH_WEIGHT = util.StandardExtracts( + { + constants.LABELS_KEY: np.array(["cats", "dogs"]), + constants.PREDICTIONS_KEY: { + "classes": np.array(["dogs", "birds"]), + "scores": np.array([0.3, 0.1]), + }, + constants.EXAMPLE_WEIGHTS_KEY: np.array([0.7]), + } +) _SET_MATCH_RESULT_WITH_WEIGHT = [ { @@ -64,7 +68,7 @@ }, { constants.LABELS_KEY: np.array([1.0]), - constants.PREDICTIONS_KEY: np.array([float('-inf')]), + constants.PREDICTIONS_KEY: np.array([float("-inf")]), constants.EXAMPLE_WEIGHTS_KEY: np.array([0.7]), }, { @@ -74,15 +78,17 @@ }, ] -_SET_MATCH_INPUT_WITH_CLASS_WEIGHT = util.StandardExtracts({ - constants.LABELS_KEY: np.array(['cats', 'dogs']), - constants.FEATURES_KEY: {'class_weights': np.array([0.7, 0.2])}, - constants.PREDICTIONS_KEY: { - 'classes': np.array(['dogs', 'birds']), - 'scores': np.array([0.3, 0.1]), - }, - constants.EXAMPLE_WEIGHTS_KEY: np.array([0.7]), -}) +_SET_MATCH_INPUT_WITH_CLASS_WEIGHT = util.StandardExtracts( + { + constants.LABELS_KEY: np.array(["cats", "dogs"]), + constants.FEATURES_KEY: {"class_weights": np.array([0.7, 0.2])}, + constants.PREDICTIONS_KEY: { + "classes": np.array(["dogs", "birds"]), + "scores": np.array([0.3, 0.1]), + }, + constants.EXAMPLE_WEIGHTS_KEY: np.array([0.7]), + } +) _SET_MATCH_RESULT_WITH_CLASS_WEIGHT = [ { @@ -92,7 +98,7 @@ }, { constants.LABELS_KEY: np.array([1.0]), - constants.PREDICTIONS_KEY: np.array([float('-inf')]), + constants.PREDICTIONS_KEY: np.array([float("-inf")]), constants.EXAMPLE_WEIGHTS_KEY: np.array([0.49]), }, { @@ -104,286 +110,305 @@ class SetMatchPreprocessorTest(parameterized.TestCase): - - @parameterized.named_parameters( - ( - 'two_sets', - _SET_MATCH_INPUT, - _SET_MATCH_RESULT, - set_match_preprocessors.SetMatchPreprocessor( - class_key='', - weight_key='', - prediction_class_key='classes', - prediction_score_key='scores', - top_k=None, - ), - ), - ( - 'two_sets_with_example_weight', - _SET_MATCH_INPUT_WITH_WEIGHT, - _SET_MATCH_RESULT_WITH_WEIGHT, - set_match_preprocessors.SetMatchPreprocessor( - class_key='', - weight_key='', - prediction_class_key='classes', - prediction_score_key='scores', - top_k=None, - ), - ), - ( - 'two_sets_with_top_k', - _SET_MATCH_INPUT, - _SET_MATCH_RESULT[:2], - set_match_preprocessors.SetMatchPreprocessor( - class_key='', - weight_key='', - prediction_class_key='classes', - prediction_score_key='scores', - top_k=1, - ), - ), - ( - 'two_sets_with_class_weight', - _SET_MATCH_INPUT_WITH_CLASS_WEIGHT, - _SET_MATCH_RESULT_WITH_CLASS_WEIGHT, - set_match_preprocessors.SetMatchPreprocessor( - class_key='', - weight_key='class_weights', - prediction_class_key='classes', - prediction_score_key='scores', - top_k=None, - ), - ), - ) - def testSetMatchPreprocessor(self, extracts, expected_inputs, preprocessor): - with beam.Pipeline() as p: - updated_pcoll = ( - p - | 'Create' >> beam.Create([extracts]) - | 'Preprocess' >> beam.ParDo(preprocessor) - ) - - def check_result(result): - # Only single extract case is tested - self.assertLen(result, len(expected_inputs)) - for updated_extracts, expected_input in zip(result, expected_inputs): - self.assertIn(constants.PREDICTIONS_KEY, updated_extracts) - np.testing.assert_allclose( - updated_extracts[constants.PREDICTIONS_KEY], - expected_input[constants.PREDICTIONS_KEY], - ) - self.assertIn(constants.LABELS_KEY, updated_extracts) - np.testing.assert_allclose( - updated_extracts[constants.LABELS_KEY], - expected_input[constants.LABELS_KEY], - ) - if constants.EXAMPLE_WEIGHTS_KEY in expected_input: - self.assertIn(constants.EXAMPLE_WEIGHTS_KEY, updated_extracts) - np.testing.assert_allclose( - updated_extracts[constants.EXAMPLE_WEIGHTS_KEY], - expected_input[constants.EXAMPLE_WEIGHTS_KEY], + @parameterized.named_parameters( + ( + "two_sets", + _SET_MATCH_INPUT, + _SET_MATCH_RESULT, + set_match_preprocessors.SetMatchPreprocessor( + class_key="", + weight_key="", + prediction_class_key="classes", + prediction_score_key="scores", + top_k=None, + ), + ), + ( + "two_sets_with_example_weight", + _SET_MATCH_INPUT_WITH_WEIGHT, + _SET_MATCH_RESULT_WITH_WEIGHT, + set_match_preprocessors.SetMatchPreprocessor( + class_key="", + weight_key="", + prediction_class_key="classes", + prediction_score_key="scores", + top_k=None, + ), + ), + ( + "two_sets_with_top_k", + _SET_MATCH_INPUT, + _SET_MATCH_RESULT[:2], + set_match_preprocessors.SetMatchPreprocessor( + class_key="", + weight_key="", + prediction_class_key="classes", + prediction_score_key="scores", + top_k=1, + ), + ), + ( + "two_sets_with_class_weight", + _SET_MATCH_INPUT_WITH_CLASS_WEIGHT, + _SET_MATCH_RESULT_WITH_CLASS_WEIGHT, + set_match_preprocessors.SetMatchPreprocessor( + class_key="", + weight_key="class_weights", + prediction_class_key="classes", + prediction_score_key="scores", + top_k=None, + ), + ), + ) + def testSetMatchPreprocessor(self, extracts, expected_inputs, preprocessor): + with beam.Pipeline() as p: + updated_pcoll = ( + p + | "Create" >> beam.Create([extracts]) + | "Preprocess" >> beam.ParDo(preprocessor) ) - beam_testing_util.assert_that(updated_pcoll, check_result) + def check_result(result): + # Only single extract case is tested + self.assertLen(result, len(expected_inputs)) + for updated_extracts, expected_input in zip(result, expected_inputs): + self.assertIn(constants.PREDICTIONS_KEY, updated_extracts) + np.testing.assert_allclose( + updated_extracts[constants.PREDICTIONS_KEY], + expected_input[constants.PREDICTIONS_KEY], + ) + self.assertIn(constants.LABELS_KEY, updated_extracts) + np.testing.assert_allclose( + updated_extracts[constants.LABELS_KEY], + expected_input[constants.LABELS_KEY], + ) + if constants.EXAMPLE_WEIGHTS_KEY in expected_input: + self.assertIn(constants.EXAMPLE_WEIGHTS_KEY, updated_extracts) + np.testing.assert_allclose( + updated_extracts[constants.EXAMPLE_WEIGHTS_KEY], + expected_input[constants.EXAMPLE_WEIGHTS_KEY], + ) - def testName(self): - preprocessor = set_match_preprocessors.SetMatchPreprocessor( - class_key='', - weight_key='', - prediction_class_key='classes', - prediction_score_key='scores', - top_k=3, - ) - self.assertEqual(preprocessor.name, '_set_match_preprocessor:top_k=3') + beam_testing_util.assert_that(updated_pcoll, check_result) - def testClassWeightShapeMismatch(self): - extracts = util.StandardExtracts({ - constants.LABELS_KEY: np.array(['cats', 'dogs']), - constants.FEATURES_KEY: {'class_weights': np.array([0.7])}, - constants.PREDICTIONS_KEY: np.array(['birds', 'dogs']), - }) - with self.assertRaisesRegex( - ValueError, - 'Classes and weights must be of the same shape.', - ): - _ = next( - set_match_preprocessors.SetMatchPreprocessor( - class_key='', - weight_key='class_weights', - prediction_class_key='classes', - prediction_score_key='', - ).process(extracts=extracts) - ) + def testName(self): + preprocessor = set_match_preprocessors.SetMatchPreprocessor( + class_key="", + weight_key="", + prediction_class_key="classes", + prediction_score_key="scores", + top_k=3, + ) + self.assertEqual(preprocessor.name, "_set_match_preprocessor:top_k=3") - def testLabelNotFoundClasses(self): - extracts = util.StandardExtracts({ - constants.LABELS_KEY: np.array(['cats', 'dogs']), - constants.FEATURES_KEY: {'class_weights': np.array([0.7, 0.2])}, - constants.PREDICTIONS_KEY: { - 'classes': np.array(['birds', 'dogs']), - 'scores': np.array([0.1, 0.3]), - }, - }) - with self.assertRaisesRegex(ValueError, 'key not found'): - _ = next( - set_match_preprocessors.SetMatchPreprocessor( - class_key='cla', - weight_key='weights', - prediction_class_key='classes', - prediction_score_key='scores', - ).process(extracts=extracts) - ) + def testClassWeightShapeMismatch(self): + extracts = util.StandardExtracts( + { + constants.LABELS_KEY: np.array(["cats", "dogs"]), + constants.FEATURES_KEY: {"class_weights": np.array([0.7])}, + constants.PREDICTIONS_KEY: np.array(["birds", "dogs"]), + } + ) + with self.assertRaisesRegex( + ValueError, + "Classes and weights must be of the same shape.", + ): + _ = next( + set_match_preprocessors.SetMatchPreprocessor( + class_key="", + weight_key="class_weights", + prediction_class_key="classes", + prediction_score_key="", + ).process(extracts=extracts) + ) - def testNotFoundClassWeights(self): - extracts = util.StandardExtracts({ - constants.LABELS_KEY: np.array(['cats', 'dogs']), - constants.FEATURES_KEY: {'class_weights': np.array([0.7, 0.2])}, - constants.PREDICTIONS_KEY: { - 'classes': np.array([['birds', 'dogs']]), - 'scores': np.array([[0.1, 0.3]]), - }, - }) - with self.assertRaisesRegex(ValueError, 'key not found'): - _ = next( - set_match_preprocessors.SetMatchPreprocessor( - class_key='', - weight_key='weigh', - prediction_class_key='classes', - prediction_score_key='score', - ).process(extracts=extracts) - ) + def testLabelNotFoundClasses(self): + extracts = util.StandardExtracts( + { + constants.LABELS_KEY: np.array(["cats", "dogs"]), + constants.FEATURES_KEY: {"class_weights": np.array([0.7, 0.2])}, + constants.PREDICTIONS_KEY: { + "classes": np.array(["birds", "dogs"]), + "scores": np.array([0.1, 0.3]), + }, + } + ) + with self.assertRaisesRegex(ValueError, "key not found"): + _ = next( + set_match_preprocessors.SetMatchPreprocessor( + class_key="cla", + weight_key="weights", + prediction_class_key="classes", + prediction_score_key="scores", + ).process(extracts=extracts) + ) - def testNotFoundFeatures(self): - extracts = util.StandardExtracts({ - constants.LABELS_KEY: np.array(['cats', 'dogs']), - constants.PREDICTIONS_KEY: { - 'classes': np.array([['birds', 'dogs']]), - 'scores': np.array([[0.1, 0.3]]), - }, - }) - with self.assertRaisesRegex(ValueError, 'features is None'): - _ = next( - set_match_preprocessors.SetMatchPreprocessor( - class_key='', - weight_key='weigh', - prediction_class_key='classes', - prediction_score_key='score', - ).process(extracts=extracts) - ) + def testNotFoundClassWeights(self): + extracts = util.StandardExtracts( + { + constants.LABELS_KEY: np.array(["cats", "dogs"]), + constants.FEATURES_KEY: {"class_weights": np.array([0.7, 0.2])}, + constants.PREDICTIONS_KEY: { + "classes": np.array([["birds", "dogs"]]), + "scores": np.array([[0.1, 0.3]]), + }, + } + ) + with self.assertRaisesRegex(ValueError, "key not found"): + _ = next( + set_match_preprocessors.SetMatchPreprocessor( + class_key="", + weight_key="weigh", + prediction_class_key="classes", + prediction_score_key="score", + ).process(extracts=extracts) + ) - def testInvalidLabel(self): - extracts = util.StandardExtracts({ - constants.LABELS_KEY: np.array([['cats', 'dogs']]), - constants.PREDICTIONS_KEY: { - 'classes': np.array(['birds', 'dogs']), - 'scores': np.array([0.1, 0.3]), - }, - }) - with self.assertRaisesRegex(ValueError, 'Labels must be a 1d numpy array.'): - _ = next( - set_match_preprocessors.SetMatchPreprocessor( - prediction_class_key='classes', - prediction_score_key='scores', - ).process(extracts=extracts) - ) + def testNotFoundFeatures(self): + extracts = util.StandardExtracts( + { + constants.LABELS_KEY: np.array(["cats", "dogs"]), + constants.PREDICTIONS_KEY: { + "classes": np.array([["birds", "dogs"]]), + "scores": np.array([[0.1, 0.3]]), + }, + } + ) + with self.assertRaisesRegex(ValueError, "features is None"): + _ = next( + set_match_preprocessors.SetMatchPreprocessor( + class_key="", + weight_key="weigh", + prediction_class_key="classes", + prediction_score_key="score", + ).process(extracts=extracts) + ) - def testPredictionNotADict(self): - extracts = util.StandardExtracts({ - constants.LABELS_KEY: np.array(['cats', 'dogs']), - constants.PREDICTIONS_KEY: np.array(['birds', 'dogs']), - }) - with self.assertRaisesRegex( - TypeError, - ( - 'Predictions are expected to be a ' - 'dictionary conatining classes and scores.' - ), - ): - _ = next( - set_match_preprocessors.SetMatchPreprocessor( - class_key='', - weight_key='', - prediction_class_key='classes', - prediction_score_key='scores', - ).process(extracts=extracts) - ) + def testInvalidLabel(self): + extracts = util.StandardExtracts( + { + constants.LABELS_KEY: np.array([["cats", "dogs"]]), + constants.PREDICTIONS_KEY: { + "classes": np.array(["birds", "dogs"]), + "scores": np.array([0.1, 0.3]), + }, + } + ) + with self.assertRaisesRegex(ValueError, "Labels must be a 1d numpy array."): + _ = next( + set_match_preprocessors.SetMatchPreprocessor( + prediction_class_key="classes", + prediction_score_key="scores", + ).process(extracts=extracts) + ) - def testPredictionNotFoundClasses(self): - extracts = util.StandardExtracts({ - constants.LABELS_KEY: np.array(['cats', 'dogs']), - constants.PREDICTIONS_KEY: { - 'classes': np.array([['birds', 'dogs']]), - 'scores': np.array([[0.1, 0.3]]), - }, - }) - with self.assertRaisesRegex(ValueError, 'key not found'): - _ = next( - set_match_preprocessors.SetMatchPreprocessor( - class_key='', - weight_key='', - prediction_class_key='clas', - prediction_score_key='scores', - ).process(extracts=extracts) - ) + def testPredictionNotADict(self): + extracts = util.StandardExtracts( + { + constants.LABELS_KEY: np.array(["cats", "dogs"]), + constants.PREDICTIONS_KEY: np.array(["birds", "dogs"]), + } + ) + with self.assertRaisesRegex( + TypeError, + ( + "Predictions are expected to be a " + "dictionary conatining classes and scores." + ), + ): + _ = next( + set_match_preprocessors.SetMatchPreprocessor( + class_key="", + weight_key="", + prediction_class_key="classes", + prediction_score_key="scores", + ).process(extracts=extracts) + ) - def testPredictionNotFoundScores(self): - extracts = util.StandardExtracts({ - constants.LABELS_KEY: np.array(['cats', 'dogs']), - constants.PREDICTIONS_KEY: { - 'classes': np.array([['birds', 'dogs']]), - 'scores': np.array([[0.1, 0.3]]), - }, - }) - with self.assertRaisesRegex(ValueError, 'key not found'): - _ = next( - set_match_preprocessors.SetMatchPreprocessor( - class_key='', - weight_key='', - prediction_class_key='classes', - prediction_score_key='scor', - ).process(extracts=extracts) - ) + def testPredictionNotFoundClasses(self): + extracts = util.StandardExtracts( + { + constants.LABELS_KEY: np.array(["cats", "dogs"]), + constants.PREDICTIONS_KEY: { + "classes": np.array([["birds", "dogs"]]), + "scores": np.array([[0.1, 0.3]]), + }, + } + ) + with self.assertRaisesRegex(ValueError, "key not found"): + _ = next( + set_match_preprocessors.SetMatchPreprocessor( + class_key="", + weight_key="", + prediction_class_key="clas", + prediction_score_key="scores", + ).process(extracts=extracts) + ) - def testInvalidPrediction(self): - extracts = util.StandardExtracts({ - constants.LABELS_KEY: np.array(['cats', 'dogs']), - constants.PREDICTIONS_KEY: { - 'classes': np.array([['birds', 'dogs']]), - 'scores': np.array([[0.1, 0.3]]), - }, - }) - with self.assertRaisesRegex( - ValueError, 'Predicted classes must be a 1d numpy array.' - ): - _ = next( - set_match_preprocessors.SetMatchPreprocessor( - class_key='', - weight_key='', - prediction_class_key='classes', - prediction_score_key='scores', - ).process(extracts=extracts) - ) + def testPredictionNotFoundScores(self): + extracts = util.StandardExtracts( + { + constants.LABELS_KEY: np.array(["cats", "dogs"]), + constants.PREDICTIONS_KEY: { + "classes": np.array([["birds", "dogs"]]), + "scores": np.array([[0.1, 0.3]]), + }, + } + ) + with self.assertRaisesRegex(ValueError, "key not found"): + _ = next( + set_match_preprocessors.SetMatchPreprocessor( + class_key="", + weight_key="", + prediction_class_key="classes", + prediction_score_key="scor", + ).process(extracts=extracts) + ) - def testMismatchClassesAndScores(self): - extracts = util.StandardExtracts({ - constants.LABELS_KEY: np.array(['cats', 'dogs']), - constants.PREDICTIONS_KEY: { - 'classes': np.array([['birds', 'dogs']]), - 'scores': np.array([0.1, 0.3]), - }, - }) - with self.assertRaisesRegex( - ValueError, 'Classes and scores must be of the same shape.' - ): - _ = next( - set_match_preprocessors.SetMatchPreprocessor( - class_key='', - weight_key='', - prediction_class_key='classes', - prediction_score_key='scores', - ).process(extracts=extracts) - ) + def testInvalidPrediction(self): + extracts = util.StandardExtracts( + { + constants.LABELS_KEY: np.array(["cats", "dogs"]), + constants.PREDICTIONS_KEY: { + "classes": np.array([["birds", "dogs"]]), + "scores": np.array([[0.1, 0.3]]), + }, + } + ) + with self.assertRaisesRegex( + ValueError, "Predicted classes must be a 1d numpy array." + ): + _ = next( + set_match_preprocessors.SetMatchPreprocessor( + class_key="", + weight_key="", + prediction_class_key="classes", + prediction_score_key="scores", + ).process(extracts=extracts) + ) + + def testMismatchClassesAndScores(self): + extracts = util.StandardExtracts( + { + constants.LABELS_KEY: np.array(["cats", "dogs"]), + constants.PREDICTIONS_KEY: { + "classes": np.array([["birds", "dogs"]]), + "scores": np.array([0.1, 0.3]), + }, + } + ) + with self.assertRaisesRegex( + ValueError, "Classes and scores must be of the same shape." + ): + _ = next( + set_match_preprocessors.SetMatchPreprocessor( + class_key="", + weight_key="", + prediction_class_key="classes", + prediction_score_key="scores", + ).process(extracts=extracts) + ) -if __name__ == '__main__': - absltest.main() +if __name__ == "__main__": + absltest.main() diff --git a/tensorflow_model_analysis/metrics/preprocessors/utils/__init__.py b/tensorflow_model_analysis/metrics/preprocessors/utils/__init__.py index 680116f3b5..dc2283fe78 100644 --- a/tensorflow_model_analysis/metrics/preprocessors/utils/__init__.py +++ b/tensorflow_model_analysis/metrics/preprocessors/utils/__init__.py @@ -12,5 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. """Init module for TensorFlow Model Analysis CV related utilities.""" -from tensorflow_model_analysis.metrics.preprocessors.utils import bounding_box -from tensorflow_model_analysis.metrics.preprocessors.utils import box_match + +from tensorflow_model_analysis.metrics.preprocessors.utils import ( + bounding_box, + box_match, +) diff --git a/tensorflow_model_analysis/metrics/preprocessors/utils/bounding_box.py b/tensorflow_model_analysis/metrics/preprocessors/utils/bounding_box.py index b353c4ab1e..39356c6f3a 100644 --- a/tensorflow_model_analysis/metrics/preprocessors/utils/bounding_box.py +++ b/tensorflow_model_analysis/metrics/preprocessors/utils/bounding_box.py @@ -17,8 +17,9 @@ etc. """ -from typing import Iterable, Tuple import warnings +from typing import Iterable, Tuple + import numpy as np # indices for the inputs, it should be arranged in the following format: @@ -27,114 +28,135 @@ def bounding_box_area(boxes: np.ndarray) -> np.ndarray: - """Compute areas for a list of boxes. - - Args: - boxes: a numpy array with dimenstion [ , 4+] in corners format - [LEFT, TOP, RIGHT, BOTTOM] - - Returns: - boxes_areas: a numpy array with dimension len(boxes) - """ - if boxes.ndim != 2 or boxes.shape[1] < 4: - raise ValueError('Input boxes list should be a 2d array of shape ' - '( , 4+)') - # Calculate the height and width for each dimension - if np.any(boxes[:, RIGHT] - boxes[:, LEFT] < 0): - raise ValueError('The right boundary is less than the left boundary ' - 'right = {}, left = {}.'.format(boxes[:, RIGHT], - boxes[:, LEFT])) - if np.any(boxes[:, BOTTOM] - boxes[:, TOP] < 0): - raise ValueError('The BOTTOM boundary is less than the TOP boundary ' - 'BOTTOM = {}, TOP = {}.'.format(boxes[:, BOTTOM], - boxes[:, TOP])) - boxes_width = boxes[:, RIGHT] - boxes[:, LEFT] - boxes_height = boxes[:, BOTTOM] - boxes[:, TOP] - - # Calculate the area of each box - boxes_areas = boxes_width * boxes_height - return boxes_areas - - -def filter_boxes_by_class(boxes: np.ndarray, - class_id: Iterable[int]) -> np.ndarray: - """Select boxes for a given set of classes. - - Args: - boxes: a numpy array representing the bounding boxes in the following format - [LEFT, TOP, RIGHT, BOTTOM, CLASS, CONFIDENCE(Optional)] - class_id: id of target num_classes - - Returns: - filtered_boxes: the filtered bounding boxes - """ - if boxes.ndim != 2 or boxes.shape[1] <= CLASS: - raise ValueError(f'Input boxes list should be a 2d array of shape ' - f'(, {CLASS}+)') - if isinstance(class_id, int): - class_id = [class_id] - return boxes[np.isin(boxes[:, CLASS], class_id), :] - - -def check_boxes_in_area_range(boxes: np.ndarray, - area_range: Tuple[float, float]) -> np.ndarray: - """Check boxes whether their areas fall in a given range. - - Args: - boxes: a numpy array representing the bounding boxes in the following format - [LEFT, TOP, RIGHT, BOTTOM, CLASS, CONFIDENCE(Optional)] - area_range: [lowerbound(inclusive), upperbound(inclusive)] of the box area - - Returns: - A numpy array of bool indicates whether the box is in the range. - """ - if boxes.ndim != 2 or boxes.shape[1] <= 3: - raise ValueError('Input boxes list should be a 2d array of shape ' - '(,4+)') - - if len(area_range) != 2: - raise ValueError('Invalid shape of area_range') - if area_range[1] <= area_range[0]: - raise ValueError('Invalid input of area range: lower bound is greater' - 'or equal to upperbound') - - area = bounding_box_area(boxes) - return (area_range[0] <= area) & (area <= area_range[1]) - - -def filter_boxes_by_area_range(boxes: np.ndarray, - area_range: Tuple[float, float]) -> np.ndarray: - """Select boxes whose areas fall in a given range. - - Args: - boxes: a numpy array representing the bounding boxes in the following format - [LEFT, TOP, RIGHT, BOTTOM, CLASS, CONFIDENCE(Optional)] - area_range: [lowerbound(inclusive), upperbound(exclusive)] of the box area - - Returns: - filtered_boxes: the filtered bounding Boxes - """ - return boxes[check_boxes_in_area_range(boxes, area_range), :] + """Compute areas for a list of boxes. + + Args: + ---- + boxes: a numpy array with dimenstion [ , 4+] in corners format + [LEFT, TOP, RIGHT, BOTTOM] + + Returns: + ------- + boxes_areas: a numpy array with dimension len(boxes) + """ + if boxes.ndim != 2 or boxes.shape[1] < 4: + raise ValueError( + "Input boxes list should be a 2d array of shape " "( , 4+)" + ) + # Calculate the height and width for each dimension + if np.any(boxes[:, RIGHT] - boxes[:, LEFT] < 0): + raise ValueError( + "The right boundary is less than the left boundary " + f"right = {boxes[:, RIGHT]}, left = {boxes[:, LEFT]}." + ) + if np.any(boxes[:, BOTTOM] - boxes[:, TOP] < 0): + raise ValueError( + "The BOTTOM boundary is less than the TOP boundary " + f"BOTTOM = {boxes[:, BOTTOM]}, TOP = {boxes[:, TOP]}." + ) + boxes_width = boxes[:, RIGHT] - boxes[:, LEFT] + boxes_height = boxes[:, BOTTOM] - boxes[:, TOP] + + # Calculate the area of each box + boxes_areas = boxes_width * boxes_height + return boxes_areas + + +def filter_boxes_by_class(boxes: np.ndarray, class_id: Iterable[int]) -> np.ndarray: + """Select boxes for a given set of classes. + + Args: + ---- + boxes: a numpy array representing the bounding boxes in the following format + [LEFT, TOP, RIGHT, BOTTOM, CLASS, CONFIDENCE(Optional)] + class_id: id of target num_classes + + Returns: + ------- + filtered_boxes: the filtered bounding boxes + """ + if boxes.ndim != 2 or boxes.shape[1] <= CLASS: + raise ValueError( + f"Input boxes list should be a 2d array of shape " + f"(, {CLASS}+)" + ) + if isinstance(class_id, int): + class_id = [class_id] + return boxes[np.isin(boxes[:, CLASS], class_id), :] + + +def check_boxes_in_area_range( + boxes: np.ndarray, area_range: Tuple[float, float] +) -> np.ndarray: + """Check boxes whether their areas fall in a given range. + + Args: + ---- + boxes: a numpy array representing the bounding boxes in the following format + [LEFT, TOP, RIGHT, BOTTOM, CLASS, CONFIDENCE(Optional)] + area_range: [lowerbound(inclusive), upperbound(inclusive)] of the box area + + Returns: + ------- + A numpy array of bool indicates whether the box is in the range. + """ + if boxes.ndim != 2 or boxes.shape[1] <= 3: + raise ValueError( + "Input boxes list should be a 2d array of shape " "(,4+)" + ) + + if len(area_range) != 2: + raise ValueError("Invalid shape of area_range") + if area_range[1] <= area_range[0]: + raise ValueError( + "Invalid input of area range: lower bound is greater" + "or equal to upperbound" + ) + + area = bounding_box_area(boxes) + return (area_range[0] <= area) & (area <= area_range[1]) + + +def filter_boxes_by_area_range( + boxes: np.ndarray, area_range: Tuple[float, float] +) -> np.ndarray: + """Select boxes whose areas fall in a given range. + + Args: + ---- + boxes: a numpy array representing the bounding boxes in the following format + [LEFT, TOP, RIGHT, BOTTOM, CLASS, CONFIDENCE(Optional)] + area_range: [lowerbound(inclusive), upperbound(exclusive)] of the box area + + Returns: + ------- + filtered_boxes: the filtered bounding Boxes + """ + return boxes[check_boxes_in_area_range(boxes, area_range), :] def sort_boxes_by_confidence(boxes: np.ndarray) -> np.ndarray: - """Sort boxes according the confidence in descending order. - - It is using merge sort to agree with COCO metrics. - - Args: - boxes: a numpy array representing the bounding boxes in the following format - [LEFT, TOP, RIGHT, BOTTOM, CLASS, CONFIDENCE] - - Returns: - sorted_boxes: the sorted list of bounding boxes - """ - if boxes.ndim != 2: - raise ValueError(f'Input boxes list should be a 2d array of shape ' - f'( , {CONFIDENCE}+)') - - if boxes.shape[1] <= CONFIDENCE: - warnings.warn('The axis for sort does not exist, return the original data') - return boxes - inds = np.argsort(-boxes[:, CONFIDENCE], kind='mergesort') - return boxes[inds] + """Sort boxes according the confidence in descending order. + + It is using merge sort to agree with COCO metrics. + + Args: + ---- + boxes: a numpy array representing the bounding boxes in the following format + [LEFT, TOP, RIGHT, BOTTOM, CLASS, CONFIDENCE] + + Returns: + ------- + sorted_boxes: the sorted list of bounding boxes + """ + if boxes.ndim != 2: + raise ValueError( + f"Input boxes list should be a 2d array of shape " + f"( , {CONFIDENCE}+)" + ) + + if boxes.shape[1] <= CONFIDENCE: + warnings.warn("The axis for sort does not exist, return the original data") + return boxes + inds = np.argsort(-boxes[:, CONFIDENCE], kind="mergesort") + return boxes[inds] diff --git a/tensorflow_model_analysis/metrics/preprocessors/utils/bounding_box_test.py b/tensorflow_model_analysis/metrics/preprocessors/utils/bounding_box_test.py index a757668292..dbdab3c94c 100644 --- a/tensorflow_model_analysis/metrics/preprocessors/utils/bounding_box_test.py +++ b/tensorflow_model_analysis/metrics/preprocessors/utils/bounding_box_test.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. """Tests for bounding_box.""" -from absl.testing import absltest -from absl.testing import parameterized + import numpy as np +from absl.testing import absltest, parameterized from tensorflow_model_analysis.metrics.preprocessors.utils import bounding_box @@ -24,66 +24,96 @@ class BoundingBoxTest(parameterized.TestCase): + def test_input_check_bounding_box_area(self): + # Input should not be empty + expected_exception = ValueError + expected_regex = "Input boxes list should be a 2d array" + self.assertRaisesRegex( + expected_exception, + expected_regex, + bounding_box.bounding_box_area, + np.array([20, 60, 290]), + ) - def test_input_check_bounding_box_area(self): - # Input should not be empty - expected_exception = ValueError - expected_regex = 'Input boxes list should be a 2d array' - self.assertRaisesRegex(expected_exception, expected_regex, - bounding_box.bounding_box_area, - np.array([20, 60, 290])) - - def test_input_value_check_bounding_box_area(self): - boxes = np.array([[20, 300, 60, 290]]) - expected_exception = ValueError - expected_regex = 'The BOTTOM boundary is less than the TOP boundary ' - self.assertRaisesRegex(expected_exception, expected_regex, - bounding_box.bounding_box_area, boxes) + def test_input_value_check_bounding_box_area(self): + boxes = np.array([[20, 300, 60, 290]]) + expected_exception = ValueError + expected_regex = "The BOTTOM boundary is less than the TOP boundary " + self.assertRaisesRegex( + expected_exception, expected_regex, bounding_box.bounding_box_area, boxes + ) - def test_compute_box_area(self): - boxes = np.array([[30, 100, 70, 300], [50, 100, 80, 110]]) - np.testing.assert_allclose( - np.array([8000, 300]), bounding_box.bounding_box_area(boxes)) + def test_compute_box_area(self): + boxes = np.array([[30, 100, 70, 300], [50, 100, 80, 110]]) + np.testing.assert_allclose( + np.array([8000, 300]), bounding_box.bounding_box_area(boxes) + ) - def test_input_check_filter_boxes_by_class(self): - with self.assertRaisesRegex(ValueError, - 'Input boxes list should be a 2d array'): - _ = bounding_box.filter_boxes_by_class(np.array([20, 60, 290]), [3]) + def test_input_check_filter_boxes_by_class(self): + with self.assertRaisesRegex( + ValueError, "Input boxes list should be a 2d array" + ): + _ = bounding_box.filter_boxes_by_class(np.array([20, 60, 290]), [3]) - def test_filter_boxes_by_one_class(self): - boxes = np.array([[30, 100, 70, 300, 1], [50, 100, 80, 90, 2], - [40, 100, 100, 290, 1]]) - result = bounding_box.filter_boxes_by_class(boxes, [1]) - expected_result = np.array([[30, 100, 70, 300, 1], [40, 100, 100, 290, 1]]) - np.testing.assert_equal(result, expected_result) + def test_filter_boxes_by_one_class(self): + boxes = np.array( + [[30, 100, 70, 300, 1], [50, 100, 80, 90, 2], [40, 100, 100, 290, 1]] + ) + result = bounding_box.filter_boxes_by_class(boxes, [1]) + expected_result = np.array([[30, 100, 70, 300, 1], [40, 100, 100, 290, 1]]) + np.testing.assert_equal(result, expected_result) - def test_filter_boxes_by_multi_classes(self): - boxes = np.array([[30, 100, 70, 300, 1], [50, 100, 80, 90, 2], - [40, 100, 100, 290, 1], [55, 200, 88, 390, 0]]) - result = bounding_box.filter_boxes_by_class(boxes, [0, 1]) - expected_result = np.array([[30, 100, 70, 300, 1], [40, 100, 100, 290, 1], - [55, 200, 88, 390, 0]]) - np.testing.assert_equal(result, expected_result) + def test_filter_boxes_by_multi_classes(self): + boxes = np.array( + [ + [30, 100, 70, 300, 1], + [50, 100, 80, 90, 2], + [40, 100, 100, 290, 1], + [55, 200, 88, 390, 0], + ] + ) + result = bounding_box.filter_boxes_by_class(boxes, [0, 1]) + expected_result = np.array( + [[30, 100, 70, 300, 1], [40, 100, 100, 290, 1], [55, 200, 88, 390, 0]] + ) + np.testing.assert_equal(result, expected_result) - @parameterized.named_parameters( - ('_filtering_to_empty', np.array([[0, 100, 100, 300] - ]), [50000, 80000], np.empty([0, 4])), - ('_case1', - np.array([[30, 100, 70, 300, 1], [50, 100, 80, 100, 2], - [40, 100, 100, 290, 1], [55, 200, 88, 390, 0]]), [7000, 12000], - np.array([[30, 100, 70, 300, 1], [40, 100, 100, 290, 1]]))) - def test_filter_boxes_by_area_range(self, boxes, area_range, expected_result): - result = bounding_box.filter_boxes_by_area_range(boxes, area_range) - np.testing.assert_equal(result, expected_result) + @parameterized.named_parameters( + ( + "_filtering_to_empty", + np.array([[0, 100, 100, 300]]), + [50000, 80000], + np.empty([0, 4]), + ), + ( + "_case1", + np.array( + [ + [30, 100, 70, 300, 1], + [50, 100, 80, 100, 2], + [40, 100, 100, 290, 1], + [55, 200, 88, 390, 0], + ] + ), + [7000, 12000], + np.array([[30, 100, 70, 300, 1], [40, 100, 100, 290, 1]]), + ), + ) + def test_filter_boxes_by_area_range(self, boxes, area_range, expected_result): + result = bounding_box.filter_boxes_by_area_range(boxes, area_range) + np.testing.assert_equal(result, expected_result) - def test_input_check_sort_boxes_by_confidence(self): - # Input should not be empty - expected_exception = ValueError - expected_regex = 'Input boxes list should be a 2d array' - self.assertRaisesRegex(expected_exception, expected_regex, - bounding_box.sort_boxes_by_confidence, - np.array([20, 60, 290])) + def test_input_check_sort_boxes_by_confidence(self): + # Input should not be empty + expected_exception = ValueError + expected_regex = "Input boxes list should be a 2d array" + self.assertRaisesRegex( + expected_exception, + expected_regex, + bounding_box.sort_boxes_by_confidence, + np.array([20, 60, 290]), + ) -if __name__ == '__main__': - absltest.main() +if __name__ == "__main__": + absltest.main() diff --git a/tensorflow_model_analysis/metrics/preprocessors/utils/box_match.py b/tensorflow_model_analysis/metrics/preprocessors/utils/box_match.py index 72850ba893..110cdd7348 100644 --- a/tensorflow_model_analysis/metrics/preprocessors/utils/box_match.py +++ b/tensorflow_model_analysis/metrics/preprocessors/utils/box_match.py @@ -16,7 +16,8 @@ It includes functions for pairwise iou calculations and box matching utilities. """ -from typing import Callable, Iterable, Union, Tuple, Optional +from typing import Callable, Iterable, Optional, Tuple, Union + import numpy as np from tensorflow_model_analysis.metrics.preprocessors.utils import bounding_box @@ -26,124 +27,134 @@ def _match_boxes( - ious: np.ndarray, - thresholds: Union[float, Iterable[float]]) -> Tuple[np.ndarray, np.ndarray]: - """Match predictions and ground_truth through the pairwise IoUs. - - Args: - ious: a numpy array, ious[i,j] is the iou between the i th prediction and the - j th ground truth. - thresholds: the minimum IoU for a pair to be considered match. - - Returns: - (matches_gt, matches_pred): a tuple of ndarray of the following, - matches_gt: a numpy array with shape [T, G], the matched prediction index - at each iou threshold (-1 means unmatched) - matches_pred: a numpy array with shape [T, P], the matched ground truth - index at each iou threshold (-1 means unmatched) - where, - T: num of thresholds - P: num of predictions - G: num of ground truth - """ - if ious.ndim != 2: - raise ValueError('Input ious list should be a 2d array') - - if isinstance(thresholds, float): - thresholds = [thresholds] - for threshold in thresholds: - if threshold < 0: - raise ValueError(f'Invalid input of threshold = {threshold}, should be' - ' greater than or equal to 0') - - num_iou_thresholds = len(thresholds) - num_pred = ious.shape[0] - num_gt = ious.shape[1] - - # initialize the matching, and set the type to int. - matches_gt = -1 * np.ones((num_iou_thresholds, num_gt), dtype=int) - matches_pred = -1 * np.ones((num_iou_thresholds, num_pred), dtype=int) - - for i, threshold in enumerate(thresholds): - # find the matched ground truth, for each prediction - for pred_idx in range(num_pred): - # initialize the index of ground truth which will match with the - # prediction. - match_gt_index = -1 - # make sure the threshold is < 1.0 - iou = np.minimum(threshold, 1.0 - 1e-10) - - for gt_idx in range(num_gt): - # if this ground truth is already matched, skip it - if matches_gt[i, gt_idx] != -1: - continue - - if ious[pred_idx, gt_idx] < iou: - continue - - iou = ious[pred_idx, gt_idx] - match_gt_index = gt_idx - - matches_pred[i, pred_idx] = match_gt_index - if match_gt_index != -1: - matches_gt[i, match_gt_index] = pred_idx - return (matches_gt, matches_pred) - - -def compute_ious_for_image(boxes1: np.ndarray, - boxes2: np.ndarray) -> np.ndarray: - """Computes pairwise ious for two lists of boxes. - - Args: - boxes1: numpy array, containing a list of bounding boxes in 'corners' format. - boxes2: numpy array, containing a list of bounding boxes in 'corners' format. - Bounding boxes are expected to be in the corners format of [LEFT, TOP, - RIGHT, BOTTOM] For example, the bounding box with it's left bound at 20, - right bound at 100, TOP_bound at 110, BOTTOM bound at 300 is be represented - as [20, 110, 100, 300] - - Returns: - iou_lookup_table: a vector containing the pairwise ious of boxes1 and - boxes2. The (i,j) element of the table is the iou of the i th box of - boxes1 and the j th element of boxes2. - """ - - # Sanity check of the dimension and shape of inputs - if boxes1.ndim != 2 or boxes1.shape[ - 1] != 4 or boxes2.ndim != 2 or boxes2.shape[1] != 4: - raise ValueError('Input boxes lists should be a 2d array of shape ' - f'(, 4), Input shapes are {boxes1.shape},' - f' {boxes2.shape}') - - # Split each dimension - boxes1_xmin, boxes1_ymin, boxes1_xmax, boxes1_ymax = np.split( - boxes1, 4, axis=1) - boxes2_xmin, boxes2_ymin, boxes2_xmax, boxes2_ymax = np.split( - boxes2, 4, axis=1) - - # Calculate the area of each box - boxes1_area = bounding_box.bounding_box_area(boxes1) - boxes2_area = bounding_box.bounding_box_area(boxes2) - - # Calculate the intersection area for each boxes pair - intersect_ymin = np.maximum(boxes1_ymin, boxes2_ymin.transpose()) - intersect_xmin = np.maximum(boxes1_xmin, boxes2_xmin.transpose()) - intersect_ymax = np.minimum(boxes1_ymax, boxes2_ymax.transpose()) - intersect_xmax = np.minimum(boxes1_xmax, boxes2_xmax.transpose()) - - intersect_width = np.maximum(intersect_xmax - intersect_xmin, 0) - intersect_height = np.maximum(intersect_ymax - intersect_ymin, 0) - intersect_area = intersect_width * intersect_height - - # Calculate the union area of each boxes pair - union_area = boxes1_area[..., np.newaxis] + boxes2_area[np.newaxis, - ...] - intersect_area - # Return with a out arg to avoid divide by zero - return np.divide( - intersect_area, - union_area, - out=np.zeros_like(intersect_area, dtype=float), - where=union_area != 0) + ious: np.ndarray, thresholds: Union[float, Iterable[float]] +) -> Tuple[np.ndarray, np.ndarray]: + """Match predictions and ground_truth through the pairwise IoUs. + + Args: + ---- + ious: a numpy array, ious[i,j] is the iou between the i th prediction and the + j th ground truth. + thresholds: the minimum IoU for a pair to be considered match. + + Returns: + ------- + (matches_gt, matches_pred): a tuple of ndarray of the following, + matches_gt: a numpy array with shape [T, G], the matched prediction index + at each iou threshold (-1 means unmatched) + matches_pred: a numpy array with shape [T, P], the matched ground truth + index at each iou threshold (-1 means unmatched) + where, + T: num of thresholds + P: num of predictions + G: num of ground truth + """ + if ious.ndim != 2: + raise ValueError("Input ious list should be a 2d array") + + if isinstance(thresholds, float): + thresholds = [thresholds] + for threshold in thresholds: + if threshold < 0: + raise ValueError( + f"Invalid input of threshold = {threshold}, should be" + " greater than or equal to 0" + ) + + num_iou_thresholds = len(thresholds) + num_pred = ious.shape[0] + num_gt = ious.shape[1] + + # initialize the matching, and set the type to int. + matches_gt = -1 * np.ones((num_iou_thresholds, num_gt), dtype=int) + matches_pred = -1 * np.ones((num_iou_thresholds, num_pred), dtype=int) + + for i, threshold in enumerate(thresholds): + # find the matched ground truth, for each prediction + for pred_idx in range(num_pred): + # initialize the index of ground truth which will match with the + # prediction. + match_gt_index = -1 + # make sure the threshold is < 1.0 + iou = np.minimum(threshold, 1.0 - 1e-10) + + for gt_idx in range(num_gt): + # if this ground truth is already matched, skip it + if matches_gt[i, gt_idx] != -1: + continue + + if ious[pred_idx, gt_idx] < iou: + continue + + iou = ious[pred_idx, gt_idx] + match_gt_index = gt_idx + + matches_pred[i, pred_idx] = match_gt_index + if match_gt_index != -1: + matches_gt[i, match_gt_index] = pred_idx + return (matches_gt, matches_pred) + + +def compute_ious_for_image(boxes1: np.ndarray, boxes2: np.ndarray) -> np.ndarray: + """Computes pairwise ious for two lists of boxes. + + Args: + ---- + boxes1: numpy array, containing a list of bounding boxes in 'corners' format. + boxes2: numpy array, containing a list of bounding boxes in 'corners' format. + Bounding boxes are expected to be in the corners format of [LEFT, TOP, + RIGHT, BOTTOM] For example, the bounding box with it's left bound at 20, + right bound at 100, TOP_bound at 110, BOTTOM bound at 300 is be represented + as [20, 110, 100, 300] + + Returns: + ------- + iou_lookup_table: a vector containing the pairwise ious of boxes1 and + boxes2. The (i,j) element of the table is the iou of the i th box of + boxes1 and the j th element of boxes2. + """ + # Sanity check of the dimension and shape of inputs + if ( + boxes1.ndim != 2 + or boxes1.shape[1] != 4 + or boxes2.ndim != 2 + or boxes2.shape[1] != 4 + ): + raise ValueError( + "Input boxes lists should be a 2d array of shape " + f"(, 4), Input shapes are {boxes1.shape}," + f" {boxes2.shape}" + ) + + # Split each dimension + boxes1_xmin, boxes1_ymin, boxes1_xmax, boxes1_ymax = np.split(boxes1, 4, axis=1) + boxes2_xmin, boxes2_ymin, boxes2_xmax, boxes2_ymax = np.split(boxes2, 4, axis=1) + + # Calculate the area of each box + boxes1_area = bounding_box.bounding_box_area(boxes1) + boxes2_area = bounding_box.bounding_box_area(boxes2) + + # Calculate the intersection area for each boxes pair + intersect_ymin = np.maximum(boxes1_ymin, boxes2_ymin.transpose()) + intersect_xmin = np.maximum(boxes1_xmin, boxes2_xmin.transpose()) + intersect_ymax = np.minimum(boxes1_ymax, boxes2_ymax.transpose()) + intersect_xmax = np.minimum(boxes1_xmax, boxes2_xmax.transpose()) + + intersect_width = np.maximum(intersect_xmax - intersect_xmin, 0) + intersect_height = np.maximum(intersect_ymax - intersect_ymin, 0) + intersect_area = intersect_width * intersect_height + + # Calculate the union area of each boxes pair + union_area = ( + boxes1_area[..., np.newaxis] + boxes2_area[np.newaxis, ...] - intersect_area + ) + # Return with a out arg to avoid divide by zero + return np.divide( + intersect_area, + union_area, + out=np.zeros_like(intersect_area, dtype=float), + where=union_area != 0, + ) def boxes_to_label_prediction_example_weight( @@ -151,7 +162,7 @@ def boxes_to_label_prediction_example_weight( boxes_pred: np.ndarray, class_id: int, iou_threshold: float, - area_range: Optional[Tuple[float, float]] = (0, float('inf')), + area_range: Optional[Tuple[float, float]] = (0, float("inf")), max_num_detections: Optional[int] = None, class_weight: Optional[float] = None, weight: Optional[float] = None, @@ -163,90 +174,98 @@ def boxes_to_label_prediction_example_weight( ] = None, iou_func: Optional[Callable[[np.ndarray, np.ndarray], np.ndarray]] = None, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: - """Generate label prediction weight tuple from ground truths and detections. - - Args: - boxes_gt: a numpy array representing the bounding boxes in the following - format [LEFT, TOP, RIGHT, BOTTOM, CLASS] - boxes_pred: a numpy array representing the bounding boxes in the following - format [LEFT, TOP, RIGHT, BOTTOM, CLASS, CONFIDENCE] - class_id: the class to consider classification. - iou_threshold: the threshold for two bounding boxes to be considered as a - Match. - area_range: objects outside of this arange will be excluded. - max_num_detections: maximum number of detections in a single image. - class_weight: the weight associated with this class. - weight: weight of this image/example. - match_boxes_func: optional alternative function to compute match_boxes. - iou_func: optional alternative function to compute box ious. - - Returns: - (label, prediction, weight): three lists of numpy array for binary - classfication. - """ - # Sanity check of the dimension and shape of inputs - if len(boxes_gt.shape) != 2 or boxes_gt.shape[1] != 5 or len( - boxes_pred.shape) != 2 or boxes_pred.shape[1] != 6: - raise ValueError('Input boxes list should be a 2d array of shape ( ,5)' - 'for ground truth and ( ,6) for prediction, where' - f'boxes_gt.shape = {boxes_gt.shape}, ' - f'boxes_pred.shape = {boxes_pred.shape}') - if max_num_detections is not None and max_num_detections <= 0: - raise ValueError( - f'max_num_detections = {max_num_detections} must be positive.') - if class_id < 0: - raise ValueError(f'class_id = {class_id} must be 0 or positive.') - if class_weight is not None and class_weight < 0.: - raise ValueError(f'class_weight = {class_weight} must be 0 or positive.') - - # Filter all bounding boxes with a specific Class - boxes_gt = bounding_box.filter_boxes_by_class(boxes_gt, [class_id]) - boxes_pred = bounding_box.filter_boxes_by_class(boxes_pred, [class_id]) - - # Filter ground truth bounding boxes within an area range - boxes_gt = bounding_box.filter_boxes_by_area_range(boxes_gt, area_range) - - # Sort predictions with confidence(larger ones first) - boxes_pred = bounding_box.sort_boxes_by_confidence(boxes_pred) - - # Limit detection numbers to max_num_detections - if (max_num_detections - is not None) and (boxes_pred.shape[0] > max_num_detections): - boxes_pred = boxes_pred[:max_num_detections] - - if not iou_func: - iou_func = compute_ious_for_image - ious = iou_func(boxes_pred[:, :CLASS], boxes_gt[:, :CLASS]) - if not match_boxes_func: - match_boxes_func = _match_boxes - matches_gt, matches_pred = match_boxes_func(ious, iou_threshold) - - # It is assumed that it only takes one iou_threshold result while it - # returns a list of results, so the matches only needs to take the first - # entry of the results - matches_gt = matches_gt[0] - matches_pred = matches_pred[0] - - # Ignore the unmatched predictions which is out of the area range - boxes_pred_area_tag = bounding_box.check_boxes_in_area_range( - boxes_pred, area_range) - # Only count unmatched predictions within the area range - boxes_pred_num_unmatched = np.count_nonzero(boxes_pred_area_tag - & (matches_pred == -1)) - - # If the index is -1, it means unmatched - labels = np.append( - np.ones(boxes_gt.shape[0]), np.zeros(boxes_pred_num_unmatched)) - - predictions = np.concatenate( - (boxes_pred[matches_gt[matches_gt != -1], - CONFIDENCE], np.zeros(np.count_nonzero((matches_gt == -1))), - boxes_pred[boxes_pred_area_tag & (matches_pred == -1), CONFIDENCE])) - - if weight is None: - weight = 1.0 - if class_weight is None: - class_weight = 1.0 - weights = np.ones_like(labels) * weight * class_weight - - return (labels, predictions, weights) + """Generate label prediction weight tuple from ground truths and detections. + + Args: + ---- + boxes_gt: a numpy array representing the bounding boxes in the following + format [LEFT, TOP, RIGHT, BOTTOM, CLASS] + boxes_pred: a numpy array representing the bounding boxes in the following + format [LEFT, TOP, RIGHT, BOTTOM, CLASS, CONFIDENCE] + class_id: the class to consider classification. + iou_threshold: the threshold for two bounding boxes to be considered as a + Match. + area_range: objects outside of this arange will be excluded. + max_num_detections: maximum number of detections in a single image. + class_weight: the weight associated with this class. + weight: weight of this image/example. + match_boxes_func: optional alternative function to compute match_boxes. + iou_func: optional alternative function to compute box ious. + + Returns: + ------- + (label, prediction, weight): three lists of numpy array for binary + classfication. + """ + # Sanity check of the dimension and shape of inputs + if ( + len(boxes_gt.shape) != 2 + or boxes_gt.shape[1] != 5 + or len(boxes_pred.shape) != 2 + or boxes_pred.shape[1] != 6 + ): + raise ValueError( + "Input boxes list should be a 2d array of shape ( ,5)" + "for ground truth and ( ,6) for prediction, where" + f"boxes_gt.shape = {boxes_gt.shape}, " + f"boxes_pred.shape = {boxes_pred.shape}" + ) + if max_num_detections is not None and max_num_detections <= 0: + raise ValueError(f"max_num_detections = {max_num_detections} must be positive.") + if class_id < 0: + raise ValueError(f"class_id = {class_id} must be 0 or positive.") + if class_weight is not None and class_weight < 0.0: + raise ValueError(f"class_weight = {class_weight} must be 0 or positive.") + + # Filter all bounding boxes with a specific Class + boxes_gt = bounding_box.filter_boxes_by_class(boxes_gt, [class_id]) + boxes_pred = bounding_box.filter_boxes_by_class(boxes_pred, [class_id]) + + # Filter ground truth bounding boxes within an area range + boxes_gt = bounding_box.filter_boxes_by_area_range(boxes_gt, area_range) + + # Sort predictions with confidence(larger ones first) + boxes_pred = bounding_box.sort_boxes_by_confidence(boxes_pred) + + # Limit detection numbers to max_num_detections + if (max_num_detections is not None) and (boxes_pred.shape[0] > max_num_detections): + boxes_pred = boxes_pred[:max_num_detections] + + if not iou_func: + iou_func = compute_ious_for_image + ious = iou_func(boxes_pred[:, :CLASS], boxes_gt[:, :CLASS]) + if not match_boxes_func: + match_boxes_func = _match_boxes + matches_gt, matches_pred = match_boxes_func(ious, iou_threshold) + + # It is assumed that it only takes one iou_threshold result while it + # returns a list of results, so the matches only needs to take the first + # entry of the results + matches_gt = matches_gt[0] + matches_pred = matches_pred[0] + + # Ignore the unmatched predictions which is out of the area range + boxes_pred_area_tag = bounding_box.check_boxes_in_area_range(boxes_pred, area_range) + # Only count unmatched predictions within the area range + boxes_pred_num_unmatched = np.count_nonzero( + boxes_pred_area_tag & (matches_pred == -1) + ) + + # If the index is -1, it means unmatched + labels = np.append(np.ones(boxes_gt.shape[0]), np.zeros(boxes_pred_num_unmatched)) + + predictions = np.concatenate( + ( + boxes_pred[matches_gt[matches_gt != -1], CONFIDENCE], + np.zeros(np.count_nonzero(matches_gt == -1)), + boxes_pred[boxes_pred_area_tag & (matches_pred == -1), CONFIDENCE], + ) + ) + + if weight is None: + weight = 1.0 + if class_weight is None: + class_weight = 1.0 + weights = np.ones_like(labels) * weight * class_weight + + return (labels, predictions, weights) diff --git a/tensorflow_model_analysis/metrics/preprocessors/utils/box_match_test.py b/tensorflow_model_analysis/metrics/preprocessors/utils/box_match_test.py index f02eb1891f..1a082d2405 100644 --- a/tensorflow_model_analysis/metrics/preprocessors/utils/box_match_test.py +++ b/tensorflow_model_analysis/metrics/preprocessors/utils/box_match_test.py @@ -12,185 +12,248 @@ # See the License for the specific language governing permissions and # limitations under the License. """Tests for iou.""" -from absl.testing import absltest -from absl.testing import parameterized + import numpy as np +from absl.testing import absltest, parameterized + from tensorflow_model_analysis.metrics.preprocessors.utils import box_match class IouTest(parameterized.TestCase): + @parameterized.named_parameters( + ("_2dnot1d", np.array([30, 100, 70, 300]), np.array([[20, 130, 60, 290]])), + ("_4cols", np.array([30, 100, 70, 300]), np.array([[20, 60, 290]])), + ) + def test_input_check_compute_iou(self, boxes1, boxes2): + self.assertRaisesRegex( + ValueError, + "Input boxes lists should be a 2d array", + box_match.compute_ious_for_image, + boxes1, + boxes2, + ) - @parameterized.named_parameters( - ('_2dnot1d', np.array([30, 100, 70, 300]), np.array([[20, 130, 60, 290] - ])), - ('_4cols', np.array([30, 100, 70, 300]), np.array([[20, 60, 290]]))) - def test_input_check_compute_iou(self, boxes1, boxes2): - self.assertRaisesRegex(ValueError, 'Input boxes lists should be a 2d array', - box_match.compute_ious_for_image, boxes1, boxes2) + def test_compute_single_iou(self): + # Boxes are in the corners format [LEFT, RIGHT, TOP, BOTTOM] + boxes1 = np.array([[30, 100, 70, 300]]) + boxes2 = np.array([[20, 130, 60, 290]]) + result = box_match.compute_ious_for_image(boxes1, boxes2) + expected_result = np.array([[0.5]]) + np.testing.assert_allclose(result, expected_result) - def test_compute_single_iou(self): - # Boxes are in the corners format [LEFT, RIGHT, TOP, BOTTOM] - boxes1 = np.array([[30, 100, 70, 300]]) - boxes2 = np.array([[20, 130, 60, 290]]) - result = box_match.compute_ious_for_image(boxes1, boxes2) - expected_result = np.array([[0.5]]) - np.testing.assert_allclose(result, expected_result) - - def test_compute_multiple_iou(self): - boxes1 = np.array([[30, 100, 70, 300], [50, 100, 80, 200]]) - boxes2 = np.array([[20, 130, 60, 290], [30, 100, 70, 300], - [500, 100, 800, 300]]) - result = box_match.compute_ious_for_image(boxes1, boxes2) - expected_result = np.array([[0.5, 1., 0.], [7 / 87, 2 / 9, 0.]]) - np.testing.assert_allclose(result, expected_result) + def test_compute_multiple_iou(self): + boxes1 = np.array([[30, 100, 70, 300], [50, 100, 80, 200]]) + boxes2 = np.array( + [[20, 130, 60, 290], [30, 100, 70, 300], [500, 100, 800, 300]] + ) + result = box_match.compute_ious_for_image(boxes1, boxes2) + expected_result = np.array([[0.5, 1.0, 0.0], [7 / 87, 2 / 9, 0.0]]) + np.testing.assert_allclose(result, expected_result) class BoundingBoxTest(parameterized.TestCase): + def test_input_check_match_boxes(self): + # Input should include class_id + ious = np.array([20, 60, 290]) + thresholds = np.array(0.5) + with self.assertRaisesRegex(ValueError, "ious list should be a 2d array"): + _ = box_match._match_boxes(ious, thresholds) - def test_input_check_match_boxes(self): - # Input should include class_id - ious = np.array([20, 60, 290]) - thresholds = np.array(0.5) - with self.assertRaisesRegex(ValueError, 'ious list should be a 2d array'): - _ = box_match._match_boxes(ious, thresholds) - - @parameterized.named_parameters(('_one_gt_multi_pred', { - 'ious': np.array([[0.1], [0.8], [0.4]]), - 'thresholds': 0.5 - }, np.array([[1]]), np.array([[-1, 0, -1]])), ('_threshold_too_high', { - 'ious': np.array([[0.1], [0.8], [0.4]]), - 'thresholds': 0.9 - }, np.array([[-1]]), np.array([[-1, -1, -1]])), ('_multi_gt_multi_pred', { - 'ious': np.array([[0.1, 0.8, 0.4], [0.3, 0.1, 0.4], [0.6, 0.9, 0.4]]), - 'thresholds': [0., 0.5, 0.85] - }, np.array([[2, 0, 1], [2, 0, -1], [-1, 2, -1] - ]), np.array([[1, 2, 0], [1, -1, 0], [-1, -1, 1]]))) - def test_match_boxes(self, raw_input, expected_gt_match, expected_pred_match): - gt_match, pred_match = box_match._match_boxes(**raw_input) - np.testing.assert_equal(expected_gt_match, gt_match) - np.testing.assert_equal(expected_pred_match, pred_match) + @parameterized.named_parameters( + ( + "_one_gt_multi_pred", + {"ious": np.array([[0.1], [0.8], [0.4]]), "thresholds": 0.5}, + np.array([[1]]), + np.array([[-1, 0, -1]]), + ), + ( + "_threshold_too_high", + {"ious": np.array([[0.1], [0.8], [0.4]]), "thresholds": 0.9}, + np.array([[-1]]), + np.array([[-1, -1, -1]]), + ), + ( + "_multi_gt_multi_pred", + { + "ious": np.array([[0.1, 0.8, 0.4], [0.3, 0.1, 0.4], [0.6, 0.9, 0.4]]), + "thresholds": [0.0, 0.5, 0.85], + }, + np.array([[2, 0, 1], [2, 0, -1], [-1, 2, -1]]), + np.array([[1, 2, 0], [1, -1, 0], [-1, -1, 1]]), + ), + ) + def test_match_boxes(self, raw_input, expected_gt_match, expected_pred_match): + gt_match, pred_match = box_match._match_boxes(**raw_input) + np.testing.assert_equal(expected_gt_match, gt_match) + np.testing.assert_equal(expected_pred_match, pred_match) - @parameterized.named_parameters( - ('_single_case_matched', { - 'boxes_gt': np.array([[0, 50, 30, 100, 0]]), - 'boxes_pred': np.array([[10, 60, 40, 80, 0, 0.5]]), - 'iou_threshold': 0.1 - }, { - 'labels': np.array([1.]), - 'predictions': np.array([0.5]), - 'example_weights': np.array([1.]) - }), - ('_single_case_notmatched', { - 'boxes_gt': np.array([[0, 50, 30, 100, 0]]), - 'boxes_pred': np.array([[10, 60, 40, 80, 0, 0.5]]), - 'iou_threshold': 0.5 - }, { - 'labels': np.array([1., 0.]), - 'predictions': np.array([0., 0.5]), - 'example_weights': np.array([1., 1.]) - }), - ('_empty_ground_truth', { - 'boxes_gt': np.empty([0, 5]), - 'boxes_pred': np.array([[10, 60, 40, 80, 0, 0.5]]), - 'iou_threshold': 0.5 - }, { - 'labels': np.array([0.]), - 'predictions': np.array([0.5]), - 'example_weights': np.array([1.]) - }), - ('_empty_prediction', { - 'boxes_gt': np.array([[0, 50, 30, 100, 0]]), - 'boxes_pred': np.empty([0, 6]), - 'iou_threshold': 0.5 - }, { - 'labels': np.array([1]), - 'predictions': np.array([0]), - 'example_weights': np.array([1.]) - }), - ('_empty_both_truth_and_prediction', { - 'boxes_gt': np.empty([0, 5]), - 'boxes_pred': np.empty([0, 6]), - 'iou_threshold': 0.5 - }, { - 'labels': np.array([]), - 'predictions': np.array([]), - 'example_weights': np.array([]) - }), - # the following multi-example produces:(after_sorting) - # ious: np.array([[0., 0., 0.], [0., 0.5, 7/87], [0., 1., 2/9]]) - # matches_gt: [-1, 1, -1] - # matches_pred: [-1, 0, -1] - ('_multi_cases', { - 'boxes_gt': - np.array([[30, 1000, 70, 3000, 0], [30, 100, 70, 300, 0], - [50, 100, 80, 200, 0]]), - 'boxes_pred': - np.array([[20, 130, 60, 290, 0, 0.5], [30, 100, 70, 300, 0, 0.3], - [500, 100, 800, 300, 0, 0.9]]), - 'iou_threshold': - 0.3 - }, { - 'labels': np.array([1., 1., 1., 0., 0.]), - 'predictions': np.array([0.5, 0., 0., 0.9, 0.3]), - 'example_weights': np.array([1., 1., 1., 1., 1.]), - })) - def test_boxes_to_label_prediction(self, raw_input, expected_result): - result = box_match.boxes_to_label_prediction_example_weight( - boxes_gt=raw_input['boxes_gt'], - boxes_pred=raw_input['boxes_pred'], - iou_threshold=raw_input['iou_threshold'], - class_id=0) - self.assertLen(result, len(expected_result)) - np.testing.assert_allclose(result[0], expected_result['labels']) - np.testing.assert_allclose(result[1], expected_result['predictions']) - np.testing.assert_allclose(result[2], expected_result['example_weights']) + @parameterized.named_parameters( + ( + "_single_case_matched", + { + "boxes_gt": np.array([[0, 50, 30, 100, 0]]), + "boxes_pred": np.array([[10, 60, 40, 80, 0, 0.5]]), + "iou_threshold": 0.1, + }, + { + "labels": np.array([1.0]), + "predictions": np.array([0.5]), + "example_weights": np.array([1.0]), + }, + ), + ( + "_single_case_notmatched", + { + "boxes_gt": np.array([[0, 50, 30, 100, 0]]), + "boxes_pred": np.array([[10, 60, 40, 80, 0, 0.5]]), + "iou_threshold": 0.5, + }, + { + "labels": np.array([1.0, 0.0]), + "predictions": np.array([0.0, 0.5]), + "example_weights": np.array([1.0, 1.0]), + }, + ), + ( + "_empty_ground_truth", + { + "boxes_gt": np.empty([0, 5]), + "boxes_pred": np.array([[10, 60, 40, 80, 0, 0.5]]), + "iou_threshold": 0.5, + }, + { + "labels": np.array([0.0]), + "predictions": np.array([0.5]), + "example_weights": np.array([1.0]), + }, + ), + ( + "_empty_prediction", + { + "boxes_gt": np.array([[0, 50, 30, 100, 0]]), + "boxes_pred": np.empty([0, 6]), + "iou_threshold": 0.5, + }, + { + "labels": np.array([1]), + "predictions": np.array([0]), + "example_weights": np.array([1.0]), + }, + ), + ( + "_empty_both_truth_and_prediction", + { + "boxes_gt": np.empty([0, 5]), + "boxes_pred": np.empty([0, 6]), + "iou_threshold": 0.5, + }, + { + "labels": np.array([]), + "predictions": np.array([]), + "example_weights": np.array([]), + }, + ), + # the following multi-example produces:(after_sorting) + # ious: np.array([[0., 0., 0.], [0., 0.5, 7/87], [0., 1., 2/9]]) + # matches_gt: [-1, 1, -1] + # matches_pred: [-1, 0, -1] + ( + "_multi_cases", + { + "boxes_gt": np.array( + [ + [30, 1000, 70, 3000, 0], + [30, 100, 70, 300, 0], + [50, 100, 80, 200, 0], + ] + ), + "boxes_pred": np.array( + [ + [20, 130, 60, 290, 0, 0.5], + [30, 100, 70, 300, 0, 0.3], + [500, 100, 800, 300, 0, 0.9], + ] + ), + "iou_threshold": 0.3, + }, + { + "labels": np.array([1.0, 1.0, 1.0, 0.0, 0.0]), + "predictions": np.array([0.5, 0.0, 0.0, 0.9, 0.3]), + "example_weights": np.array([1.0, 1.0, 1.0, 1.0, 1.0]), + }, + ), + ) + def test_boxes_to_label_prediction(self, raw_input, expected_result): + result = box_match.boxes_to_label_prediction_example_weight( + boxes_gt=raw_input["boxes_gt"], + boxes_pred=raw_input["boxes_pred"], + iou_threshold=raw_input["iou_threshold"], + class_id=0, + ) + self.assertLen(result, len(expected_result)) + np.testing.assert_allclose(result[0], expected_result["labels"]) + np.testing.assert_allclose(result[1], expected_result["predictions"]) + np.testing.assert_allclose(result[2], expected_result["example_weights"]) - @parameterized.named_parameters((('_filter_by_class'), { - 'boxes_gt': np.array([[0, 50, 30, 100, 0]]), - 'boxes_pred': np.array([[10, 60, 40, 80, 1, 0.5]]), - 'iou_threshold': 0.1, - 'class_id': 1, - 'area_range': [0, 10000], - 'max_num_detections': 1 - }, { - 'labels': np.array([0.]), - 'predictions': np.array([0.5]), - 'example_weights': np.array([1.0]), - }), (('_filter_by_area_range'), { - 'boxes_gt': np.array([[0, 50, 30, 100, 0]]), - 'boxes_pred': np.array([[10, 60, 40, 80, 1, 0.5]]), - 'iou_threshold': 0.1, - 'class_id': 0, - 'area_range': [100, 200], - 'max_num_detections': 1 - }, { - 'labels': np.array([]), - 'predictions': np.array([]), - 'example_weights': np.array([]), - }), (('_filter_by_maximum_detections'), { - 'boxes_gt': - np.array([[0, 50, 30, 100, 0]]), - 'boxes_pred': - np.array([[10, 60, 40, 80, 1, 0.5], [10, 60, 40, 80, 0, 0.8]]), - 'iou_threshold': - 0.1, - 'class_id': - 1, - 'area_range': [0, 10000], - 'max_num_detections': - 1 - }, { - 'labels': np.array([0.]), - 'predictions': np.array([0.5]), - 'example_weights': np.array([1.0]), - })) - def test_boxes_to_label_prediction_filter(self, raw_input, expected_result): - result = box_match.boxes_to_label_prediction_example_weight(**raw_input) - self.assertLen(result, len(expected_result)) - np.testing.assert_allclose(result[0], expected_result['labels']) - np.testing.assert_allclose(result[1], expected_result['predictions']) - np.testing.assert_allclose(result[2], expected_result['example_weights']) + @parameterized.named_parameters( + ( + ("_filter_by_class"), + { + "boxes_gt": np.array([[0, 50, 30, 100, 0]]), + "boxes_pred": np.array([[10, 60, 40, 80, 1, 0.5]]), + "iou_threshold": 0.1, + "class_id": 1, + "area_range": [0, 10000], + "max_num_detections": 1, + }, + { + "labels": np.array([0.0]), + "predictions": np.array([0.5]), + "example_weights": np.array([1.0]), + }, + ), + ( + ("_filter_by_area_range"), + { + "boxes_gt": np.array([[0, 50, 30, 100, 0]]), + "boxes_pred": np.array([[10, 60, 40, 80, 1, 0.5]]), + "iou_threshold": 0.1, + "class_id": 0, + "area_range": [100, 200], + "max_num_detections": 1, + }, + { + "labels": np.array([]), + "predictions": np.array([]), + "example_weights": np.array([]), + }, + ), + ( + ("_filter_by_maximum_detections"), + { + "boxes_gt": np.array([[0, 50, 30, 100, 0]]), + "boxes_pred": np.array( + [[10, 60, 40, 80, 1, 0.5], [10, 60, 40, 80, 0, 0.8]] + ), + "iou_threshold": 0.1, + "class_id": 1, + "area_range": [0, 10000], + "max_num_detections": 1, + }, + { + "labels": np.array([0.0]), + "predictions": np.array([0.5]), + "example_weights": np.array([1.0]), + }, + ), + ) + def test_boxes_to_label_prediction_filter(self, raw_input, expected_result): + result = box_match.boxes_to_label_prediction_example_weight(**raw_input) + self.assertLen(result, len(expected_result)) + np.testing.assert_allclose(result[0], expected_result["labels"]) + np.testing.assert_allclose(result[1], expected_result["predictions"]) + np.testing.assert_allclose(result[2], expected_result["example_weights"]) -if __name__ == '__main__': - absltest.main() +if __name__ == "__main__": + absltest.main() diff --git a/tensorflow_model_analysis/metrics/preprocessors/utils/object_detection_format.py b/tensorflow_model_analysis/metrics/preprocessors/utils/object_detection_format.py index 7eaf3ffba7..1f96d67223 100644 --- a/tensorflow_model_analysis/metrics/preprocessors/utils/object_detection_format.py +++ b/tensorflow_model_analysis/metrics/preprocessors/utils/object_detection_format.py @@ -21,138 +21,160 @@ from typing import List, Optional import numpy as np + from tensorflow_model_analysis.metrics import metric_util from tensorflow_model_analysis.utils import util -def stack_labels(extracts: util.StandardExtracts, - col_names: List[str], - model_name: Optional[str] = None, - allow_missing_key: Optional[bool] = False) -> np.ndarray: - """Stacks several numpy arrays in the extracts into a single one for labels. - - It will search for column_names in labels, features and transformed features. - If not found, it will raise an error. - - Examples: - Extracts - { - features: { - 'xmin': [0, 0, 0.1] - 'xmax': [1, 1, 0.5] - 'ymin_max': [[0.2, 1], [0.3, 1], [0.1, 1]] - } - } - stack_labels(extracts, ['xmin', 'xmax']) == - np.array([[0, 1], [0, 1], [0.1, 0.5]]) - stack_labels(extracts, ['xmin', 'xmax', 'ymin_max']) = - np.array([[0, 1, 0.2, 1], [0, 1, 0.3, 1], [0.1, 0.5, 0.1 ,1]]) - - Args: - extracts: TFMA extracts that stores the keys. - col_names: Keys of columns which will be stacked. - model_name: The name of the model for outputs. - allow_missing_key: (Optional) If true, it will return empty array instead of - raising errors when col_names are not found. - - Returns: - A numpy array that stacks all the columns together. - - Raises: - KeyError: The columns for stacking are not found in extracts. - ValueError: The format of the input is not valid for stacking. - """ - cols = [] - - dict_to_search = collections.ChainMap( - extracts.get_labels(model_name) or {}, - extracts.get_features() or {}, - extracts.get_transformed_features(model_name) or {}) - - for col_name in col_names: - if dict_to_search and col_name in dict_to_search: - new_cols = dict_to_search[col_name] - if new_cols.ndim == 2: - cols.append(new_cols) - elif new_cols.ndim == 1: - cols.append(new_cols[:, np.newaxis]) - else: - raise ValueError(f"Dimension of input under {col_name}" - " should be 1 or 2.") - else: - if allow_missing_key: - return np.empty((0, 5)) - else: - raise KeyError(f"Key {col_name} is not found under labels, " - "features, or transformed features of the extracts." - "Please set allow_missing_key to True, if you want to " - "return empty array instead.") - result = np.hstack(cols) - return result - - -def stack_predictions(extracts: util.StandardExtracts, - col_names: List[str], - model_name: Optional[str] = None, - allow_missing_key: Optional[bool] = False) -> np.ndarray: - """Stacks several numpy arrays in the extracts into a single predictions. - - It will search for column_names in labels, features and transformed features. - If not found, it will raise an error. - - Examples: - Extracts - { - features: { - 'xmin': [0, 0, 0.1] - 'xmax': [1, 1, 0.5] - 'ymin_max': [[0.2, 1], [0.3, 1], [0.1, 1]] - } - } - stack_predictions(extracts, ['xmin', 'xmax', 'ymin_max']) = - np.array([[0, 1, 0.2, 1], [0, 1, 0.3, 1], [0.1, 0.5, 0.1 ,1]]) - - Args: - extracts: TFMA extracts that stores the keys. - col_names: Keys of columns which will be stacked. - model_name: The name of the model for outputs. - allow_missing_key: (Optional) If true, it will return empty array instead of - raising errors when col_names are not found. - - Returns: - A numpy array that stacks all the columns together. - - Raises: - KeyError: The columns for stacking are not found in extracts. - ValueError: The format of the input is not valid for stacking. - """ - cols = [] - - dict_to_search = collections.ChainMap( - extracts.get_predictions(model_name) or {}, - extracts.get_features() or {}, - extracts.get_transformed_features(model_name) or {}) - - for col_name in col_names: - if dict_to_search and col_name in dict_to_search: - new_cols = dict_to_search[col_name] - if new_cols.ndim == 2: - cols.append(new_cols) - elif new_cols.ndim == 1: - cols.append(new_cols[:, np.newaxis]) - else: - raise ValueError(f"Dimension of input under {col_name} is " - f"{new_cols.ndim}, but should be 1 or 2.") - else: - if allow_missing_key: - return np.empty((0, 6)) - else: - raise KeyError(f"Key {col_name} is not found under predictions, " - "features, or transformed features of the extracts." - "Please set allow_missing_key to True, if you want to " - "return empty array instead.") - result = np.hstack(cols) - return result +def stack_labels( + extracts: util.StandardExtracts, + col_names: List[str], + model_name: Optional[str] = None, + allow_missing_key: Optional[bool] = False, +) -> np.ndarray: + """Stacks several numpy arrays in the extracts into a single one for labels. + + It will search for column_names in labels, features and transformed features. + If not found, it will raise an error. + + Examples: + -------- + Extracts + { + features: { + 'xmin': [0, 0, 0.1] + 'xmax': [1, 1, 0.5] + 'ymin_max': [[0.2, 1], [0.3, 1], [0.1, 1]] + } + } + stack_labels(extracts, ['xmin', 'xmax']) == + np.array([[0, 1], [0, 1], [0.1, 0.5]]) + stack_labels(extracts, ['xmin', 'xmax', 'ymin_max']) = + np.array([[0, 1, 0.2, 1], [0, 1, 0.3, 1], [0.1, 0.5, 0.1 ,1]]) + + Args: + ---- + extracts: TFMA extracts that stores the keys. + col_names: Keys of columns which will be stacked. + model_name: The name of the model for outputs. + allow_missing_key: (Optional) If true, it will return empty array instead of + raising errors when col_names are not found. + + Returns: + ------- + A numpy array that stacks all the columns together. + + Raises: + ------ + KeyError: The columns for stacking are not found in extracts. + ValueError: The format of the input is not valid for stacking. + """ + cols = [] + + dict_to_search = collections.ChainMap( + extracts.get_labels(model_name) or {}, + extracts.get_features() or {}, + extracts.get_transformed_features(model_name) or {}, + ) + + for col_name in col_names: + if dict_to_search and col_name in dict_to_search: + new_cols = dict_to_search[col_name] + if new_cols.ndim == 2: + cols.append(new_cols) + elif new_cols.ndim == 1: + cols.append(new_cols[:, np.newaxis]) + else: + raise ValueError( + f"Dimension of input under {col_name}" " should be 1 or 2." + ) + else: + if allow_missing_key: + return np.empty((0, 5)) + else: + raise KeyError( + f"Key {col_name} is not found under labels, " + "features, or transformed features of the extracts." + "Please set allow_missing_key to True, if you want to " + "return empty array instead." + ) + result = np.hstack(cols) + return result + + +def stack_predictions( + extracts: util.StandardExtracts, + col_names: List[str], + model_name: Optional[str] = None, + allow_missing_key: Optional[bool] = False, +) -> np.ndarray: + """Stacks several numpy arrays in the extracts into a single predictions. + + It will search for column_names in labels, features and transformed features. + If not found, it will raise an error. + + Examples: + -------- + Extracts + { + features: { + 'xmin': [0, 0, 0.1] + 'xmax': [1, 1, 0.5] + 'ymin_max': [[0.2, 1], [0.3, 1], [0.1, 1]] + } + } + stack_predictions(extracts, ['xmin', 'xmax', 'ymin_max']) = + np.array([[0, 1, 0.2, 1], [0, 1, 0.3, 1], [0.1, 0.5, 0.1 ,1]]) + + Args: + ---- + extracts: TFMA extracts that stores the keys. + col_names: Keys of columns which will be stacked. + model_name: The name of the model for outputs. + allow_missing_key: (Optional) If true, it will return empty array instead of + raising errors when col_names are not found. + + Returns: + ------- + A numpy array that stacks all the columns together. + + Raises: + ------ + KeyError: The columns for stacking are not found in extracts. + ValueError: The format of the input is not valid for stacking. + """ + cols = [] + + dict_to_search = collections.ChainMap( + extracts.get_predictions(model_name) or {}, + extracts.get_features() or {}, + extracts.get_transformed_features(model_name) or {}, + ) + + for col_name in col_names: + if dict_to_search and col_name in dict_to_search: + new_cols = dict_to_search[col_name] + if new_cols.ndim == 2: + cols.append(new_cols) + elif new_cols.ndim == 1: + cols.append(new_cols[:, np.newaxis]) + else: + raise ValueError( + f"Dimension of input under {col_name} is " + f"{new_cols.ndim}, but should be 1 or 2." + ) + else: + if allow_missing_key: + return np.empty((0, 6)) + else: + raise KeyError( + f"Key {col_name} is not found under predictions, " + "features, or transformed features of the extracts." + "Please set allow_missing_key to True, if you want to " + "return empty array instead." + ) + result = np.hstack(cols) + return result def truncate_by_num_detections( @@ -162,42 +184,48 @@ def truncate_by_num_detections( model_name: Optional[str] = None, allow_missing_key: Optional[bool] = False, ) -> np.ndarray: - """Get the array to be truncated by the number of rows. - - Args: - extracts: TFMA extracts that stores the keys. - num_rows_key: Number of rows in each column except the paddings. For - multi-dimensional input, it will truncate on the first dimension. - array_to_truncate: the array to be truncated te - model_name: The name of the model for outputs. - allow_missing_key: (Optional) If true, it will do nothing instead of - raising errors when col_names are not found. - - Returns: - The array truncated by the number of rows. - - Raises: - KeyError: The num_rows_key is not found in extracts. - """ - num_of_rows = None - - dict_to_search = collections.ChainMap( - extracts.get_predictions(model_name) or {}, - extracts.get_features() or {}, - extracts.get_transformed_features(model_name) or {}) - - if num_rows_key: - if dict_to_search and num_rows_key in dict_to_search: - num_of_rows = dict_to_search[num_rows_key] - if isinstance(num_of_rows, np.ndarray): - num_of_rows = metric_util.safe_to_scalar(num_of_rows) - else: - if not allow_missing_key: - raise KeyError(f"Key {num_rows_key} is not found under predictions, " - "features, or transformed features of the extracts." - "Please set allow_missing_key to True, if you want to " - "skip truncation instead.") - result = array_to_truncate - if num_of_rows and num_of_rows > 0 and len(result) > num_of_rows: - result = result[:num_of_rows] - return result + """Get the array to be truncated by the number of rows. + + Args: + ---- + extracts: TFMA extracts that stores the keys. + num_rows_key: Number of rows in each column except the paddings. For + multi-dimensional input, it will truncate on the first dimension. + array_to_truncate: the array to be truncated te + model_name: The name of the model for outputs. + allow_missing_key: (Optional) If true, it will do nothing instead of + raising errors when col_names are not found. + + Returns: + ------- + The array truncated by the number of rows. + + Raises: + ------ + KeyError: The num_rows_key is not found in extracts. + """ + num_of_rows = None + + dict_to_search = collections.ChainMap( + extracts.get_predictions(model_name) or {}, + extracts.get_features() or {}, + extracts.get_transformed_features(model_name) or {}, + ) + + if num_rows_key: + if dict_to_search and num_rows_key in dict_to_search: + num_of_rows = dict_to_search[num_rows_key] + if isinstance(num_of_rows, np.ndarray): + num_of_rows = metric_util.safe_to_scalar(num_of_rows) + else: + if not allow_missing_key: + raise KeyError( + f"Key {num_rows_key} is not found under predictions, " + "features, or transformed features of the extracts." + "Please set allow_missing_key to True, if you want to " + "skip truncation instead." + ) + result = array_to_truncate + if num_of_rows and num_of_rows > 0 and len(result) > num_of_rows: + result = result[:num_of_rows] + return result diff --git a/tensorflow_model_analysis/metrics/preprocessors/utils/object_detection_format_test.py b/tensorflow_model_analysis/metrics/preprocessors/utils/object_detection_format_test.py index fd787432c3..ceecc21d65 100644 --- a/tensorflow_model_analysis/metrics/preprocessors/utils/object_detection_format_test.py +++ b/tensorflow_model_analysis/metrics/preprocessors/utils/object_detection_format_test.py @@ -12,52 +12,64 @@ # See the License for the specific language governing permissions and # limitations under the License. """Tests for iou.""" -from absl.testing import absltest -from absl.testing import parameterized + import numpy as np +from absl.testing import absltest, parameterized + from tensorflow_model_analysis import constants -from tensorflow_model_analysis.metrics.preprocessors.utils import object_detection_format +from tensorflow_model_analysis.metrics.preprocessors.utils import ( + object_detection_format, +) from tensorflow_model_analysis.utils import util -_STACK_SPLITFORMAT = util.StandardExtracts({ - constants.FEATURES_KEY: { - 'xmin': np.array([30, 50]), - 'ymin': np.array([100, 100]), - 'xmax': np.array([70, 80]), - 'ymax': np.array([300, 200]), - 'class_id': np.array([0, 0]), - 'num_detections': np.array([1]) - }, -}) - -_STACK_GROUPFORMAT = util.StandardExtracts({ - constants.LABELS_KEY: { - 'bbox': np.array([[30, 100, 70, 300], [50, 100, 80, 200]]), - 'class_id': np.array([0, 0]) - }, -}) +_STACK_SPLITFORMAT = util.StandardExtracts( + { + constants.FEATURES_KEY: { + "xmin": np.array([30, 50]), + "ymin": np.array([100, 100]), + "xmax": np.array([70, 80]), + "ymax": np.array([300, 200]), + "class_id": np.array([0, 0]), + "num_detections": np.array([1]), + }, + } +) + +_STACK_GROUPFORMAT = util.StandardExtracts( + { + constants.LABELS_KEY: { + "bbox": np.array([[30, 100, 70, 300], [50, 100, 80, 200]]), + "class_id": np.array([0, 0]), + }, + } +) _STACK_RESULT = np.array([[30, 100, 70, 300, 0], [50, 100, 80, 200, 0]]) class ObjectDetectionFormatTest(parameterized.TestCase): + @parameterized.named_parameters( + ( + "_splitted_columns", + _STACK_SPLITFORMAT, + ["xmin", "ymin", "xmax", "ymax", "class_id"], + _STACK_RESULT, + ), + ("_partially_stacked", _STACK_GROUPFORMAT, ["bbox", "class_id"], _STACK_RESULT), + ) + def test_stack_column(self, extracts, column_names, expected_result): + result = object_detection_format.stack_labels(extracts, column_names) + np.testing.assert_allclose(result, expected_result) + + def test_stack_predictions(self): + result = object_detection_format.stack_predictions( + _STACK_SPLITFORMAT, ["xmin", "ymin", "xmax", "ymax", "class_id"] + ) + result = object_detection_format.truncate_by_num_detections( + _STACK_SPLITFORMAT, "num_detections", result + ) + np.testing.assert_allclose(result, _STACK_RESULT[:1]) + - @parameterized.named_parameters( - ('_splitted_columns', _STACK_SPLITFORMAT, - ['xmin', 'ymin', 'xmax', 'ymax', 'class_id'], _STACK_RESULT), - ('_partially_stacked', _STACK_GROUPFORMAT, ['bbox', 'class_id' - ], _STACK_RESULT)) - def test_stack_column(self, extracts, column_names, expected_result): - result = object_detection_format.stack_labels(extracts, column_names) - np.testing.assert_allclose(result, expected_result) - - def test_stack_predictions(self): - result = object_detection_format.stack_predictions( - _STACK_SPLITFORMAT, ['xmin', 'ymin', 'xmax', 'ymax', 'class_id']) - result = object_detection_format.truncate_by_num_detections( - _STACK_SPLITFORMAT, 'num_detections', result) - np.testing.assert_allclose(result, _STACK_RESULT[:1]) - - -if __name__ == '__main__': - absltest.main() +if __name__ == "__main__": + absltest.main() diff --git a/tensorflow_model_analysis/metrics/query_statistics.py b/tensorflow_model_analysis/metrics/query_statistics.py index 31565e58a1..4c8a86bb0e 100644 --- a/tensorflow_model_analysis/metrics/query_statistics.py +++ b/tensorflow_model_analysis/metrics/query_statistics.py @@ -16,42 +16,46 @@ from typing import Dict, Iterable, Optional import apache_beam as beam -from tensorflow_model_analysis.metrics import metric_types -from tensorflow_model_analysis.metrics import metric_util + +from tensorflow_model_analysis.metrics import metric_types, metric_util from tensorflow_model_analysis.proto import config_pb2 -TOTAL_QUERIES_NAME = 'total_queries' -TOTAL_DOCUMENTS_NAME = 'total_documents' -MIN_DOCUMENTS_NAME = 'min_documents' -MAX_DOCUMENTS_NAME = 'max_documents' +TOTAL_QUERIES_NAME = "total_queries" +TOTAL_DOCUMENTS_NAME = "total_documents" +MIN_DOCUMENTS_NAME = "min_documents" +MAX_DOCUMENTS_NAME = "max_documents" class QueryStatistics(metric_types.Metric): - """Query statistic metrics. - - These metrics are query/ranking based so a query_key must also be provided in - the associated metrics spec. - """ - - def __init__(self, - total_queries_name: str = TOTAL_QUERIES_NAME, - total_documents_name: str = TOTAL_DOCUMENTS_NAME, - min_documents_name: str = MIN_DOCUMENTS_NAME, - max_documents_name: str = MAX_DOCUMENTS_NAME): - """Initializes query statistics metrics. - - Args: - total_queries_name: Total queries metric name. - total_documents_name: Total documents metric name. - min_documents_name: Min documents name. - max_documents_name: Max documents name. + """Query statistic metrics. + + These metrics are query/ranking based so a query_key must also be provided in + the associated metrics spec. """ - super().__init__( - _query_statistics, - total_queries_name=total_queries_name, - total_documents_name=total_documents_name, - min_documents_name=min_documents_name, - max_documents_name=max_documents_name) + + def __init__( + self, + total_queries_name: str = TOTAL_QUERIES_NAME, + total_documents_name: str = TOTAL_DOCUMENTS_NAME, + min_documents_name: str = MIN_DOCUMENTS_NAME, + max_documents_name: str = MAX_DOCUMENTS_NAME, + ): + """Initializes query statistics metrics. + + Args: + ---- + total_queries_name: Total queries metric name. + total_documents_name: Total documents metric name. + min_documents_name: Min documents name. + max_documents_name: Max documents name. + """ + super().__init__( + _query_statistics, + total_queries_name=total_queries_name, + total_documents_name=total_documents_name, + min_documents_name=min_documents_name, + max_documents_name=max_documents_name, + ) metric_types.register_metric(QueryStatistics) @@ -63,125 +67,141 @@ def _query_statistics( min_documents_name: str = MIN_DOCUMENTS_NAME, max_documents_name: str = MAX_DOCUMENTS_NAME, eval_config: Optional[config_pb2.EvalConfig] = None, - model_name: str = '', - output_name: str = '', - query_key: str = '', - example_weighted: bool = False) -> metric_types.MetricComputations: - """Returns metric computations for query statistics.""" - if not query_key: - raise ValueError('a query_key is required to use QueryStatistics metrics') - - total_queries_key = metric_types.MetricKey( - name=total_queries_name, - model_name=model_name, - output_name=output_name, - example_weighted=example_weighted) - total_documents_key = metric_types.MetricKey( - name=total_documents_name, - model_name=model_name, - output_name=output_name, - example_weighted=example_weighted) - min_documents_key = metric_types.MetricKey( - name=min_documents_name, - model_name=model_name, - output_name=output_name, - example_weighted=example_weighted) - max_documents_key = metric_types.MetricKey( - name=max_documents_name, - model_name=model_name, - output_name=output_name, - example_weighted=example_weighted) - - return [ - metric_types.MetricComputation( - keys=[ - total_queries_key, total_documents_key, min_documents_key, - max_documents_key - ], - preprocessors=None, - combiner=_QueryStatisticsCombiner(total_queries_key, - total_documents_key, - min_documents_key, - max_documents_key, eval_config, - model_name, output_name, - example_weighted)) - ] + model_name: str = "", + output_name: str = "", + query_key: str = "", + example_weighted: bool = False, +) -> metric_types.MetricComputations: + """Returns metric computations for query statistics.""" + if not query_key: + raise ValueError("a query_key is required to use QueryStatistics metrics") + + total_queries_key = metric_types.MetricKey( + name=total_queries_name, + model_name=model_name, + output_name=output_name, + example_weighted=example_weighted, + ) + total_documents_key = metric_types.MetricKey( + name=total_documents_name, + model_name=model_name, + output_name=output_name, + example_weighted=example_weighted, + ) + min_documents_key = metric_types.MetricKey( + name=min_documents_name, + model_name=model_name, + output_name=output_name, + example_weighted=example_weighted, + ) + max_documents_key = metric_types.MetricKey( + name=max_documents_name, + model_name=model_name, + output_name=output_name, + example_weighted=example_weighted, + ) + + return [ + metric_types.MetricComputation( + keys=[ + total_queries_key, + total_documents_key, + min_documents_key, + max_documents_key, + ], + preprocessors=None, + combiner=_QueryStatisticsCombiner( + total_queries_key, + total_documents_key, + min_documents_key, + max_documents_key, + eval_config, + model_name, + output_name, + example_weighted, + ), + ) + ] class _QueryStatisticsAccumulator: - """Query statistics accumulator.""" - __slots__ = [ - 'total_queries', 'total_documents', 'min_documents', 'max_documents' - ] + """Query statistics accumulator.""" + + __slots__ = ["total_queries", "total_documents", "min_documents", "max_documents"] - def __init__(self): - self.total_queries = 0.0 - self.total_documents = 0.0 - self.min_documents = float('inf') - self.max_documents = 0.0 + def __init__(self): + self.total_queries = 0.0 + self.total_documents = 0.0 + self.min_documents = float("inf") + self.max_documents = 0.0 class _QueryStatisticsCombiner(beam.CombineFn): - """Computes query statistics metrics.""" - - def __init__(self, total_queries_key: metric_types.MetricKey, - total_documents_key: metric_types.MetricKey, - min_documents_key: metric_types.MetricKey, - max_documents_key: metric_types.MetricKey, - eval_config: config_pb2.EvalConfig, model_name: str, - output_name: str, example_weighted: bool): - self._total_queries_key = total_queries_key - self._total_documents_key = total_documents_key - self._min_documents_key = min_documents_key - self._max_documents_key = max_documents_key - self._eval_config = eval_config - self._model_name = model_name - self._output_name = output_name - self._example_weighted = example_weighted - - def create_accumulator(self) -> _QueryStatisticsAccumulator: - return _QueryStatisticsAccumulator() - - def add_input( - self, accumulator: _QueryStatisticsAccumulator, - element: metric_types.StandardMetricInputs - ) -> _QueryStatisticsAccumulator: - for _, _, example_weight in (metric_util.to_label_prediction_example_weight( - element, - eval_config=self._eval_config, - model_name=self._model_name, - output_name=self._output_name, - example_weighted=self._example_weighted, - flatten=False, - require_single_example_weight=True)): - example_weight = float(example_weight) - accumulator.total_queries += example_weight - num_documents = len(element.prediction) * example_weight - accumulator.total_documents += num_documents - accumulator.min_documents = min(accumulator.min_documents, num_documents) - accumulator.max_documents = max(accumulator.max_documents, num_documents) - return accumulator - - def merge_accumulators( - self, accumulators: Iterable[_QueryStatisticsAccumulator] - ) -> _QueryStatisticsAccumulator: - accumulators = iter(accumulators) - result = next(accumulators) - for accumulator in accumulators: - result.total_queries += accumulator.total_queries - result.total_documents += accumulator.total_documents - result.min_documents = min(result.min_documents, - accumulator.min_documents) - result.max_documents = max(result.max_documents, - accumulator.max_documents) - return result - - def extract_output( - self, accumulator: _QueryStatisticsAccumulator - ) -> Dict[metric_types.MetricKey, float]: - return { - self._total_queries_key: accumulator.total_queries, - self._total_documents_key: accumulator.total_documents, - self._min_documents_key: accumulator.min_documents, - self._max_documents_key: accumulator.max_documents - } + """Computes query statistics metrics.""" + + def __init__( + self, + total_queries_key: metric_types.MetricKey, + total_documents_key: metric_types.MetricKey, + min_documents_key: metric_types.MetricKey, + max_documents_key: metric_types.MetricKey, + eval_config: config_pb2.EvalConfig, + model_name: str, + output_name: str, + example_weighted: bool, + ): + self._total_queries_key = total_queries_key + self._total_documents_key = total_documents_key + self._min_documents_key = min_documents_key + self._max_documents_key = max_documents_key + self._eval_config = eval_config + self._model_name = model_name + self._output_name = output_name + self._example_weighted = example_weighted + + def create_accumulator(self) -> _QueryStatisticsAccumulator: + return _QueryStatisticsAccumulator() + + def add_input( + self, + accumulator: _QueryStatisticsAccumulator, + element: metric_types.StandardMetricInputs, + ) -> _QueryStatisticsAccumulator: + for _, _, example_weight in metric_util.to_label_prediction_example_weight( + element, + eval_config=self._eval_config, + model_name=self._model_name, + output_name=self._output_name, + example_weighted=self._example_weighted, + flatten=False, + require_single_example_weight=True, + ): + example_weight = float(example_weight) + accumulator.total_queries += example_weight + num_documents = len(element.prediction) * example_weight + accumulator.total_documents += num_documents + accumulator.min_documents = min(accumulator.min_documents, num_documents) + accumulator.max_documents = max(accumulator.max_documents, num_documents) + return accumulator + + def merge_accumulators( + self, accumulators: Iterable[_QueryStatisticsAccumulator] + ) -> _QueryStatisticsAccumulator: + accumulators = iter(accumulators) + result = next(accumulators) + for accumulator in accumulators: + result.total_queries += accumulator.total_queries + result.total_documents += accumulator.total_documents + result.min_documents = min(result.min_documents, accumulator.min_documents) + result.max_documents = max(result.max_documents, accumulator.max_documents) + return result + + def extract_output( + self, accumulator: _QueryStatisticsAccumulator + ) -> Dict[metric_types.MetricKey, float]: + return { + self._total_queries_key: accumulator.total_queries, + self._total_documents_key: accumulator.total_documents, + self._min_documents_key: accumulator.min_documents, + self._max_documents_key: accumulator.max_documents, + } diff --git a/tensorflow_model_analysis/metrics/query_statistics_test.py b/tensorflow_model_analysis/metrics/query_statistics_test.py index 64ce3d5aec..051b411254 100644 --- a/tensorflow_model_analysis/metrics/query_statistics_test.py +++ b/tensorflow_model_analysis/metrics/query_statistics_test.py @@ -13,14 +13,17 @@ # limitations under the License. """Tests for query statistics metrics.""" -from absl.testing import parameterized import apache_beam as beam -from apache_beam.testing import util import numpy as np import tensorflow as tf -from tensorflow_model_analysis.metrics import metric_types -from tensorflow_model_analysis.metrics import metric_util -from tensorflow_model_analysis.metrics import query_statistics +from absl.testing import parameterized +from apache_beam.testing import util + +from tensorflow_model_analysis.metrics import ( + metric_types, + metric_util, + query_statistics, +) from tensorflow_model_analysis.utils import test_util from tensorflow_model_analysis.utils import util as tfma_util @@ -28,114 +31,111 @@ class QueryStatisticsTest( test_util.TensorflowModelAnalysisTest, parameterized.TestCase ): + @parameterized.named_parameters( + ("weighted", True, 1.0 + 2.0 + 3.0, 1 * 2.0 + 3 * 2.0 + 1 * 3.0, 2.0, 3 * 2.0), + ("unweighted", False, 3.0, 6.0, 1.0, 3.0), + ) + def testQueryStatistics( + self, + example_weighted, + total_queries, + total_documents, + min_documents, + max_documents, + ): + metrics = query_statistics.QueryStatistics().computations( + query_key="query", example_weighted=example_weighted + )[0] - @parameterized.named_parameters(('weighted', True, 1.0 + 2.0 + 3.0, - 1 * 2.0 + 3 * 2.0 + 1 * 3.0, 2.0, 3 * 2.0), - ('unweighted', False, 3.0, 6.0, 1.0, 3.0)) - def testQueryStatistics(self, example_weighted, total_queries, - total_documents, min_documents, max_documents): - metrics = query_statistics.QueryStatistics().computations( - query_key='query', example_weighted=example_weighted)[0] - - query1_example1 = { - 'labels': np.array([1.0]), - 'predictions': np.array([0.2]), - 'example_weights': np.array([1.0]), - 'features': { - 'query': np.array(['query1']), - 'gain': np.array([1.0]) + query1_example1 = { + "labels": np.array([1.0]), + "predictions": np.array([0.2]), + "example_weights": np.array([1.0]), + "features": {"query": np.array(["query1"]), "gain": np.array([1.0])}, } - } - query1_example2 = { - 'labels': np.array([0.0]), - 'predictions': np.array([0.8]), - 'example_weights': np.array([1.0]), - 'features': { - 'query': np.array(['query1']), - 'gain': np.array([0.5]) + query1_example2 = { + "labels": np.array([0.0]), + "predictions": np.array([0.8]), + "example_weights": np.array([1.0]), + "features": {"query": np.array(["query1"]), "gain": np.array([0.5])}, } - } - query2_example1 = { - 'labels': np.array([0.0]), - 'predictions': np.array([0.5]), - 'example_weights': np.array([2.0]), - 'features': { - 'query': np.array(['query2']), - 'gain': np.array([0.5]) + query2_example1 = { + "labels": np.array([0.0]), + "predictions": np.array([0.5]), + "example_weights": np.array([2.0]), + "features": {"query": np.array(["query2"]), "gain": np.array([0.5])}, } - } - query2_example2 = { - 'labels': np.array([1.0]), - 'predictions': np.array([0.9]), - 'example_weights': np.array([2.0]), - 'features': { - 'query': np.array(['query2']), - 'gain': np.array([1.0]) + query2_example2 = { + "labels": np.array([1.0]), + "predictions": np.array([0.9]), + "example_weights": np.array([2.0]), + "features": {"query": np.array(["query2"]), "gain": np.array([1.0])}, } - } - query2_example3 = { - 'labels': np.array([0.0]), - 'predictions': np.array([0.1]), - 'example_weights': np.array([2.0]), - 'features': { - 'query': np.array(['query2']), - 'gain': np.array([0.1]) + query2_example3 = { + "labels": np.array([0.0]), + "predictions": np.array([0.1]), + "example_weights": np.array([2.0]), + "features": {"query": np.array(["query2"]), "gain": np.array([0.1])}, } - } - query3_example1 = { - 'labels': np.array([1.0]), - 'predictions': np.array([0.9]), - 'example_weights': np.array([3.0]), - 'features': { - 'query': np.array(['query3']), - 'gain': np.array([1.0]) + query3_example1 = { + "labels": np.array([1.0]), + "predictions": np.array([0.9]), + "example_weights": np.array([3.0]), + "features": {"query": np.array(["query3"]), "gain": np.array([1.0])}, } - } - examples = [ - tfma_util.merge_extracts([query1_example1, query1_example2]), - tfma_util.merge_extracts( - [query2_example1, query2_example2, query2_example3]), - tfma_util.merge_extracts([query3_example1]) - ] + examples = [ + tfma_util.merge_extracts([query1_example1, query1_example2]), + tfma_util.merge_extracts( + [query2_example1, query2_example2, query2_example3] + ), + tfma_util.merge_extracts([query3_example1]), + ] - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'Create' >> beam.Create(examples) - | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs, True) - | 'AddSlice' >> beam.Map(lambda x: ((), x)) - | 'Combine' >> beam.CombinePerKey(metrics.combiner)) + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "Create" >> beam.Create(examples) + | "Process" >> beam.Map(metric_util.to_standard_metric_inputs, True) + | "AddSlice" >> beam.Map(lambda x: ((), x)) + | "Combine" >> beam.CombinePerKey(metrics.combiner) + ) - # pylint: enable=no-value-for-parameter + # pylint: enable=no-value-for-parameter - def check_result(got): - try: - self.assertLen(got, 1) - got_slice_key, got_metrics = got[0] - self.assertEqual(got_slice_key, ()) - total_queries_key = metric_types.MetricKey( - name='total_queries', example_weighted=example_weighted) - total_documents_key = metric_types.MetricKey( - name='total_documents', example_weighted=example_weighted) - min_documents_key = metric_types.MetricKey( - name='min_documents', example_weighted=example_weighted) - max_documents_key = metric_types.MetricKey( - name='max_documents', example_weighted=example_weighted) - self.assertDictElementsAlmostEqual( - got_metrics, { - total_queries_key: total_queries, - total_documents_key: total_documents, - min_documents_key: min_documents, - max_documents_key: max_documents - }, - places=5) + def check_result(got): + try: + self.assertLen(got, 1) + got_slice_key, got_metrics = got[0] + self.assertEqual(got_slice_key, ()) + total_queries_key = metric_types.MetricKey( + name="total_queries", example_weighted=example_weighted + ) + total_documents_key = metric_types.MetricKey( + name="total_documents", example_weighted=example_weighted + ) + min_documents_key = metric_types.MetricKey( + name="min_documents", example_weighted=example_weighted + ) + max_documents_key = metric_types.MetricKey( + name="max_documents", example_weighted=example_weighted + ) + self.assertDictElementsAlmostEqual( + got_metrics, + { + total_queries_key: total_queries, + total_documents_key: total_documents, + min_documents_key: min_documents, + max_documents_key: max_documents, + }, + places=5, + ) - except AssertionError as err: - raise util.BeamAssertException(err) + except AssertionError as err: + raise util.BeamAssertException(err) - util.assert_that(result, check_result, label='result') + util.assert_that(result, check_result, label="result") -if __name__ == '__main__': - tf.test.main() +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow_model_analysis/metrics/rouge.py b/tensorflow_model_analysis/metrics/rouge.py index bd5006ff70..28b0f35720 100644 --- a/tensorflow_model_analysis/metrics/rouge.py +++ b/tensorflow_model_analysis/metrics/rouge.py @@ -16,147 +16,138 @@ import dataclasses from typing import Dict, Iterable, Optional -from absl import logging import apache_beam as beam import nltk -from tensorflow_model_analysis.metrics import metric_types -from tensorflow_model_analysis.metrics import metric_util -from tensorflow_model_analysis.proto import config_pb2 - -from rouge_score import rouge_scorer -from rouge_score import scoring -from rouge_score import tokenizers +from absl import logging +from rouge_score import rouge_scorer, scoring, tokenizers +from tensorflow_model_analysis.metrics import metric_types, metric_util +from tensorflow_model_analysis.proto import config_pb2 -_LOGGING_MESSAGE_TOKENIZER_PREPARER = ( - "Finding or downloading 'punkt' from nltk." -) +_LOGGING_MESSAGE_TOKENIZER_PREPARER = "Finding or downloading 'punkt' from nltk." # TODO(b/287700355): Add __slots__ to _Accumulator @dataclasses.dataclass class _Accumulator: - weighted_count: float = 0.0 - total_precision: float = 0.0 - total_recall: float = 0.0 - total_fmeasure: float = 0.0 + weighted_count: float = 0.0 + total_precision: float = 0.0 + total_recall: float = 0.0 + total_fmeasure: float = 0.0 class RougeCombiner(beam.CombineFn): - """Computes ROUGE Scores.""" - - def __init__( - self, - rouge_type: str, - key: metric_types.MetricKey, - eval_config: config_pb2.EvalConfig, - model_name: str, - output_name: str, - use_stemmer: bool, - split_summaries: bool, - tokenizer: tokenizers.Tokenizer, - ): - """Initializes ROUGE Combiner. - - Args: - rouge_type: ROUGE type to calculate. - key: MetricKey for extract_output(). - eval_config: Eval config. - model_name: The model for which to compute these metrics. - output_name: The output name for which to compute these metrics. - use_stemmer: Bool indicating whether Porter stemmer should be used to - strip word suffixes to improve matching. This arg is used in the - DefaultTokenizer, but other tokenizers might or might not choose to use - this. - split_summaries: Whether to add newlines between sentences for rougeLsum. - tokenizer: Tokenizer object which has a tokenize() method. - """ - self._use_nltks_recommended_sentence_tokenizer = ( - split_summaries and rouge_type == 'rougeLsum' - ) - self.rouge_type = rouge_type - self.key = key - self.eval_config = eval_config - self.model_name = model_name - self.output_name = output_name - self.scorer = rouge_scorer.RougeScorer( - rouge_types=[rouge_type], - use_stemmer=use_stemmer, - split_summaries=split_summaries, - tokenizer=tokenizer, - ) - - def setup(self): - if self._use_nltks_recommended_sentence_tokenizer: - tokenizer_installed = False - - if not tokenizer_installed: - logging.info(_LOGGING_MESSAGE_TOKENIZER_PREPARER) - nltk.download('punkt') - nltk.download('punkt_tab') - - def create_accumulator(self) -> _Accumulator: - return _Accumulator() - - def add_input( - self, - accumulator: _Accumulator, - element: metric_types.StandardMetricInputs, - ) -> _Accumulator: - labels, predictions, example_weights = next( - metric_util.to_label_prediction_example_weight( - element, - eval_config=self.eval_config, - model_name=self.model_name, - output_name=self.output_name, - example_weighted=True, - flatten=False, - require_single_example_weight=True, + """Computes ROUGE Scores.""" + + def __init__( + self, + rouge_type: str, + key: metric_types.MetricKey, + eval_config: config_pb2.EvalConfig, + model_name: str, + output_name: str, + use_stemmer: bool, + split_summaries: bool, + tokenizer: tokenizers.Tokenizer, + ): + """Initializes ROUGE Combiner. + + Args: + ---- + rouge_type: ROUGE type to calculate. + key: MetricKey for extract_output(). + eval_config: Eval config. + model_name: The model for which to compute these metrics. + output_name: The output name for which to compute these metrics. + use_stemmer: Bool indicating whether Porter stemmer should be used to + strip word suffixes to improve matching. This arg is used in the + DefaultTokenizer, but other tokenizers might or might not choose to use + this. + split_summaries: Whether to add newlines between sentences for rougeLsum. + tokenizer: Tokenizer object which has a tokenize() method. + """ + self._use_nltks_recommended_sentence_tokenizer = ( + split_summaries and rouge_type == "rougeLsum" + ) + self.rouge_type = rouge_type + self.key = key + self.eval_config = eval_config + self.model_name = model_name + self.output_name = output_name + self.scorer = rouge_scorer.RougeScorer( + rouge_types=[rouge_type], + use_stemmer=use_stemmer, + split_summaries=split_summaries, + tokenizer=tokenizer, ) - ) - - example_weight = example_weights[0] - accumulator.weighted_count += example_weight - rouge_scores = self.scorer.score_multi(labels, predictions[0])[ - self.rouge_type - ] - accumulator.total_precision += rouge_scores.precision * example_weight - accumulator.total_recall += rouge_scores.recall * example_weight - accumulator.total_fmeasure += rouge_scores.fmeasure * example_weight - - return accumulator - - def merge_accumulators( - self, accumulators: Iterable[_Accumulator] - ) -> _Accumulator: - accumulators = iter(accumulators) - result = next(accumulators) - for accumulator in accumulators: - result.weighted_count += accumulator.weighted_count - result.total_precision += accumulator.total_precision - result.total_recall += accumulator.total_recall - result.total_fmeasure += accumulator.total_fmeasure - - return result - - def extract_output( - self, accumulator: _Accumulator - ) -> Dict[metric_types.MetricKey, scoring.Score]: - if accumulator.weighted_count == 0.0: - return { - self.key: scoring.Score( - precision=float('nan'), recall=float('nan'), fmeasure=float('nan') - ) - } - avg_precision = accumulator.total_precision / accumulator.weighted_count - avg_recall = accumulator.total_recall / accumulator.weighted_count - avg_fmeasure = accumulator.total_fmeasure / accumulator.weighted_count - return { - self.key: scoring.Score( - precision=avg_precision, recall=avg_recall, fmeasure=avg_fmeasure + def setup(self): + if self._use_nltks_recommended_sentence_tokenizer: + tokenizer_installed = False + + if not tokenizer_installed: + logging.info(_LOGGING_MESSAGE_TOKENIZER_PREPARER) + nltk.download("punkt") + nltk.download("punkt_tab") + + def create_accumulator(self) -> _Accumulator: + return _Accumulator() + + def add_input( + self, + accumulator: _Accumulator, + element: metric_types.StandardMetricInputs, + ) -> _Accumulator: + labels, predictions, example_weights = next( + metric_util.to_label_prediction_example_weight( + element, + eval_config=self.eval_config, + model_name=self.model_name, + output_name=self.output_name, + example_weighted=True, + flatten=False, + require_single_example_weight=True, + ) ) - } + + example_weight = example_weights[0] + accumulator.weighted_count += example_weight + + rouge_scores = self.scorer.score_multi(labels, predictions[0])[self.rouge_type] + accumulator.total_precision += rouge_scores.precision * example_weight + accumulator.total_recall += rouge_scores.recall * example_weight + accumulator.total_fmeasure += rouge_scores.fmeasure * example_weight + + return accumulator + + def merge_accumulators(self, accumulators: Iterable[_Accumulator]) -> _Accumulator: + accumulators = iter(accumulators) + result = next(accumulators) + for accumulator in accumulators: + result.weighted_count += accumulator.weighted_count + result.total_precision += accumulator.total_precision + result.total_recall += accumulator.total_recall + result.total_fmeasure += accumulator.total_fmeasure + + return result + + def extract_output( + self, accumulator: _Accumulator + ) -> Dict[metric_types.MetricKey, scoring.Score]: + if accumulator.weighted_count == 0.0: + return { + self.key: scoring.Score( + precision=float("nan"), recall=float("nan"), fmeasure=float("nan") + ) + } + avg_precision = accumulator.total_precision / accumulator.weighted_count + avg_recall = accumulator.total_recall / accumulator.weighted_count + avg_fmeasure = accumulator.total_fmeasure / accumulator.weighted_count + return { + self.key: scoring.Score( + precision=avg_precision, recall=avg_recall, fmeasure=avg_fmeasure + ) + } def _rouge( @@ -169,116 +160,115 @@ def _rouge( split_summaries: bool, tokenizer: tokenizers.Tokenizer, ) -> metric_types.MetricComputations: - """Returns metric computations for ROUGE.""" - key = metric_types.MetricKey(name=name) - return [ - metric_types.MetricComputation( - keys=[key], - preprocessors=None, - combiner=RougeCombiner( - rouge_type=rouge_type, - key=key, - eval_config=eval_config, - model_name=model_name, - output_name=output_name, - use_stemmer=use_stemmer, - split_summaries=split_summaries, - tokenizer=tokenizer, - ), - ) - ] + """Returns metric computations for ROUGE.""" + key = metric_types.MetricKey(name=name) + return [ + metric_types.MetricComputation( + keys=[key], + preprocessors=None, + combiner=RougeCombiner( + rouge_type=rouge_type, + key=key, + eval_config=eval_config, + model_name=model_name, + output_name=output_name, + use_stemmer=use_stemmer, + split_summaries=split_summaries, + tokenizer=tokenizer, + ), + ) + ] class Rouge(metric_types.Metric): - """ROUGE Metrics. - - ROUGE stands for Recall-Oriented Understudy for Gisting Evaluation. It - includes measures to automatically determine the quality of a summary by - comparing it to other (ideal) reference / target summaries. - - ROUGE was originally introduced in the paper: - - Lin, Chin-Yew. ROUGE: a Package for Automatic Evaluation of Summaries. In - Proceedings of the Workshop on Text Summarization Branches Out (WAS 2004), - Barcelona, Spain, July 25 - 26, 2004. - - This implementation supports Rouge-N where N is an int in [1, 9], RougeL, and - RougeLsum. Note, to calculate multiple ROUGE Metrics, you will need to call - this metric multiple times. - - For this implementation, a Label is expected to be a list of texts containing - the target summaries. A Prediction is expected to be text containing the - predicted text. - - In the ROUGE paper, two flavors of ROUGE are described: - - 1. sentence-level: Compute longest common subsequence (LCS) between two pieces - of text. Newlines are ignored. This is called 'rougeL' in this package. - 2. summary-level: Newlines in the text are interpreted as sentence boundaries, - and the LCS is computed between each pair of reference and candidate - sentences, and the union-LCS is computed. This is called - 'rougeLsum' in this package. This is the ROUGE-L reported in *[Get To The - Point: Summarization with Pointer-Generator Networks] - (https://arxiv.org/abs/1704.04368)*, for example. If your - references/candidates do not have newline delimiters, you can use the - split_summaries argument. - - This is a wrapper of the pure python implementation of ROUGE found here: - https://pypi.org/project/rouge-score/ - - To implement this metric, see the example below: - - eval_config = tfma.EvalConfig( - metrics_specs=[ - tfma.MetricsSpec(metrics=[ - tfma.MetricConfig( - class_name='Rouge', - config='"rouge_type":"rouge1"' - ]), - tfma.MetricsSpec(metrics=[ - tfma.MetricConfig( - class_name='Rouge', - config='"rouge_type":"rouge2"' - ]), - tfma.MetricsSpec(metrics=[ - tfma.MetricConfig( - class_name='Rouge', - config='"rouge_type":"rougeL"' - ]), - tfma.MetricsSpec(metrics=[ - tfma.MetricConfig( - class_name='Rouge', - config='"rouge_type":"rougeLsum"' - ]), - ... - ], - ... - ) - - evaluator = tfx.borg.components.Evaluator( - examples=example_gen.outputs['examples'], - model=trainer.outputs['model'], - eval_config=eval_config) - """ - - def __init__( - self, - rouge_type: str, - name: Optional[str] = None, - use_stemmer: Optional[bool] = False, - split_summaries: Optional[bool] = False, - tokenizer: Optional[tokenizers.Tokenizer] = None, - ): - """Initializes ROUGE Metrics.""" - - super().__init__( - metric_util.merge_per_key_computations(_rouge), - rouge_type=rouge_type, - name=name or rouge_type, - use_stemmer=use_stemmer, - split_summaries=split_summaries, - tokenizer=tokenizer, + """ROUGE Metrics. + + ROUGE stands for Recall-Oriented Understudy for Gisting Evaluation. It + includes measures to automatically determine the quality of a summary by + comparing it to other (ideal) reference / target summaries. + + ROUGE was originally introduced in the paper: + + Lin, Chin-Yew. ROUGE: a Package for Automatic Evaluation of Summaries. In + Proceedings of the Workshop on Text Summarization Branches Out (WAS 2004), + Barcelona, Spain, July 25 - 26, 2004. + + This implementation supports Rouge-N where N is an int in [1, 9], RougeL, and + RougeLsum. Note, to calculate multiple ROUGE Metrics, you will need to call + this metric multiple times. + + For this implementation, a Label is expected to be a list of texts containing + the target summaries. A Prediction is expected to be text containing the + predicted text. + + In the ROUGE paper, two flavors of ROUGE are described: + + 1. sentence-level: Compute longest common subsequence (LCS) between two pieces + of text. Newlines are ignored. This is called 'rougeL' in this package. + 2. summary-level: Newlines in the text are interpreted as sentence boundaries, + and the LCS is computed between each pair of reference and candidate + sentences, and the union-LCS is computed. This is called + 'rougeLsum' in this package. This is the ROUGE-L reported in *[Get To The + Point: Summarization with Pointer-Generator Networks] + (https://arxiv.org/abs/1704.04368)*, for example. If your + references/candidates do not have newline delimiters, you can use the + split_summaries argument. + + This is a wrapper of the pure python implementation of ROUGE found here: + https://pypi.org/project/rouge-score/ + + To implement this metric, see the example below: + + eval_config = tfma.EvalConfig( + metrics_specs=[ + tfma.MetricsSpec(metrics=[ + tfma.MetricConfig( + class_name='Rouge', + config='"rouge_type":"rouge1"' + ]), + tfma.MetricsSpec(metrics=[ + tfma.MetricConfig( + class_name='Rouge', + config='"rouge_type":"rouge2"' + ]), + tfma.MetricsSpec(metrics=[ + tfma.MetricConfig( + class_name='Rouge', + config='"rouge_type":"rougeL"' + ]), + tfma.MetricsSpec(metrics=[ + tfma.MetricConfig( + class_name='Rouge', + config='"rouge_type":"rougeLsum"' + ]), + ... + ], + ... ) + evaluator = tfx.borg.components.Evaluator( + examples=example_gen.outputs['examples'], + model=trainer.outputs['model'], + eval_config=eval_config) + """ + + def __init__( + self, + rouge_type: str, + name: Optional[str] = None, + use_stemmer: Optional[bool] = False, + split_summaries: Optional[bool] = False, + tokenizer: Optional[tokenizers.Tokenizer] = None, + ): + """Initializes ROUGE Metrics.""" + super().__init__( + metric_util.merge_per_key_computations(_rouge), + rouge_type=rouge_type, + name=name or rouge_type, + use_stemmer=use_stemmer, + split_summaries=split_summaries, + tokenizer=tokenizer, + ) + metric_types.register_metric(Rouge) diff --git a/tensorflow_model_analysis/metrics/rouge_test.py b/tensorflow_model_analysis/metrics/rouge_test.py index a645be58e1..2f7403e287 100644 --- a/tensorflow_model_analysis/metrics/rouge_test.py +++ b/tensorflow_model_analysis/metrics/rouge_test.py @@ -15,628 +15,635 @@ import statistics as stats -from absl.testing import parameterized import apache_beam as beam -from apache_beam.testing import util import numpy as np import tensorflow as tf -import tensorflow_model_analysis as tfma -from tensorflow_model_analysis.proto import config_pb2 +from absl.testing import parameterized +from apache_beam.testing import util +from google.protobuf import text_format +from rouge_score import tokenizers + from tensorflow_model_analysis import constants from tensorflow_model_analysis.evaluators import metrics_plots_and_validations_evaluator -from tensorflow_model_analysis.metrics import metric_types -from tensorflow_model_analysis.metrics import metric_util -from tensorflow_model_analysis.metrics import rouge +from tensorflow_model_analysis.metrics import metric_types, metric_util, rouge +from tensorflow_model_analysis.proto import config_pb2 from tensorflow_model_analysis.slicer import slicer_lib as slicer from tensorflow_model_analysis.utils import test_util -from google.protobuf import text_format -from rouge_score import tokenizers - def _get_result(pipeline, examples, combiner): - return ( - pipeline - | 'Create' >> beam.Create(examples) - | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs) - | 'AddSlice' >> beam.Map(lambda x: ((), x)) - | 'ComputeRouge' >> beam.CombinePerKey(combiner) - ) + return ( + pipeline + | "Create" >> beam.Create(examples) + | "Process" >> beam.Map(metric_util.to_standard_metric_inputs) + | "AddSlice" >> beam.Map(lambda x: ((), x)) + | "ComputeRouge" >> beam.CombinePerKey(combiner) + ) class RogueTest(test_util.TensorflowModelAnalysisTest, parameterized.TestCase): + def _check_got(self, got, rouge_computation): + """Checks that the slice key is an empty tuple and the expected MetricKey is in the metric.""" + self.assertLen(got, 1) + got_slice_key, got_metrics = got[0] + self.assertEqual(got_slice_key, ()) + self.assertIn(rouge_computation.keys[0], got_metrics) + return got_metrics + + @parameterized.parameters(["rougen", "rouge0", "rouge10"]) + def testInvalidRougeTypes(self, rouge_type): + target_text = "testing one two" + prediction_text = "testing" + example = { + constants.LABELS_KEY: target_text, + constants.PREDICTIONS_KEY: prediction_text, + } + rouge_computation = rouge.Rouge(rouge_type).computations()[0] + with self.assertRaises(ValueError): + with beam.Pipeline() as pipeline: + _get_result( + pipeline=pipeline, + examples=[example], + combiner=rouge_computation.combiner, + ) + + @parameterized.parameters( + [ + "rouge1", + "rouge2", + "rouge3", + "rouge4", + "rouge5", + "rouge6", + "rouge7", + "rouge8", + "rouge9", + "rougeL", + "rougeLsum", + ] + ) + def testValidRogueTypes(self, rouge_type): + target_text = "testing one two" + prediction_text = "testing" + example = { + constants.LABELS_KEY: target_text, + constants.PREDICTIONS_KEY: prediction_text, + } + rouge_computation = rouge.Rouge(rouge_type).computations()[0] + with beam.Pipeline() as pipeline: + result = _get_result( + pipeline=pipeline, + examples=[example], + combiner=rouge_computation.combiner, + ) - def _check_got(self, got, rouge_computation): - """Checks that the slice key is an empty tuple and the expected MetricKey is in the metric.""" - self.assertLen(got, 1) - got_slice_key, got_metrics = got[0] - self.assertEqual(got_slice_key, ()) - self.assertIn(rouge_computation.keys[0], got_metrics) - return got_metrics - - @parameterized.parameters(['rougen', 'rouge0', 'rouge10']) - def testInvalidRougeTypes(self, rouge_type): - target_text = 'testing one two' - prediction_text = 'testing' - example = { - constants.LABELS_KEY: target_text, - constants.PREDICTIONS_KEY: prediction_text, - } - rouge_computation = rouge.Rouge(rouge_type).computations()[0] - with self.assertRaises(ValueError): - with beam.Pipeline() as pipeline: - _get_result( - pipeline=pipeline, - examples=[example], - combiner=rouge_computation.combiner, - ) + def check_result(got): + try: + got_metrics = self._check_got(got, rouge_computation) + self.assertEqual(next(iter(got_metrics.keys())).name, rouge_type) - @parameterized.parameters([ - 'rouge1', - 'rouge2', - 'rouge3', - 'rouge4', - 'rouge5', - 'rouge6', - 'rouge7', - 'rouge8', - 'rouge9', - 'rougeL', - 'rougeLsum', - ]) - def testValidRogueTypes(self, rouge_type): - target_text = 'testing one two' - prediction_text = 'testing' - example = { - constants.LABELS_KEY: target_text, - constants.PREDICTIONS_KEY: prediction_text, - } - rouge_computation = rouge.Rouge(rouge_type).computations()[0] - with beam.Pipeline() as pipeline: - result = _get_result( - pipeline=pipeline, - examples=[example], - combiner=rouge_computation.combiner, - ) - - def check_result(got): - try: - got_metrics = self._check_got(got, rouge_computation) - self.assertEqual(next(iter(got_metrics.keys())).name, rouge_type) - - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(result, check_result, label='result') - - @parameterized.parameters(['rouge1', 'rouge2', 'rougeL', 'rougeLsum']) - def testNameOverride(self, rouge_type): - target_text = 'testing one two' - prediction_text = 'testing' - expected_name = 'override_default_name_with_this' - example = { - constants.LABELS_KEY: target_text, - constants.PREDICTIONS_KEY: prediction_text, - } - rouge_computation = rouge.Rouge( - rouge_type, name=expected_name - ).computations()[0] - with beam.Pipeline() as pipeline: - result = _get_result( - pipeline=pipeline, - examples=[example], - combiner=rouge_computation.combiner, - ) - - def check_result(got): - try: - got_metrics = self._check_got(got, rouge_computation) - self.assertEqual(next(iter(got_metrics.keys())).name, expected_name) - - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(result, check_result, label='result') - - @parameterized.named_parameters( - ( - 'rouge1', - 'rouge1', - ['testing one two', 'testing'], - 1, - 1 / 3, - 1 / 2, - ), - ( - 'rouge2', - 'rouge2', - ['testing one two', 'testing one'], - 1, - 1 / 2, - 2 / 3, - ), - ( - 'rougeL_consecutive', - 'rougeL', - ['testing one two', 'testing one'], - 1, - 2 / 3, - 4 / 5, - ), - ( - 'rougeL_nonconsecutive', - 'rougeL', - ['testing one two', 'testing two'], - 1, - 2 / 3, - 4 / 5, - ), - ( - 'rougeLsum', - 'rougeLsum', - ['w1 w2 w3 w4 w5', 'w1 w2 w6 w7 w8\nw1 w3 w8 w9 w5'], - 2 / 5, - 4 / 5, - 8 / 15, - ), - ) - def testRougeSingleExample( - self, - rouge_type, - example_texts, - expected_precision, - expected_recall, - expected_fmeasure, - ): - example = { - constants.LABELS_KEY: example_texts[0], - constants.PREDICTIONS_KEY: example_texts[1], - } - rouge_key = metric_types.MetricKey(name=rouge_type) - rouge_computation = rouge.Rouge(rouge_type).computations()[0] - with beam.Pipeline() as pipeline: - result = _get_result( - pipeline=pipeline, - examples=[example], - combiner=rouge_computation.combiner, - ) - - def check_result(got): - try: - got_metrics = self._check_got(got, rouge_computation) - self.assertAlmostEqual( - expected_precision, got_metrics[rouge_key].precision - ) - self.assertAlmostEqual(expected_recall, got_metrics[rouge_key].recall) - self.assertAlmostEqual( - expected_fmeasure, got_metrics[rouge_key].fmeasure - ) - - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(result, check_result, label='result') - - @parameterized.parameters('rouge1', 'rouge2', 'rougeL', 'rougeLsum') - def testRougeMultipleExampleWeights(self, rouge_type): - example = { - constants.LABELS_KEY: 'testing one two', - constants.PREDICTIONS_KEY: 'testing', - constants.EXAMPLE_WEIGHTS_KEY: [0.4, 0.6], - } - rouge_computation = rouge.Rouge(rouge_type).computations()[0] - with self.assertRaises(ValueError): - with beam.Pipeline() as pipeline: - _get_result( - pipeline=pipeline, - examples=[example], - combiner=rouge_computation.combiner, - ) + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(result, check_result, label="result") + + @parameterized.parameters(["rouge1", "rouge2", "rougeL", "rougeLsum"]) + def testNameOverride(self, rouge_type): + target_text = "testing one two" + prediction_text = "testing" + expected_name = "override_default_name_with_this" + example = { + constants.LABELS_KEY: target_text, + constants.PREDICTIONS_KEY: prediction_text, + } + rouge_computation = rouge.Rouge(rouge_type, name=expected_name).computations()[ + 0 + ] + with beam.Pipeline() as pipeline: + result = _get_result( + pipeline=pipeline, + examples=[example], + combiner=rouge_computation.combiner, + ) - @parameterized.named_parameters([ - ( - 'rouge1', - 'rouge1', - ['testing one two', 'This is a test'], - 'This is not a test', - 4 / 5, - 1, - 8 / 9, - ), - ( - 'rouge2', - 'rouge2', - ['testing one two', 'This is a test'], - 'This is not a test', - 1 / 2, - 2 / 3, - 4 / 7, - ), - ( - 'rougeL', - 'rougeL', - ['testing one two', 'This is a test'], - 'This is not a test', - 4 / 5, - 1, - 8 / 9, - ), - ( - 'rougeLsum', - 'rougeLsum', - ['testing one two', 'This is a test'], - 'This is not a test', - # ROUGE-L == ROUGE-L-Sum for these examples - # because there is no sentence splitting - 4 / 5, - 1, - 8 / 9, - ), - ]) - def testRougeMultipleTargetTexts( - self, - rouge_type, - targets, - prediction, - expected_precision, - expected_recall, - expected_fmeasure, - ): - example = { - constants.LABELS_KEY: targets, - constants.PREDICTIONS_KEY: prediction, - } - rouge_key = metric_types.MetricKey(name=rouge_type) - rouge_computation = rouge.Rouge(rouge_type).computations()[0] - with beam.Pipeline() as pipeline: - result = _get_result( - pipeline=pipeline, - examples=[example], - combiner=rouge_computation.combiner, - ) - - def check_result(got): - try: - got_metrics = self._check_got(got, rouge_computation) - self.assertAlmostEqual( - expected_precision, got_metrics[rouge_key].precision - ) - self.assertAlmostEqual(expected_recall, got_metrics[rouge_key].recall) - self.assertAlmostEqual( - expected_fmeasure, got_metrics[rouge_key].fmeasure - ) - - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(result, check_result, label='result') - - @parameterized.named_parameters([ - ( - 'rouge1', - 'rouge1', - ['testing one two', 'testing'], - ['This is a test', 'This is not a test'], - stats.mean([1, 4 / 5]), - stats.mean([1 / 3, 1]), - stats.mean([1 / 2, 8 / 9]), - ), - ( - 'rouge2', - 'rouge2', - ['testing one two', 'testing one'], - ['This is a test', 'This is not a test'], - stats.mean([1, 1 / 2]), - stats.mean([1 / 2, 2 / 3]), - stats.mean([2 / 3, 4 / 7]), - ), - ( - 'rougeL', - 'rougeL', - ['testing one two', 'testing one'], - ['This is a test', 'This is not a test'], - stats.mean([1, 4 / 5]), - stats.mean([2 / 3, 1]), - stats.mean([4 / 5, 8 / 9]), - ), - ( - 'rougeLsum', - 'rougeLsum', - ['testing one two', 'testing one'], - ['This is a test', 'This is not a test'], - # ROUGE-L == ROUGE-L-Sum for these examples - # because there is no sentence splitting - stats.mean([1, 4 / 5]), - stats.mean([2 / 3, 1]), - stats.mean([4 / 5, 8 / 9]), - ), - ]) - def testRougeMultipleExamplesUnweighted( - self, - rouge_type, - example_1_texts, - example_2_texts, - expected_precision, - expected_recall, - expected_fmeasure, - ): - example1 = { - constants.LABELS_KEY: example_1_texts[0], - constants.PREDICTIONS_KEY: example_1_texts[1], - } - example2 = { - constants.LABELS_KEY: example_2_texts[0], - constants.PREDICTIONS_KEY: example_2_texts[1], - } - rouge_key = metric_types.MetricKey(name=rouge_type) - rouge_computation = rouge.Rouge(rouge_type).computations()[0] - with beam.Pipeline() as pipeline: - result = _get_result( - pipeline=pipeline, - examples=[example1, example2], - combiner=rouge_computation.combiner, - ) - - def check_result(got): - try: - got_metrics = self._check_got(got, rouge_computation) - self.assertAlmostEqual( - expected_precision, got_metrics[rouge_key].precision, places=6 - ) - self.assertAlmostEqual( - expected_recall, got_metrics[rouge_key].recall, places=6 - ) - self.assertAlmostEqual( - expected_fmeasure, got_metrics[rouge_key].fmeasure, places=6 - ) - - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(result, check_result, label='result') - - example_weights = [0.5, 0.7] - - @parameterized.named_parameters([ - ( - 'rouge1', - 'rouge1', - ['testing one two', 'testing'], - ['This is a test', 'This is not a test'], - np.average([1, 4 / 5], weights=example_weights), - np.average([1 / 3, 1], weights=example_weights), - np.average([1 / 2, 8 / 9], weights=example_weights), - ), - ( - 'rouge2', - 'rouge2', - ['testing one two', 'testing one'], - ['This is a test', 'This is not a test'], - np.average([1, 1 / 2], weights=example_weights), - np.average([1 / 2, 2 / 3], weights=example_weights), - np.average([2 / 3, 4 / 7], weights=example_weights), - ), - ( - 'rougeL', - 'rougeL', - ['testing one two', 'testing one'], - ['This is a test', 'This is not a test'], - np.average([1, 4 / 5], weights=example_weights), - np.average([2 / 3, 1], weights=example_weights), - np.average([4 / 5, 8 / 9], weights=example_weights), - ), - ( - 'rougeLsum', - 'rougeLsum', - ['testing one two', 'testing one'], - ['This is a test', 'This is not a test'], - # ROUGE-L == ROUGE-L-Sum for these examples - # because there is no sentence splitting - np.average([1, 4 / 5], weights=example_weights), - np.average([2 / 3, 1], weights=example_weights), - np.average([4 / 5, 8 / 9], weights=example_weights), - ), - ]) - def testRougeMultipleExamplesWeighted( - self, - rouge_type, - example_1_texts, - example_2_texts, - expected_precision, - expected_recall, - expected_fmeasure, - ): - example1 = { - constants.LABELS_KEY: example_1_texts[0], - constants.PREDICTIONS_KEY: example_1_texts[1], - constants.EXAMPLE_WEIGHTS_KEY: self.example_weights[0], - } - example2 = { - constants.LABELS_KEY: example_2_texts[0], - constants.PREDICTIONS_KEY: example_2_texts[1], - constants.EXAMPLE_WEIGHTS_KEY: self.example_weights[1], - } - rouge_key = metric_types.MetricKey(name=rouge_type) - rouge_computation = rouge.Rouge(rouge_type).computations()[0] - with beam.Pipeline() as pipeline: - result = _get_result( - pipeline=pipeline, - examples=[example1, example2], - combiner=rouge_computation.combiner, - ) - - def check_result(got): - try: - got_metrics = self._check_got(got, rouge_computation) - self.assertAlmostEqual( - expected_precision, got_metrics[rouge_key].precision - ) - self.assertAlmostEqual(expected_recall, got_metrics[rouge_key].recall) - self.assertAlmostEqual( - expected_fmeasure, got_metrics[rouge_key].fmeasure - ) - - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(result, check_result, label='result') - - @parameterized.parameters('rouge1', 'rouge2', 'rougeL', 'rougeLsum') - def testRougeWeightedCountIsZero(self, rouge_type): - example = { - constants.LABELS_KEY: 'testing one two', - constants.PREDICTIONS_KEY: 'testing', - constants.EXAMPLE_WEIGHTS_KEY: [0], - } - rouge_key = metric_types.MetricKey(name=rouge_type) - rouge_computation = rouge.Rouge(rouge_type).computations()[0] - with beam.Pipeline() as pipeline: - result = _get_result( - pipeline=pipeline, - examples=[example], - combiner=rouge_computation.combiner, - ) - - def check_result(got): - try: - got_metrics = self._check_got(got, rouge_computation) - self.assertTrue(np.isnan(got_metrics[rouge_key].precision)) - self.assertTrue(np.isnan(got_metrics[rouge_key].recall)) - self.assertTrue(np.isnan(got_metrics[rouge_key].fmeasure)) - - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(result, check_result, label='result') - - def testRougeLSumSentenceSplitting(self): - rouge_type = 'rougeLsum' - rouge_key = metric_types.MetricKey(name=rouge_type) - tokenizer_preparer_logging_message = ( - 'INFO:absl:' + rouge._LOGGING_MESSAGE_TOKENIZER_PREPARER + def check_result(got): + try: + got_metrics = self._check_got(got, rouge_computation) + self.assertEqual(next(iter(got_metrics.keys())).name, expected_name) + + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(result, check_result, label="result") + + @parameterized.named_parameters( + ( + "rouge1", + "rouge1", + ["testing one two", "testing"], + 1, + 1 / 3, + 1 / 2, + ), + ( + "rouge2", + "rouge2", + ["testing one two", "testing one"], + 1, + 1 / 2, + 2 / 3, + ), + ( + "rougeL_consecutive", + "rougeL", + ["testing one two", "testing one"], + 1, + 2 / 3, + 4 / 5, + ), + ( + "rougeL_nonconsecutive", + "rougeL", + ["testing one two", "testing two"], + 1, + 2 / 3, + 4 / 5, + ), + ( + "rougeLsum", + "rougeLsum", + ["w1 w2 w3 w4 w5", "w1 w2 w6 w7 w8\nw1 w3 w8 w9 w5"], + 2 / 5, + 4 / 5, + 8 / 15, + ), ) - rouge_computation = rouge.Rouge( - rouge_type, use_stemmer=True - ).computations()[0] - target_text = 'First sentence.\nSecond Sentence.' - prediction_text = 'Second sentence.\nFirst Sentence.' - example = { - constants.LABELS_KEY: target_text, - constants.PREDICTIONS_KEY: prediction_text, - } - with self.assertLogs(level='INFO') as cm: - with beam.Pipeline() as pipeline: - result = _get_result( - pipeline=pipeline, - examples=[example], - combiner=rouge_computation.combiner, - ) + def testRougeSingleExample( + self, + rouge_type, + example_texts, + expected_precision, + expected_recall, + expected_fmeasure, + ): + example = { + constants.LABELS_KEY: example_texts[0], + constants.PREDICTIONS_KEY: example_texts[1], + } + rouge_key = metric_types.MetricKey(name=rouge_type) + rouge_computation = rouge.Rouge(rouge_type).computations()[0] + with beam.Pipeline() as pipeline: + result = _get_result( + pipeline=pipeline, + examples=[example], + combiner=rouge_computation.combiner, + ) - def check_result_newline(got): - try: - got_metrics = self._check_got(got, rouge_computation) - self.assertAlmostEqual(1, got_metrics[rouge_key].precision) - self.assertAlmostEqual(1, got_metrics[rouge_key].recall) - self.assertAlmostEqual(1, got_metrics[rouge_key].fmeasure) - - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(result, check_result_newline, label='result') - self.assertNotIn(tokenizer_preparer_logging_message, cm.output) - - # Without newlines, summaries are treated as single sentences. - target_text = target_text.replace('\n', ' ') - prediction_text = prediction_text.replace('\n', ' ') - example = { - constants.LABELS_KEY: target_text, - constants.PREDICTIONS_KEY: prediction_text, - } - with self.assertLogs(level='INFO') as cm: - with beam.Pipeline() as pipeline: - result = _get_result( - pipeline=pipeline, - examples=[example], - combiner=rouge_computation.combiner, - ) + def check_result(got): + try: + got_metrics = self._check_got(got, rouge_computation) + self.assertAlmostEqual( + expected_precision, got_metrics[rouge_key].precision + ) + self.assertAlmostEqual( + expected_recall, got_metrics[rouge_key].recall + ) + self.assertAlmostEqual( + expected_fmeasure, got_metrics[rouge_key].fmeasure + ) + + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(result, check_result, label="result") + + @parameterized.parameters("rouge1", "rouge2", "rougeL", "rougeLsum") + def testRougeMultipleExampleWeights(self, rouge_type): + example = { + constants.LABELS_KEY: "testing one two", + constants.PREDICTIONS_KEY: "testing", + constants.EXAMPLE_WEIGHTS_KEY: [0.4, 0.6], + } + rouge_computation = rouge.Rouge(rouge_type).computations()[0] + with self.assertRaises(ValueError): + with beam.Pipeline() as pipeline: + _get_result( + pipeline=pipeline, + examples=[example], + combiner=rouge_computation.combiner, + ) + + @parameterized.named_parameters( + [ + ( + "rouge1", + "rouge1", + ["testing one two", "This is a test"], + "This is not a test", + 4 / 5, + 1, + 8 / 9, + ), + ( + "rouge2", + "rouge2", + ["testing one two", "This is a test"], + "This is not a test", + 1 / 2, + 2 / 3, + 4 / 7, + ), + ( + "rougeL", + "rougeL", + ["testing one two", "This is a test"], + "This is not a test", + 4 / 5, + 1, + 8 / 9, + ), + ( + "rougeLsum", + "rougeLsum", + ["testing one two", "This is a test"], + "This is not a test", + # ROUGE-L == ROUGE-L-Sum for these examples + # because there is no sentence splitting + 4 / 5, + 1, + 8 / 9, + ), + ] + ) + def testRougeMultipleTargetTexts( + self, + rouge_type, + targets, + prediction, + expected_precision, + expected_recall, + expected_fmeasure, + ): + example = { + constants.LABELS_KEY: targets, + constants.PREDICTIONS_KEY: prediction, + } + rouge_key = metric_types.MetricKey(name=rouge_type) + rouge_computation = rouge.Rouge(rouge_type).computations()[0] + with beam.Pipeline() as pipeline: + result = _get_result( + pipeline=pipeline, + examples=[example], + combiner=rouge_computation.combiner, + ) + + def check_result(got): + try: + got_metrics = self._check_got(got, rouge_computation) + self.assertAlmostEqual( + expected_precision, got_metrics[rouge_key].precision + ) + self.assertAlmostEqual( + expected_recall, got_metrics[rouge_key].recall + ) + self.assertAlmostEqual( + expected_fmeasure, got_metrics[rouge_key].fmeasure + ) + + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(result, check_result, label="result") + + @parameterized.named_parameters( + [ + ( + "rouge1", + "rouge1", + ["testing one two", "testing"], + ["This is a test", "This is not a test"], + stats.mean([1, 4 / 5]), + stats.mean([1 / 3, 1]), + stats.mean([1 / 2, 8 / 9]), + ), + ( + "rouge2", + "rouge2", + ["testing one two", "testing one"], + ["This is a test", "This is not a test"], + stats.mean([1, 1 / 2]), + stats.mean([1 / 2, 2 / 3]), + stats.mean([2 / 3, 4 / 7]), + ), + ( + "rougeL", + "rougeL", + ["testing one two", "testing one"], + ["This is a test", "This is not a test"], + stats.mean([1, 4 / 5]), + stats.mean([2 / 3, 1]), + stats.mean([4 / 5, 8 / 9]), + ), + ( + "rougeLsum", + "rougeLsum", + ["testing one two", "testing one"], + ["This is a test", "This is not a test"], + # ROUGE-L == ROUGE-L-Sum for these examples + # because there is no sentence splitting + stats.mean([1, 4 / 5]), + stats.mean([2 / 3, 1]), + stats.mean([4 / 5, 8 / 9]), + ), + ] + ) + def testRougeMultipleExamplesUnweighted( + self, + rouge_type, + example_1_texts, + example_2_texts, + expected_precision, + expected_recall, + expected_fmeasure, + ): + example1 = { + constants.LABELS_KEY: example_1_texts[0], + constants.PREDICTIONS_KEY: example_1_texts[1], + } + example2 = { + constants.LABELS_KEY: example_2_texts[0], + constants.PREDICTIONS_KEY: example_2_texts[1], + } + rouge_key = metric_types.MetricKey(name=rouge_type) + rouge_computation = rouge.Rouge(rouge_type).computations()[0] + with beam.Pipeline() as pipeline: + result = _get_result( + pipeline=pipeline, + examples=[example1, example2], + combiner=rouge_computation.combiner, + ) + + def check_result(got): + try: + got_metrics = self._check_got(got, rouge_computation) + self.assertAlmostEqual( + expected_precision, got_metrics[rouge_key].precision, places=6 + ) + self.assertAlmostEqual( + expected_recall, got_metrics[rouge_key].recall, places=6 + ) + self.assertAlmostEqual( + expected_fmeasure, got_metrics[rouge_key].fmeasure, places=6 + ) + + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(result, check_result, label="result") + + example_weights = [0.5, 0.7] + + @parameterized.named_parameters( + [ + ( + "rouge1", + "rouge1", + ["testing one two", "testing"], + ["This is a test", "This is not a test"], + np.average([1, 4 / 5], weights=example_weights), + np.average([1 / 3, 1], weights=example_weights), + np.average([1 / 2, 8 / 9], weights=example_weights), + ), + ( + "rouge2", + "rouge2", + ["testing one two", "testing one"], + ["This is a test", "This is not a test"], + np.average([1, 1 / 2], weights=example_weights), + np.average([1 / 2, 2 / 3], weights=example_weights), + np.average([2 / 3, 4 / 7], weights=example_weights), + ), + ( + "rougeL", + "rougeL", + ["testing one two", "testing one"], + ["This is a test", "This is not a test"], + np.average([1, 4 / 5], weights=example_weights), + np.average([2 / 3, 1], weights=example_weights), + np.average([4 / 5, 8 / 9], weights=example_weights), + ), + ( + "rougeLsum", + "rougeLsum", + ["testing one two", "testing one"], + ["This is a test", "This is not a test"], + # ROUGE-L == ROUGE-L-Sum for these examples + # because there is no sentence splitting + np.average([1, 4 / 5], weights=example_weights), + np.average([2 / 3, 1], weights=example_weights), + np.average([4 / 5, 8 / 9], weights=example_weights), + ), + ] + ) + def testRougeMultipleExamplesWeighted( + self, + rouge_type, + example_1_texts, + example_2_texts, + expected_precision, + expected_recall, + expected_fmeasure, + ): + example1 = { + constants.LABELS_KEY: example_1_texts[0], + constants.PREDICTIONS_KEY: example_1_texts[1], + constants.EXAMPLE_WEIGHTS_KEY: self.example_weights[0], + } + example2 = { + constants.LABELS_KEY: example_2_texts[0], + constants.PREDICTIONS_KEY: example_2_texts[1], + constants.EXAMPLE_WEIGHTS_KEY: self.example_weights[1], + } + rouge_key = metric_types.MetricKey(name=rouge_type) + rouge_computation = rouge.Rouge(rouge_type).computations()[0] + with beam.Pipeline() as pipeline: + result = _get_result( + pipeline=pipeline, + examples=[example1, example2], + combiner=rouge_computation.combiner, + ) + + def check_result(got): + try: + got_metrics = self._check_got(got, rouge_computation) + self.assertAlmostEqual( + expected_precision, got_metrics[rouge_key].precision + ) + self.assertAlmostEqual( + expected_recall, got_metrics[rouge_key].recall + ) + self.assertAlmostEqual( + expected_fmeasure, got_metrics[rouge_key].fmeasure + ) + + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(result, check_result, label="result") + + @parameterized.parameters("rouge1", "rouge2", "rougeL", "rougeLsum") + def testRougeWeightedCountIsZero(self, rouge_type): + example = { + constants.LABELS_KEY: "testing one two", + constants.PREDICTIONS_KEY: "testing", + constants.EXAMPLE_WEIGHTS_KEY: [0], + } + rouge_key = metric_types.MetricKey(name=rouge_type) + rouge_computation = rouge.Rouge(rouge_type).computations()[0] + with beam.Pipeline() as pipeline: + result = _get_result( + pipeline=pipeline, + examples=[example], + combiner=rouge_computation.combiner, + ) + + def check_result(got): + try: + got_metrics = self._check_got(got, rouge_computation) + self.assertTrue(np.isnan(got_metrics[rouge_key].precision)) + self.assertTrue(np.isnan(got_metrics[rouge_key].recall)) + self.assertTrue(np.isnan(got_metrics[rouge_key].fmeasure)) + + except AssertionError as err: + raise util.BeamAssertException(err) - def check_result_sentences(got): - try: - got_metrics = self._check_got(got, rouge_computation) - self.assertAlmostEqual(1 / 2, got_metrics[rouge_key].precision) - self.assertAlmostEqual(1 / 2, got_metrics[rouge_key].recall) - self.assertAlmostEqual(1 / 2, got_metrics[rouge_key].fmeasure) - - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(result, check_result_sentences, label='result') - self.assertNotIn(tokenizer_preparer_logging_message, cm.output) - - def check_split_summaries_result(): - with beam.Pipeline() as pipeline: - result = _get_result( - pipeline=pipeline, - examples=[example], - combiner=rouge_computation.combiner, + util.assert_that(result, check_result, label="result") + + def testRougeLSumSentenceSplitting(self): + rouge_type = "rougeLsum" + rouge_key = metric_types.MetricKey(name=rouge_type) + tokenizer_preparer_logging_message = ( + "INFO:absl:" + rouge._LOGGING_MESSAGE_TOKENIZER_PREPARER ) + rouge_computation = rouge.Rouge(rouge_type, use_stemmer=True).computations()[0] + target_text = "First sentence.\nSecond Sentence." + prediction_text = "Second sentence.\nFirst Sentence." + example = { + constants.LABELS_KEY: target_text, + constants.PREDICTIONS_KEY: prediction_text, + } + with self.assertLogs(level="INFO") as cm: + with beam.Pipeline() as pipeline: + result = _get_result( + pipeline=pipeline, + examples=[example], + combiner=rouge_computation.combiner, + ) + + def check_result_newline(got): + try: + got_metrics = self._check_got(got, rouge_computation) + self.assertAlmostEqual(1, got_metrics[rouge_key].precision) + self.assertAlmostEqual(1, got_metrics[rouge_key].recall) + self.assertAlmostEqual(1, got_metrics[rouge_key].fmeasure) + + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(result, check_result_newline, label="result") + self.assertNotIn(tokenizer_preparer_logging_message, cm.output) + + # Without newlines, summaries are treated as single sentences. + target_text = target_text.replace("\n", " ") + prediction_text = prediction_text.replace("\n", " ") + example = { + constants.LABELS_KEY: target_text, + constants.PREDICTIONS_KEY: prediction_text, + } + with self.assertLogs(level="INFO") as cm: + with beam.Pipeline() as pipeline: + result = _get_result( + pipeline=pipeline, + examples=[example], + combiner=rouge_computation.combiner, + ) + + def check_result_sentences(got): + try: + got_metrics = self._check_got(got, rouge_computation) + self.assertAlmostEqual(1 / 2, got_metrics[rouge_key].precision) + self.assertAlmostEqual(1 / 2, got_metrics[rouge_key].recall) + self.assertAlmostEqual(1 / 2, got_metrics[rouge_key].fmeasure) + + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(result, check_result_sentences, label="result") + self.assertNotIn(tokenizer_preparer_logging_message, cm.output) + + def check_split_summaries_result(): + with beam.Pipeline() as pipeline: + result = _get_result( + pipeline=pipeline, + examples=[example], + combiner=rouge_computation.combiner, + ) + + def check_result_nltk(got): + try: + got_metrics = self._check_got(got, rouge_computation) + self.assertAlmostEqual(1, got_metrics[rouge_key].precision) + self.assertAlmostEqual(1, got_metrics[rouge_key].recall) + self.assertAlmostEqual(1, got_metrics[rouge_key].fmeasure) + + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(result, check_result_nltk, label="result") + + # Split summaries into sentences using nltk + rouge_computation = rouge.Rouge( + rouge_type, use_stemmer=True, split_summaries=True + ).computations()[0] + check_split_summaries_result() + + def testRougeTokenizer(self): + rouge_type = "rouge1" + target_text = "testing one two" + prediction_text = "testing" + example = { + constants.LABELS_KEY: target_text, + constants.PREDICTIONS_KEY: prediction_text, + } + rouge_key = metric_types.MetricKey(name=rouge_type) + rouge_computation = rouge.Rouge( + rouge_type, tokenizer=tokenizers.DefaultTokenizer() + ).computations()[0] + with beam.Pipeline() as pipeline: + result = _get_result( + pipeline=pipeline, + examples=[example], + combiner=rouge_computation.combiner, + ) - def check_result_nltk(got): - try: - got_metrics = self._check_got(got, rouge_computation) - self.assertAlmostEqual(1, got_metrics[rouge_key].precision) - self.assertAlmostEqual(1, got_metrics[rouge_key].recall) - self.assertAlmostEqual(1, got_metrics[rouge_key].fmeasure) - - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(result, check_result_nltk, label='result') - - # Split summaries into sentences using nltk - rouge_computation = rouge.Rouge( - rouge_type, use_stemmer=True, split_summaries=True - ).computations()[0] - check_split_summaries_result() - - def testRougeTokenizer(self): - rouge_type = 'rouge1' - target_text = 'testing one two' - prediction_text = 'testing' - example = { - constants.LABELS_KEY: target_text, - constants.PREDICTIONS_KEY: prediction_text, - } - rouge_key = metric_types.MetricKey(name=rouge_type) - rouge_computation = rouge.Rouge( - rouge_type, tokenizer=tokenizers.DefaultTokenizer() - ).computations()[0] - with beam.Pipeline() as pipeline: - result = _get_result( - pipeline=pipeline, - examples=[example], - combiner=rouge_computation.combiner, - ) - - def check_result(got): - try: - got_metrics = self._check_got(got, rouge_computation) - self.assertAlmostEqual(1, got_metrics[rouge_key].precision) - self.assertAlmostEqual(1 / 3, got_metrics[rouge_key].recall) - self.assertAlmostEqual(1 / 2, got_metrics[rouge_key].fmeasure) - - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(result, check_result, label='result') + def check_result(got): + try: + got_metrics = self._check_got(got, rouge_computation) + self.assertAlmostEqual(1, got_metrics[rouge_key].precision) + self.assertAlmostEqual(1 / 3, got_metrics[rouge_key].recall) + self.assertAlmostEqual(1 / 2, got_metrics[rouge_key].fmeasure) + except AssertionError as err: + raise util.BeamAssertException(err) -class RougeEnd2EndTest(parameterized.TestCase): + util.assert_that(result, check_result, label="result") - def testRougeEnd2End(self): - # Same tests as RougeTest.testRougeMultipleExamplesWeighted - eval_config = text_format.Parse( - """ + +class RougeEnd2EndTest(parameterized.TestCase): + def testRougeEnd2End(self): + # Same tests as RougeTest.testRougeMultipleExamplesWeighted + eval_config = text_format.Parse( + """ model_specs { label_key: "labels" prediction_key: "predictions" @@ -660,86 +667,82 @@ def testRougeEnd2End(self): } } """, - config_pb2.EvalConfig(), - ) - rouge_types = ['rouge1', 'rouge2', 'rougeL', 'rougeLsum'] - example_weights = [0.5, 0.7] - extracts = [ - { - constants.LABELS_KEY: np.array(['testing one two']), - constants.PREDICTIONS_KEY: np.array(['testing']), - constants.EXAMPLE_WEIGHTS_KEY: np.array([example_weights[0]]), - constants.FEATURES_KEY: None, - constants.SLICE_KEY_TYPES_KEY: slicer.slice_keys_to_numpy_array( - [()] - ), - }, - { - constants.LABELS_KEY: np.array(['This is a test']), - constants.PREDICTIONS_KEY: np.array(['This is not a test']), - constants.EXAMPLE_WEIGHTS_KEY: np.array([example_weights[1]]), - constants.FEATURES_KEY: None, - constants.SLICE_KEY_TYPES_KEY: slicer.slice_keys_to_numpy_array( - [()] - ), - }, - ] - - # Values are [unweighed_score_for_example_1, unweighted_score_for_example_2] - # where the scores are precision, recall, and fmeasure. - expected_unweighted_scores = { - 'rouge1': ([1, 4 / 5], [1 / 3, 1], [1 / 2, 8 / 9]), - 'rouge2': ([0, 1 / 2], [0, 2 / 3], [0, 4 / 7]), - 'rougeL': ([1, 4 / 5], [1 / 3, 1], [1 / 2, 8 / 9]), - 'rougeLsum': ([1, 4 / 5], [1 / 3, 1], [1 / 2, 8 / 9]), - } - for rouge_type in rouge_types: - rouge_key = metric_types.MetricKey(name=rouge_type) - - with beam.Pipeline() as pipeline: - result = ( - pipeline - | 'LoadData' >> beam.Create(extracts) - | 'ExtractEval' - >> metrics_plots_and_validations_evaluator.MetricsPlotsAndValidationsEvaluator( - eval_config=eval_config - ).ptransform + config_pb2.EvalConfig(), ) - - def check_result(got, rouge_key=rouge_key, rouge_type=rouge_type): - try: - self.assertLen(got, 1) - got_slice_key, got_metrics = got[0] - self.assertEqual(got_slice_key, ()) - self.assertIn(rouge_key, got_metrics.keys()) - self.assertAlmostEqual( - np.average( - expected_unweighted_scores[rouge_type][0], - weights=example_weights, - ), - got_metrics[rouge_key].precision, - ) - self.assertAlmostEqual( - np.average( - expected_unweighted_scores[rouge_type][1], - weights=example_weights, - ), - got_metrics[rouge_key].recall, - ) - self.assertAlmostEqual( - np.average( - expected_unweighted_scores[rouge_type][2], - weights=example_weights, - ), - got_metrics[rouge_key].fmeasure, - ) - - except AssertionError as err: - raise util.BeamAssertException(err) - - self.assertIn('metrics', result) - util.assert_that(result['metrics'], check_result, label='result') - - -if __name__ == '__main__': - tf.test.main() + rouge_types = ["rouge1", "rouge2", "rougeL", "rougeLsum"] + example_weights = [0.5, 0.7] + extracts = [ + { + constants.LABELS_KEY: np.array(["testing one two"]), + constants.PREDICTIONS_KEY: np.array(["testing"]), + constants.EXAMPLE_WEIGHTS_KEY: np.array([example_weights[0]]), + constants.FEATURES_KEY: None, + constants.SLICE_KEY_TYPES_KEY: slicer.slice_keys_to_numpy_array([()]), + }, + { + constants.LABELS_KEY: np.array(["This is a test"]), + constants.PREDICTIONS_KEY: np.array(["This is not a test"]), + constants.EXAMPLE_WEIGHTS_KEY: np.array([example_weights[1]]), + constants.FEATURES_KEY: None, + constants.SLICE_KEY_TYPES_KEY: slicer.slice_keys_to_numpy_array([()]), + }, + ] + + # Values are [unweighed_score_for_example_1, unweighted_score_for_example_2] + # where the scores are precision, recall, and fmeasure. + expected_unweighted_scores = { + "rouge1": ([1, 4 / 5], [1 / 3, 1], [1 / 2, 8 / 9]), + "rouge2": ([0, 1 / 2], [0, 2 / 3], [0, 4 / 7]), + "rougeL": ([1, 4 / 5], [1 / 3, 1], [1 / 2, 8 / 9]), + "rougeLsum": ([1, 4 / 5], [1 / 3, 1], [1 / 2, 8 / 9]), + } + for rouge_type in rouge_types: + rouge_key = metric_types.MetricKey(name=rouge_type) + + with beam.Pipeline() as pipeline: + result = ( + pipeline + | "LoadData" >> beam.Create(extracts) + | "ExtractEval" + >> metrics_plots_and_validations_evaluator.MetricsPlotsAndValidationsEvaluator( + eval_config=eval_config + ).ptransform + ) + + def check_result(got, rouge_key=rouge_key, rouge_type=rouge_type): + try: + self.assertLen(got, 1) + got_slice_key, got_metrics = got[0] + self.assertEqual(got_slice_key, ()) + self.assertIn(rouge_key, got_metrics.keys()) + self.assertAlmostEqual( + np.average( + expected_unweighted_scores[rouge_type][0], + weights=example_weights, + ), + got_metrics[rouge_key].precision, + ) + self.assertAlmostEqual( + np.average( + expected_unweighted_scores[rouge_type][1], + weights=example_weights, + ), + got_metrics[rouge_key].recall, + ) + self.assertAlmostEqual( + np.average( + expected_unweighted_scores[rouge_type][2], + weights=example_weights, + ), + got_metrics[rouge_key].fmeasure, + ) + + except AssertionError as err: + raise util.BeamAssertException(err) + + self.assertIn("metrics", result) + util.assert_that(result["metrics"], check_result, label="result") + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow_model_analysis/metrics/sample_metrics.py b/tensorflow_model_analysis/metrics/sample_metrics.py index 10abcee60b..7f3a1b8dd5 100644 --- a/tensorflow_model_analysis/metrics/sample_metrics.py +++ b/tensorflow_model_analysis/metrics/sample_metrics.py @@ -13,121 +13,136 @@ # limitations under the License. """A collection of metrics which sample per-example values.""" -from typing import Any, List, Optional, Text, Tuple +from typing import Any, List, Optional, Tuple import apache_beam as beam import numpy as np + from tensorflow_model_analysis.metrics import metric_types -from tensorflow_model_analysis.utils import beam_util -from tensorflow_model_analysis.utils import util +from tensorflow_model_analysis.utils import beam_util, util -FIXED_SIZE_SAMPLE_NAME = 'fixed_size_sample' +FIXED_SIZE_SAMPLE_NAME = "fixed_size_sample" # This corresponds to the comments in apache_beam/transforms/combiners.py _HeapType = Tuple[bool, List[Any]] class FixedSizeSample(metric_types.Metric): - """Computes a fixed-size sample per slice.""" - - def __init__(self, - sampled_key: Text, - size: int, - name: Text = FIXED_SIZE_SAMPLE_NAME, - random_seed: Optional[int] = None): - """Initializes a FixedSizeSample metric. - - Args: - sampled_key: The key whose values should be sampled - size: The number of samples to collect (per slice) - name: Metric name. - random_seed: The random_seed to be used for intializing the per worker - np.random.RandomGenerator in the CombineFn setup. Note that when more - than one worker is used, setting this is not sufficient to guarantee - determinism. - """ - super().__init__( - _fixed_size_sample, - sampled_key=sampled_key, - size=size, - name=name, - random_seed=random_seed) + """Computes a fixed-size sample per slice.""" + + def __init__( + self, + sampled_key: str, + size: int, + name: str = FIXED_SIZE_SAMPLE_NAME, + random_seed: Optional[int] = None, + ): + """Initializes a FixedSizeSample metric. + + Args: + ---- + sampled_key: The key whose values should be sampled + size: The number of samples to collect (per slice) + name: Metric name. + random_seed: The random_seed to be used for intializing the per worker + np.random.RandomGenerator in the CombineFn setup. Note that when more + than one worker is used, setting this is not sufficient to guarantee + determinism. + """ + super().__init__( + _fixed_size_sample, + sampled_key=sampled_key, + size=size, + name=name, + random_seed=random_seed, + ) metric_types.register_metric(FixedSizeSample) def _fixed_size_sample( - sampled_key: Text, + sampled_key: str, size: int, - name: Text, + name: str, random_seed: Optional[int], - model_names: Optional[List[Text]] = None, - output_names: Optional[List[Text]] = None, + model_names: Optional[List[str]] = None, + output_names: Optional[List[str]] = None, sub_keys: Optional[List[metric_types.SubKey]] = None, - example_weighted: bool = False) -> metric_types.MetricComputations: - """Returns metrics computations for FixedSizeSample metrcs.""" - keys = [] - for model_name in model_names or ['']: - for output_name in output_names or ['']: - for sub_key in sub_keys or [None]: - keys.append( - metric_types.MetricKey( - name, - model_name=model_name, - output_name=output_name, - sub_key=sub_key, - example_weighted=example_weighted)) - return [ - metric_types.MetricComputation( - keys=keys, - preprocessors=[ - metric_types.FeaturePreprocessor(feature_keys=[sampled_key]) - ], - combiner=_FixedSizeSampleCombineFn( - metric_keys=keys, - sampled_key=sampled_key, - size=size, - example_weighted=example_weighted, - random_seed=random_seed)) - ] + example_weighted: bool = False, +) -> metric_types.MetricComputations: + """Returns metrics computations for FixedSizeSample metrcs.""" + keys = [] + for model_name in model_names or [""]: + for output_name in output_names or [""]: + for sub_key in sub_keys or [None]: + keys.append( + metric_types.MetricKey( + name, + model_name=model_name, + output_name=output_name, + sub_key=sub_key, + example_weighted=example_weighted, + ) + ) + return [ + metric_types.MetricComputation( + keys=keys, + preprocessors=[ + metric_types.FeaturePreprocessor(feature_keys=[sampled_key]) + ], + combiner=_FixedSizeSampleCombineFn( + metric_keys=keys, + sampled_key=sampled_key, + size=size, + example_weighted=example_weighted, + random_seed=random_seed, + ), + ) + ] class _FixedSizeSampleCombineFn(beam_util.DelegatingCombineFn): - """A fixed size sample combiner which samples values of a specified key. - - This CombineFn is similar to beam.combiners.SampleCombineFn except it makes - use of the numpy random generator which means that it accepts a seed for use - with deterministic testing. - """ - - def __init__(self, metric_keys: List[metric_types.MetricKey], - sampled_key: Text, size: int, example_weighted: bool, - random_seed: Optional[int]): - self._metric_keys = metric_keys - self._sampled_key = sampled_key - self._example_weighted = example_weighted - self._random_seed = random_seed - # We delegate to the TopCombineFn rather than subclass because the use of a - # TopCombineFn is an implementation detail. - super().__init__(combine_fn=beam.combiners.TopCombineFn(n=size)) - - def setup(self): - self._random_generator = np.random.default_rng(self._random_seed) - - def add_input(self, heap: _HeapType, - element: metric_types.StandardMetricInputs) -> _HeapType: - # TODO(b/206546545): add support for sampling derived features - sampled_value = util.get_by_keys(element.features, [self._sampled_key]) - random_tag = self._random_generator.random() - if self._example_weighted: - # For details, see Weighted Random Sampling over Data Streams: - # https://arxiv.org/abs/1012.0256 - weight = element.example_weight - random_tag = random_tag**(1 / weight) - return super().add_input(heap, (random_tag, sampled_value)) - - def extract_output(self, heap: _HeapType) -> metric_types.MetricsDict: - # drop random numbers used for sampling - sampled_values = np.array([v for _, v in super().extract_output(heap)]) - return {k: sampled_values for k in self._metric_keys} + """A fixed size sample combiner which samples values of a specified key. + + This CombineFn is similar to beam.combiners.SampleCombineFn except it makes + use of the numpy random generator which means that it accepts a seed for use + with deterministic testing. + """ + + def __init__( + self, + metric_keys: List[metric_types.MetricKey], + sampled_key: str, + size: int, + example_weighted: bool, + random_seed: Optional[int], + ): + self._metric_keys = metric_keys + self._sampled_key = sampled_key + self._example_weighted = example_weighted + self._random_seed = random_seed + # We delegate to the TopCombineFn rather than subclass because the use of a + # TopCombineFn is an implementation detail. + super().__init__(combine_fn=beam.combiners.TopCombineFn(n=size)) + + def setup(self): + self._random_generator = np.random.default_rng(self._random_seed) + + def add_input( + self, heap: _HeapType, element: metric_types.StandardMetricInputs + ) -> _HeapType: + # TODO(b/206546545): add support for sampling derived features + sampled_value = util.get_by_keys(element.features, [self._sampled_key]) + random_tag = self._random_generator.random() + if self._example_weighted: + # For details, see Weighted Random Sampling over Data Streams: + # https://arxiv.org/abs/1012.0256 + weight = element.example_weight + random_tag = random_tag ** (1 / weight) + return super().add_input(heap, (random_tag, sampled_value)) + + def extract_output(self, heap: _HeapType) -> metric_types.MetricsDict: + # drop random numbers used for sampling + sampled_values = np.array([v for _, v in super().extract_output(heap)]) + return {k: sampled_values for k in self._metric_keys} diff --git a/tensorflow_model_analysis/metrics/sample_metrics_test.py b/tensorflow_model_analysis/metrics/sample_metrics_test.py index 6c8a7382d1..4d141b7630 100644 --- a/tensorflow_model_analysis/metrics/sample_metrics_test.py +++ b/tensorflow_model_analysis/metrics/sample_metrics_test.py @@ -13,102 +13,107 @@ # limitations under the License. """Tests for sample_metrics.""" -from absl.testing import absltest import apache_beam as beam -from apache_beam.testing import util import numpy as np +from absl.testing import absltest +from apache_beam.testing import util + from tensorflow_model_analysis import constants -from tensorflow_model_analysis.metrics import metric_types -from tensorflow_model_analysis.metrics import metric_util -from tensorflow_model_analysis.metrics import sample_metrics +from tensorflow_model_analysis.metrics import metric_types, metric_util, sample_metrics class SampleTest(absltest.TestCase): - - def testFixedSizeSample(self): - metric = sample_metrics.FixedSizeSample( - sampled_key='sampled_key', size=2, random_seed=0).computations()[0] - - examples = [] - for i in range(5): - examples.append({ - constants.LABELS_KEY: np.array([0]), - constants.PREDICTIONS_KEY: np.array([1]), - constants.FEATURES_KEY: { - 'sampled_key': i - } - }) - - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'Create' >> beam.Create(examples, reshuffle=False) - | 'PreProcess' >> beam.ParDo(metric.preprocessors[0]) - | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs) - | 'AddSlice' >> beam.Map(lambda x: ((), x)) - | 'ComputeMetric' >> beam.CombinePerKey(metric.combiner)) - - # pylint: enable=no-value-for-parameter - - def check_result(got): - try: - self.assertLen(got, 1) - got_slice_key, got_metrics = got[0] - self.assertEqual(got_slice_key, ()) - fixed_sized_sample_key = metric_types.MetricKey( - name='fixed_size_sample') - np.testing.assert_equal(got_metrics, - {fixed_sized_sample_key: np.array([4, 0])}) - - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(result, check_result) - - def testFixedSizeSampleWeighted(self): - metric = sample_metrics.FixedSizeSample( - sampled_key='sampled_key', size=2, - random_seed=0).computations(example_weighted=True)[0] - - examples = [] - for i in range(5): - examples.append({ - constants.LABELS_KEY: np.array([0]), - constants.PREDICTIONS_KEY: np.array([1]), - constants.EXAMPLE_WEIGHTS_KEY: np.array([10**i]), - constants.FEATURES_KEY: { - 'sampled_key': i - } - }) - - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'Create' >> beam.Create(examples, reshuffle=False) - # | 'Process' >> beam.ParDo(metric.preprocessors[0]) - | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs) - | 'AddSlice' >> beam.Map(lambda x: ((), x)) - | 'ComputeMetric' >> beam.CombinePerKey(metric.combiner)) - - # pylint: enable=no-value-for-parameter - - def check_result(got): - try: - self.assertLen(got, 1) - got_slice_key, got_metrics = got[0] - self.assertEqual(got_slice_key, ()) - fixed_sized_sample_key = metric_types.MetricKey( - name='fixed_size_sample', example_weighted=True) - np.testing.assert_equal(got_metrics, - {fixed_sized_sample_key: np.array([4, 3])}) - - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(result, check_result) - - -if __name__ == '__main__': - absltest.main() + def testFixedSizeSample(self): + metric = sample_metrics.FixedSizeSample( + sampled_key="sampled_key", size=2, random_seed=0 + ).computations()[0] + + examples = [] + for i in range(5): + examples.append( + { + constants.LABELS_KEY: np.array([0]), + constants.PREDICTIONS_KEY: np.array([1]), + constants.FEATURES_KEY: {"sampled_key": i}, + } + ) + + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "Create" >> beam.Create(examples, reshuffle=False) + | "PreProcess" >> beam.ParDo(metric.preprocessors[0]) + | "Process" >> beam.Map(metric_util.to_standard_metric_inputs) + | "AddSlice" >> beam.Map(lambda x: ((), x)) + | "ComputeMetric" >> beam.CombinePerKey(metric.combiner) + ) + + # pylint: enable=no-value-for-parameter + + def check_result(got): + try: + self.assertLen(got, 1) + got_slice_key, got_metrics = got[0] + self.assertEqual(got_slice_key, ()) + fixed_sized_sample_key = metric_types.MetricKey( + name="fixed_size_sample" + ) + np.testing.assert_equal( + got_metrics, {fixed_sized_sample_key: np.array([4, 0])} + ) + + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(result, check_result) + + def testFixedSizeSampleWeighted(self): + metric = sample_metrics.FixedSizeSample( + sampled_key="sampled_key", size=2, random_seed=0 + ).computations(example_weighted=True)[0] + + examples = [] + for i in range(5): + examples.append( + { + constants.LABELS_KEY: np.array([0]), + constants.PREDICTIONS_KEY: np.array([1]), + constants.EXAMPLE_WEIGHTS_KEY: np.array([10**i]), + constants.FEATURES_KEY: {"sampled_key": i}, + } + ) + + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "Create" >> beam.Create(examples, reshuffle=False) + # | 'Process' >> beam.ParDo(metric.preprocessors[0]) + | "Process" >> beam.Map(metric_util.to_standard_metric_inputs) + | "AddSlice" >> beam.Map(lambda x: ((), x)) + | "ComputeMetric" >> beam.CombinePerKey(metric.combiner) + ) + + # pylint: enable=no-value-for-parameter + + def check_result(got): + try: + self.assertLen(got, 1) + got_slice_key, got_metrics = got[0] + self.assertEqual(got_slice_key, ()) + fixed_sized_sample_key = metric_types.MetricKey( + name="fixed_size_sample", example_weighted=True + ) + np.testing.assert_equal( + got_metrics, {fixed_sized_sample_key: np.array([4, 3])} + ) + + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(result, check_result) + + +if __name__ == "__main__": + absltest.main() diff --git a/tensorflow_model_analysis/metrics/score_distribution_plot.py b/tensorflow_model_analysis/metrics/score_distribution_plot.py index 4a0d45be08..5e753ef0f6 100644 --- a/tensorflow_model_analysis/metrics/score_distribution_plot.py +++ b/tensorflow_model_analysis/metrics/score_distribution_plot.py @@ -14,130 +14,141 @@ """Score distribution Plot.""" import copy - from typing import Any, Dict, Iterator, Optional, Tuple import numpy as np + from tensorflow_model_analysis import constants -from tensorflow_model_analysis.metrics import binary_confusion_matrices -from tensorflow_model_analysis.metrics import metric_types -from tensorflow_model_analysis.metrics import metric_util +from tensorflow_model_analysis.metrics import ( + binary_confusion_matrices, + metric_types, + metric_util, +) from tensorflow_model_analysis.proto import config_pb2 DEFAULT_NUM_THRESHOLDS = 1000 -SCORE_DISTRIBUTION_PLOT_NAME = 'score_distribution_plot' +SCORE_DISTRIBUTION_PLOT_NAME = "score_distribution_plot" class ScoreDistributionPlot(metric_types.Metric): - """Score distribution plot.""" - - def __init__(self, - num_thresholds: int = DEFAULT_NUM_THRESHOLDS, - name: str = SCORE_DISTRIBUTION_PLOT_NAME): - """Initializes confusion matrix plot. - - Args: - num_thresholds: Number of thresholds to use when discretizing the curve. - Values must be > 1. Defaults to 1000. - name: Metric name. - """ - super().__init__( - metric_util.merge_per_key_computations(_confusion_matrix_plot), - num_thresholds=num_thresholds, - name=name) + """Score distribution plot.""" + + def __init__( + self, + num_thresholds: int = DEFAULT_NUM_THRESHOLDS, + name: str = SCORE_DISTRIBUTION_PLOT_NAME, + ): + """Initializes confusion matrix plot. + + Args: + ---- + num_thresholds: Number of thresholds to use when discretizing the curve. + Values must be > 1. Defaults to 1000. + name: Metric name. + """ + super().__init__( + metric_util.merge_per_key_computations(_confusion_matrix_plot), + num_thresholds=num_thresholds, + name=name, + ) metric_types.register_metric(ScoreDistributionPlot) def _extract_prediction_and_weight( - inputs: metric_types.StandardMetricInputs, - **kwargs) -> Iterator[Tuple[np.ndarray, np.ndarray, np.ndarray]]: - if 'predictions' in inputs: - modified_inputs = copy.deepcopy(inputs) - modified_inputs['labels'] = modified_inputs['predictions'] - else: - modified_inputs = inputs - return metric_util.to_label_prediction_example_weight(modified_inputs, - **kwargs) + inputs: metric_types.StandardMetricInputs, **kwargs +) -> Iterator[Tuple[np.ndarray, np.ndarray, np.ndarray]]: + if "predictions" in inputs: + modified_inputs = copy.deepcopy(inputs) + modified_inputs["labels"] = modified_inputs["predictions"] + else: + modified_inputs = inputs + return metric_util.to_label_prediction_example_weight(modified_inputs, **kwargs) def _confusion_matrix_plot( num_thresholds: int = DEFAULT_NUM_THRESHOLDS, name: str = SCORE_DISTRIBUTION_PLOT_NAME, eval_config: Optional[config_pb2.EvalConfig] = None, - model_name: str = '', - output_name: str = '', + model_name: str = "", + output_name: str = "", sub_key: Optional[metric_types.SubKey] = None, aggregation_type: Optional[metric_types.AggregationType] = None, class_weights: Optional[Dict[int, float]] = None, - example_weighted: bool = False) -> metric_types.MetricComputations: - """Returns metric computations for confusion matrix plots.""" - key = metric_types.PlotKey( - name=name, - model_name=model_name, - output_name=output_name, - sub_key=sub_key, - example_weighted=example_weighted) - - # The interoploation strategy used here matches how the legacy post export - # metrics calculated its plots. - thresholds = [i * 1.0 / num_thresholds for i in range(0, num_thresholds + 1)] - thresholds = [-1e-6] + thresholds - - modified_eval_config = None - if eval_config: - modified_eval_config = copy.deepcopy(eval_config) - # We want to completely ignore the labels, and in particular not fail if the - # label_key is not specified. - for model_spec in modified_eval_config.model_specs: - model_spec.label_key = model_spec.prediction_key - else: - modified_eval_config = config_pb2.EvalConfig() - spec = config_pb2.ModelSpec() - spec.prediction_key = constants.PREDICTIONS_KEY - # Pass the prediction key as label key to avoid failing if labels are not - # present. - spec.label_key = constants.PREDICTIONS_KEY - spec.name = model_name - modified_eval_config.model_specs.append(spec) - - # Make sure matrices are calculated. - matrices_computations = binary_confusion_matrices.binary_confusion_matrices( - # Use a custom name since we have a custom interpolation strategy which - # will cause the default naming used by the binary confusion matrix to be - # very long. - name=(binary_confusion_matrices.BINARY_CONFUSION_MATRICES_NAME + '_' + - name), - eval_config=modified_eval_config, - extract_label_prediction_and_weight=_extract_prediction_and_weight, - model_name=model_name, - output_name=output_name, - sub_key=sub_key, - aggregation_type=aggregation_type, - class_weights=class_weights, - example_weighted=example_weighted, - thresholds=thresholds, - use_histogram=False) - matrices_key = matrices_computations[-1].keys[-1] - - def result( - metrics: Dict[metric_types.MetricKey, Any] - ) -> Dict[metric_types.MetricKey, binary_confusion_matrices.Matrices]: # pytype: disable=signature-mismatch # always-use-return-annotations - confusion_matrix = metrics[matrices_key].to_proto( - ).confusion_matrix_at_thresholds - for value in confusion_matrix.matrices: - value.true_negatives += value.false_negatives - value.false_negatives = 0 - value.true_positives += value.false_positives - value.false_positives = 0 - value.precision = 0 - value.recall = 0 - return {key: confusion_matrix} - - derived_computation = metric_types.DerivedMetricComputation( - keys=[key], result=result) - computations = matrices_computations - computations.append(derived_computation) - return computations + example_weighted: bool = False, +) -> metric_types.MetricComputations: + """Returns metric computations for confusion matrix plots.""" + key = metric_types.PlotKey( + name=name, + model_name=model_name, + output_name=output_name, + sub_key=sub_key, + example_weighted=example_weighted, + ) + + # The interoploation strategy used here matches how the legacy post export + # metrics calculated its plots. + thresholds = [i * 1.0 / num_thresholds for i in range(0, num_thresholds + 1)] + thresholds = [-1e-6] + thresholds + + modified_eval_config = None + if eval_config: + modified_eval_config = copy.deepcopy(eval_config) + # We want to completely ignore the labels, and in particular not fail if the + # label_key is not specified. + for model_spec in modified_eval_config.model_specs: + model_spec.label_key = model_spec.prediction_key + else: + modified_eval_config = config_pb2.EvalConfig() + spec = config_pb2.ModelSpec() + spec.prediction_key = constants.PREDICTIONS_KEY + # Pass the prediction key as label key to avoid failing if labels are not + # present. + spec.label_key = constants.PREDICTIONS_KEY + spec.name = model_name + modified_eval_config.model_specs.append(spec) + + # Make sure matrices are calculated. + matrices_computations = binary_confusion_matrices.binary_confusion_matrices( + # Use a custom name since we have a custom interpolation strategy which + # will cause the default naming used by the binary confusion matrix to be + # very long. + name=(binary_confusion_matrices.BINARY_CONFUSION_MATRICES_NAME + "_" + name), + eval_config=modified_eval_config, + extract_label_prediction_and_weight=_extract_prediction_and_weight, + model_name=model_name, + output_name=output_name, + sub_key=sub_key, + aggregation_type=aggregation_type, + class_weights=class_weights, + example_weighted=example_weighted, + thresholds=thresholds, + use_histogram=False, + ) + matrices_key = matrices_computations[-1].keys[-1] + + def result( + metrics: Dict[metric_types.MetricKey, Any], + ) -> Dict[ + metric_types.MetricKey, binary_confusion_matrices.Matrices + ]: # pytype: disable=signature-mismatch # always-use-return-annotations + confusion_matrix = ( + metrics[matrices_key].to_proto().confusion_matrix_at_thresholds + ) + for value in confusion_matrix.matrices: + value.true_negatives += value.false_negatives + value.false_negatives = 0 + value.true_positives += value.false_positives + value.false_positives = 0 + value.precision = 0 + value.recall = 0 + return {key: confusion_matrix} + + derived_computation = metric_types.DerivedMetricComputation( + keys=[key], result=result + ) + computations = matrices_computations + computations.append(derived_computation) + return computations diff --git a/tensorflow_model_analysis/metrics/score_distribution_plot_test.py b/tensorflow_model_analysis/metrics/score_distribution_plot_test.py index d74ae730c9..ebb7acc132 100644 --- a/tensorflow_model_analysis/metrics/score_distribution_plot_test.py +++ b/tensorflow_model_analysis/metrics/score_distribution_plot_test.py @@ -14,46 +14,48 @@ """Tests for confusion matrix plot.""" import apache_beam as beam -from apache_beam.testing import util import numpy as np import tensorflow as tf -import tensorflow_model_analysis as tfma # pylint: disable=unused-import +from apache_beam.testing import util +from google.protobuf import text_format + from tensorflow_model_analysis.api import model_eval_lib from tensorflow_model_analysis.metrics import metric_types from tensorflow_model_analysis.proto import config_pb2 from tensorflow_model_analysis.utils import test_util -from google.protobuf import text_format - class ScoreDistributionPlotTest(test_util.TensorflowModelAnalysisTest): + def testScoreDistributionPlot(self): + extracts = [ + { + "features": { + "my_predictions": np.array([0.0]), + "my_weights": np.array([1.0]), + } + }, + { + "features": { + "my_predictions": np.array([0.5]), + "my_weights": np.array([1.0]), + } + }, + { + "features": { + "my_predictions": np.array([0.3]), + "my_weights": np.array([1.0]), + } + }, + { + "features": { + "my_predictions": np.array([0.9]), + "my_weights": np.array([1.0]), + } + }, + ] - def testScoreDistributionPlot(self): - - extracts = [{ - 'features': { - 'my_predictions': np.array([0.0]), - 'my_weights': np.array([1.0]), - } - }, { - 'features': { - 'my_predictions': np.array([0.5]), - 'my_weights': np.array([1.0]), - } - }, { - 'features': { - 'my_predictions': np.array([0.3]), - 'my_weights': np.array([1.0]), - } - }, { - 'features': { - 'my_predictions': np.array([0.9]), - 'my_weights': np.array([1.0]), - } - }] - - eval_config = text_format.Parse( - """ + eval_config = text_format.Parse( + """ model_specs { name: "baseline" prediction_key: "my_predictions" @@ -68,30 +70,36 @@ def testScoreDistributionPlot(self): options { compute_confidence_intervals { } - }""", config_pb2.EvalConfig()) + }""", + config_pb2.EvalConfig(), + ) - evaluators = model_eval_lib.default_evaluators(eval_config=eval_config) - extractors = model_eval_lib.default_extractors( - eval_shared_model=None, eval_config=eval_config) + evaluators = model_eval_lib.default_evaluators(eval_config=eval_config) + extractors = model_eval_lib.default_extractors( + eval_shared_model=None, eval_config=eval_config + ) - with beam.Pipeline() as pipeline: - result = ( - pipeline - | 'LoadData' >> beam.Create(extracts) - | 'ExtractEval' >> model_eval_lib.ExtractAndEvaluate( - extractors=extractors, evaluators=evaluators)) + with beam.Pipeline() as pipeline: + result = ( + pipeline + | "LoadData" >> beam.Create(extracts) + | "ExtractEval" + >> model_eval_lib.ExtractAndEvaluate( + extractors=extractors, evaluators=evaluators + ) + ) - def check_result(got): - try: - self.assertLen(got, 1) - got_slice_key, got_plots = got[0] - self.assertEqual(got_slice_key, ()) - self.assertLen(got_plots, 1) - key = metric_types.PlotKey(name='score_distribution_plot') - self.assertIn(key, got_plots) - got_plot = got_plots[key] - self.assertProtoEquals( - """ + def check_result(got): + try: + self.assertLen(got, 1) + got_slice_key, got_plots = got[0] + self.assertEqual(got_slice_key, ()) + self.assertLen(got_plots, 1) + key = metric_types.PlotKey(name="score_distribution_plot") + self.assertIn(key, got_plots) + got_plot = got_plots[key] + self.assertProtoEquals( + """ matrices { threshold: -1e-06 true_positives: 4.0 @@ -140,14 +148,14 @@ def check_result(got): false_omission_rate: 0.425 } """, - got_plot, - ) + got_plot, + ) - except AssertionError as err: - raise util.BeamAssertException(err) + except AssertionError as err: + raise util.BeamAssertException(err) - util.assert_that(result['plots'], check_result, label='result') + util.assert_that(result["plots"], check_result, label="result") -if __name__ == '__main__': - tf.test.main() +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow_model_analysis/metrics/semantic_segmentation_confusion_matrix_metrics.py b/tensorflow_model_analysis/metrics/semantic_segmentation_confusion_matrix_metrics.py index 5ccdba2470..073b46f984 100644 --- a/tensorflow_model_analysis/metrics/semantic_segmentation_confusion_matrix_metrics.py +++ b/tensorflow_model_analysis/metrics/semantic_segmentation_confusion_matrix_metrics.py @@ -15,245 +15,238 @@ import abc from typing import Any, Dict, Iterable, List, Optional, Union + import apache_beam as beam import numpy as np + from tensorflow_model_analysis.contrib.aggregates import binary_confusion_matrices -from tensorflow_model_analysis.metrics import metric_types -from tensorflow_model_analysis.metrics import metric_util -from tensorflow_model_analysis.metrics import preprocessors +from tensorflow_model_analysis.metrics import metric_types, metric_util, preprocessors from tensorflow_model_analysis.proto import config_pb2 # default values for object detection settings _DEFAULT_IOU_THRESHOLD = 0.5 -_DEFAULT_AREA_RANGE = (0, float('inf')) - -SEMANTIC_SEGMENTATION_TRUE_POSITIVES_NAME = ( - 'semantic_segmentation_true_positives' -) -SEMANTIC_SEGMENTATION_FALSE_POSITIVES_NAME = ( - 'semantic_segmentation_false_positives' -) +_DEFAULT_AREA_RANGE = (0, float("inf")) + +SEMANTIC_SEGMENTATION_TRUE_POSITIVES_NAME = "semantic_segmentation_true_positives" +SEMANTIC_SEGMENTATION_FALSE_POSITIVES_NAME = "semantic_segmentation_false_positives" # The confusion matrix metric is supposed to be a private metric. It should only # be as the intermediate results for other metrics. To access the entries, # please use metrics like TruePositive, TrueNegative, etc. -SEMANTIC_SEGMENTATION_CONFUSION_MATRIX_NAME = ( - '_semantic_segmentation_confusion_matrix' -) +SEMANTIC_SEGMENTATION_CONFUSION_MATRIX_NAME = "_semantic_segmentation_confusion_matrix" Matrix = binary_confusion_matrices.Matrix class SemanticSegmentationConfusionMatrix(metric_types.Metric): - """Computes confusion matrices for semantic segmentation.""" - - def __init__( - self, - class_ids: List[int], - ground_truth_key: str, - prediction_key: str, - decode_ground_truth: bool = True, - decode_prediction: bool = False, - ignore_ground_truth_id: Optional[int] = None, - name: Optional[str] = None, - ): - """Initializes PrecisionAtRecall metric. - - Args: - class_ids: the class ids for calculating metrics. - ground_truth_key: the key for storing the ground truth of encoded image - with class ids. - prediction_key: the key for storing the predictions of encoded image with - class ids. - decode_ground_truth: If true, the ground truth is assumed to be bytes of - images and will be decoded. By default it is true assuming the label is - the bytes of image. - decode_prediction: If true, the prediction is assumed to be bytes of - images and will be decoded. By default it is false assuming the model - outputs numpy arrays or tensors. - ignore_ground_truth_id: (Optional) The id of ground truth to be ignored. - name: (Optional) string name of the metric instance. - """ + """Computes confusion matrices for semantic segmentation.""" + + def __init__( + self, + class_ids: List[int], + ground_truth_key: str, + prediction_key: str, + decode_ground_truth: bool = True, + decode_prediction: bool = False, + ignore_ground_truth_id: Optional[int] = None, + name: Optional[str] = None, + ): + """Initializes PrecisionAtRecall metric. + + Args: + ---- + class_ids: the class ids for calculating metrics. + ground_truth_key: the key for storing the ground truth of encoded image + with class ids. + prediction_key: the key for storing the predictions of encoded image with + class ids. + decode_ground_truth: If true, the ground truth is assumed to be bytes of + images and will be decoded. By default it is true assuming the label is + the bytes of image. + decode_prediction: If true, the prediction is assumed to be bytes of + images and will be decoded. By default it is false assuming the model + outputs numpy arrays or tensors. + ignore_ground_truth_id: (Optional) The id of ground truth to be ignored. + name: (Optional) string name of the metric instance. + """ + super().__init__( + metric_util.merge_per_key_computations(self._metric_computations), + name=name, + class_ids=class_ids, + ground_truth_key=ground_truth_key, + prediction_key=prediction_key, + decode_ground_truth=decode_ground_truth, + decode_prediction=decode_prediction, + ignore_ground_truth_id=ignore_ground_truth_id, + ) - super().__init__( - metric_util.merge_per_key_computations(self._metric_computations), - name=name, - class_ids=class_ids, - ground_truth_key=ground_truth_key, - prediction_key=prediction_key, - decode_ground_truth=decode_ground_truth, - decode_prediction=decode_prediction, - ignore_ground_truth_id=ignore_ground_truth_id, - ) - - def _default_name(self) -> str: - return SEMANTIC_SEGMENTATION_CONFUSION_MATRIX_NAME - - def _metric_computations( - self, - class_ids: List[int], - ground_truth_key: str, - prediction_key: str, - decode_ground_truth: bool = True, - decode_prediction: bool = False, - ignore_ground_truth_id: Optional[int] = None, - name: Optional[str] = None, - eval_config: Optional[config_pb2.EvalConfig] = None, - model_name: str = '', - output_name: str = '', - example_weighted: bool = False, - **kwargs, - ) -> metric_types.MetricComputations: - preprocessor = preprocessors.DecodeImagePreprocessor( - ground_truth_key=ground_truth_key, - prediction_key=prediction_key, - decode_ground_truth=decode_ground_truth, - decode_prediction=decode_prediction, - name=name, - model_name=model_name, - ) - # The model output is the key to store images and other informations. - # The output name should not be set. - # sub_key should not be set. The class_id info will be encoded in sub_key - # in the combiners. - key = metric_types.MetricKey( - name=name, - model_name=model_name, - output_name='', - sub_key=None, - example_weighted=example_weighted, - ) - - return [ - metric_types.MetricComputation( - keys=[key], - preprocessors=[preprocessor], - combiner=_SemanticSegmentationConfusionMatrixCombiner( - key=key, - class_ids=class_ids, - ignore_ground_truth_id=ignore_ground_truth_id, - ), + def _default_name(self) -> str: + return SEMANTIC_SEGMENTATION_CONFUSION_MATRIX_NAME + + def _metric_computations( + self, + class_ids: List[int], + ground_truth_key: str, + prediction_key: str, + decode_ground_truth: bool = True, + decode_prediction: bool = False, + ignore_ground_truth_id: Optional[int] = None, + name: Optional[str] = None, + eval_config: Optional[config_pb2.EvalConfig] = None, + model_name: str = "", + output_name: str = "", + example_weighted: bool = False, + **kwargs, + ) -> metric_types.MetricComputations: + preprocessor = preprocessors.DecodeImagePreprocessor( + ground_truth_key=ground_truth_key, + prediction_key=prediction_key, + decode_ground_truth=decode_ground_truth, + decode_prediction=decode_prediction, + name=name, + model_name=model_name, ) - ] + # The model output is the key to store images and other informations. + # The output name should not be set. + # sub_key should not be set. The class_id info will be encoded in sub_key + # in the combiners. + key = metric_types.MetricKey( + name=name, + model_name=model_name, + output_name="", + sub_key=None, + example_weighted=example_weighted, + ) + + return [ + metric_types.MetricComputation( + keys=[key], + preprocessors=[preprocessor], + combiner=_SemanticSegmentationConfusionMatrixCombiner( + key=key, + class_ids=class_ids, + ignore_ground_truth_id=ignore_ground_truth_id, + ), + ) + ] class _SemanticSegmentationConfusionMatrixCombiner(beam.CombineFn): - """Combines semantic segmentation confusion matrices.""" - - def __init__( - self, - key: metric_types.MetricKey, - class_ids: List[int], - ignore_ground_truth_id: Optional[int] = None, - ): - """Initializes the semantic segmentation confusion matrix combiner. - - Args: - key: The metric key to identify the metric output. - class_ids: The ids of classes to calculate metrics. - ignore_ground_truth_id: The id of the ignored class. It could be used for - the class that should not be counted (e.g. masked pixels). - """ - self._key = key - self._class_ids = class_ids - self._ignore_ground_truth_id = ignore_ground_truth_id - - def create_accumulator(self) -> Dict[int, Matrix]: - return {} - - def add_input( - self, - accumulator: Dict[int, Matrix], - element: metric_types.StandardMetricInputs, - ) -> Dict[int, Matrix]: - ground_truth = element.get_labels() - prediction = element.get_predictions() - - if ground_truth.ndim != 2: - raise ValueError( - 'The ground truth should be in 2d. ' - f'But the shape is {ground_truth.shape}' - ) - if ground_truth.shape != prediction.shape: - raise ValueError( - f'The shape of prediction {prediction.shape} does not' - f' match the shape of ground truth {ground_truth.shape}' - ) - - for class_id in self._class_ids: - class_true_positive = np.sum( - np.logical_and(ground_truth == class_id, prediction == class_id) - ) - class_false_positive = np.sum( - np.logical_and( - np.logical_and(ground_truth != class_id, prediction == class_id), - ground_truth != self._ignore_ground_truth_id, - ) - ) - class_false_negative = np.sum( - np.logical_and(ground_truth == class_id, prediction != class_id) - ) - class_true_negative = np.sum( - np.logical_and( - np.logical_and(ground_truth != class_id, prediction != class_id), - ground_truth != self._ignore_ground_truth_id, - ) - ) - - class_confusion_matrix = Matrix( - tp=class_true_positive, - tn=class_true_negative, - fp=class_false_positive, - fn=class_false_negative, - ) - if class_id not in accumulator: - accumulator[class_id] = class_confusion_matrix - else: - accumulator[class_id] = Matrix( - tp=accumulator[class_id].tp + class_confusion_matrix.tp, - tn=accumulator[class_id].tn + class_confusion_matrix.tn, - fp=accumulator[class_id].fp + class_confusion_matrix.fp, - fn=accumulator[class_id].fn + class_confusion_matrix.fn, - ) - return accumulator - - def merge_accumulators( - self, - accumulators: Iterable[Dict[int, Matrix]], - ) -> Dict[int, Matrix]: - accumulators = iter(accumulators) - result = next(accumulators) - for accumulator in accumulators: - for class_id, confusion_matrix in accumulator.items(): - if class_id in result: - result[class_id] = Matrix( - tp=result[class_id].tp + confusion_matrix.tp, - tn=result[class_id].tn + confusion_matrix.tn, - fp=result[class_id].fp + confusion_matrix.fp, - fn=result[class_id].fn + confusion_matrix.fn, - ) - else: - result[class_id] = confusion_matrix - return result - - def extract_output( - self, accumulator: Dict[int, Matrix] - ) -> Dict[metric_types.MetricKey, Matrix]: - result = {} - for class_id, confusion_matrix_matrix in accumulator.items(): - new_key = self._key._replace( - sub_key=metric_types.SubKey(class_id=class_id) - ) - # In semantic segmentation metrics, there is no confidence score and thus - # no thresholds. It is set to 0 as default to reuse the binary confusion - # matrices. - matrices = Matrix( - tp=confusion_matrix_matrix.tp, - tn=confusion_matrix_matrix.tn, - fp=confusion_matrix_matrix.fp, - fn=confusion_matrix_matrix.fn, - ) - result[new_key] = matrices - return result + """Combines semantic segmentation confusion matrices.""" + + def __init__( + self, + key: metric_types.MetricKey, + class_ids: List[int], + ignore_ground_truth_id: Optional[int] = None, + ): + """Initializes the semantic segmentation confusion matrix combiner. + + Args: + ---- + key: The metric key to identify the metric output. + class_ids: The ids of classes to calculate metrics. + ignore_ground_truth_id: The id of the ignored class. It could be used for + the class that should not be counted (e.g. masked pixels). + """ + self._key = key + self._class_ids = class_ids + self._ignore_ground_truth_id = ignore_ground_truth_id + + def create_accumulator(self) -> Dict[int, Matrix]: + return {} + + def add_input( + self, + accumulator: Dict[int, Matrix], + element: metric_types.StandardMetricInputs, + ) -> Dict[int, Matrix]: + ground_truth = element.get_labels() + prediction = element.get_predictions() + + if ground_truth.ndim != 2: + raise ValueError( + "The ground truth should be in 2d. " + f"But the shape is {ground_truth.shape}" + ) + if ground_truth.shape != prediction.shape: + raise ValueError( + f"The shape of prediction {prediction.shape} does not" + f" match the shape of ground truth {ground_truth.shape}" + ) + + for class_id in self._class_ids: + class_true_positive = np.sum( + np.logical_and(ground_truth == class_id, prediction == class_id) + ) + class_false_positive = np.sum( + np.logical_and( + np.logical_and(ground_truth != class_id, prediction == class_id), + ground_truth != self._ignore_ground_truth_id, + ) + ) + class_false_negative = np.sum( + np.logical_and(ground_truth == class_id, prediction != class_id) + ) + class_true_negative = np.sum( + np.logical_and( + np.logical_and(ground_truth != class_id, prediction != class_id), + ground_truth != self._ignore_ground_truth_id, + ) + ) + + class_confusion_matrix = Matrix( + tp=class_true_positive, + tn=class_true_negative, + fp=class_false_positive, + fn=class_false_negative, + ) + if class_id not in accumulator: + accumulator[class_id] = class_confusion_matrix + else: + accumulator[class_id] = Matrix( + tp=accumulator[class_id].tp + class_confusion_matrix.tp, + tn=accumulator[class_id].tn + class_confusion_matrix.tn, + fp=accumulator[class_id].fp + class_confusion_matrix.fp, + fn=accumulator[class_id].fn + class_confusion_matrix.fn, + ) + return accumulator + + def merge_accumulators( + self, + accumulators: Iterable[Dict[int, Matrix]], + ) -> Dict[int, Matrix]: + accumulators = iter(accumulators) + result = next(accumulators) + for accumulator in accumulators: + for class_id, confusion_matrix in accumulator.items(): + if class_id in result: + result[class_id] = Matrix( + tp=result[class_id].tp + confusion_matrix.tp, + tn=result[class_id].tn + confusion_matrix.tn, + fp=result[class_id].fp + confusion_matrix.fp, + fn=result[class_id].fn + confusion_matrix.fn, + ) + else: + result[class_id] = confusion_matrix + return result + + def extract_output( + self, accumulator: Dict[int, Matrix] + ) -> Dict[metric_types.MetricKey, Matrix]: + result = {} + for class_id, confusion_matrix_matrix in accumulator.items(): + new_key = self._key._replace(sub_key=metric_types.SubKey(class_id=class_id)) + # In semantic segmentation metrics, there is no confidence score and thus + # no thresholds. It is set to 0 as default to reuse the binary confusion + # matrices. + matrices = Matrix( + tp=confusion_matrix_matrix.tp, + tn=confusion_matrix_matrix.tn, + fp=confusion_matrix_matrix.fp, + fn=confusion_matrix_matrix.fn, + ) + result[new_key] = matrices + return result metric_types.register_metric(SemanticSegmentationConfusionMatrix) @@ -262,172 +255,170 @@ def extract_output( class SemanticSegmentationConfusionMatrixMetricBase( metric_types.Metric, metaclass=abc.ABCMeta ): - """The base metric for semantic segmentation confusion matrix based metrics. - - This is the base metric for other metrics such as true postive, true negative, - false positvie and false negative. - """ - - def __init__( - self, - class_ids: List[int], - ground_truth_key: str, - prediction_key: str, - decode_ground_truth: bool = True, - decode_prediction: bool = False, - ignore_ground_truth_id: Optional[int] = None, - name: Optional[str] = None, - ): - """Initializes PrecisionAtRecall metric. - - Args: - class_ids: the class ids for calculating metrics. - ground_truth_key: the key for storing the ground truth of encoded image - with class ids. - prediction_key: the key for storing the predictions of encoded image with - class ids. - decode_ground_truth: If true, the ground truth is assumed to be bytes of - images and will be decoded. By default it is true assuming the label is - the bytes of image. - decode_prediction: If true, the prediction is assumed to be bytes of - images and will be decoded. By default it is false assuming the model - outputs numpy arrays or tensors. - ignore_ground_truth_id: (Optional) The id of ground truth to be ignored. - name: (Optional) string name of the metric instance. - """ + """The base metric for semantic segmentation confusion matrix based metrics. - super().__init__( - metric_util.merge_per_key_computations(self._metric_computations), - name=name, - class_ids=class_ids, - ground_truth_key=ground_truth_key, - prediction_key=prediction_key, - decode_ground_truth=decode_ground_truth, - decode_prediction=decode_prediction, - ignore_ground_truth_id=ignore_ground_truth_id, - ) - - @abc.abstractmethod - def _default_name(self) -> str: - """Returns the default metric name.""" - raise NotImplementedError('Must have a default name for the metric.') - - @abc.abstractmethod - def _metric_value( - self, - matrix: Matrix, - ) -> Union[float, np.ndarray]: - """Returns the metric value based the confusion matrix. - - Subclasses must override this method. - Args: - matrix: The matrix to calculate derived values. - - Return: The values calculated based on a confusion matrix. + This is the base metric for other metrics such as true postive, true negative, + false positvie and false negative. """ - raise NotImplementedError('Must be implemented to return a metric value') - - def _metric_computations( - self, - class_ids: List[int], - ground_truth_key: str, - prediction_key: str, - decode_ground_truth: bool = True, - decode_prediction: bool = False, - ignore_ground_truth_id: Optional[int] = None, - name: Optional[str] = None, - eval_config: Optional[config_pb2.EvalConfig] = None, - model_name: str = '', - output_name: str = '', - example_weighted: bool = False, - **kwargs, - ) -> metric_types.MetricComputations: - # generates private name to distinguish semantic segmentation confusion - # matrix from different configs. - semantic_segmentation_confusion_matrix_name = ( - metric_util.generate_private_name_from_arguments( + + def __init__( + self, + class_ids: List[int], + ground_truth_key: str, + prediction_key: str, + decode_ground_truth: bool = True, + decode_prediction: bool = False, + ignore_ground_truth_id: Optional[int] = None, + name: Optional[str] = None, + ): + """Initializes PrecisionAtRecall metric. + + Args: + ---- + class_ids: the class ids for calculating metrics. + ground_truth_key: the key for storing the ground truth of encoded image + with class ids. + prediction_key: the key for storing the predictions of encoded image with + class ids. + decode_ground_truth: If true, the ground truth is assumed to be bytes of + images and will be decoded. By default it is true assuming the label is + the bytes of image. + decode_prediction: If true, the prediction is assumed to be bytes of + images and will be decoded. By default it is false assuming the model + outputs numpy arrays or tensors. + ignore_ground_truth_id: (Optional) The id of ground truth to be ignored. + name: (Optional) string name of the metric instance. + """ + super().__init__( + metric_util.merge_per_key_computations(self._metric_computations), + name=name, class_ids=class_ids, ground_truth_key=ground_truth_key, prediction_key=prediction_key, decode_ground_truth=decode_ground_truth, decode_prediction=decode_prediction, ignore_ground_truth_id=ignore_ground_truth_id, - name=name, ) - ) - maxtrix_computation = SemanticSegmentationConfusionMatrix( - class_ids=class_ids, - ground_truth_key=ground_truth_key, - prediction_key=prediction_key, - decode_ground_truth=decode_ground_truth, - decode_prediction=decode_prediction, - ignore_ground_truth_id=ignore_ground_truth_id, - name=semantic_segmentation_confusion_matrix_name, - ) - computations = maxtrix_computation.computations( - model_names=[model_name], - output_names=[output_name], - example_weighted=example_weighted, - ) - # This is the key to metric output for the entire matrix computation. - # sub_key/class_id info is not encoded in the metric key yet. - matrix_key = computations[-1].keys[-1] - - key = metric_types.MetricKey( - name=name, - model_name=model_name, - output_name=output_name, - ) - - def result( - metrics: Dict[metric_types.MetricKey, Any] - ) -> Dict[metric_types.MetricKey, Union[float, np.ndarray]]: - derived_output = {} - for class_id in class_ids: - class_matrix_key = matrix_key._replace( - sub_key=metric_types.SubKey(class_id=class_id) + + @abc.abstractmethod + def _default_name(self) -> str: + """Returns the default metric name.""" + raise NotImplementedError("Must have a default name for the metric.") + + @abc.abstractmethod + def _metric_value( + self, + matrix: Matrix, + ) -> Union[float, np.ndarray]: + """Returns the metric value based the confusion matrix. + + Subclasses must override this method. + + Args: + ---- + matrix: The matrix to calculate derived values. + + Return: The values calculated based on a confusion matrix. + """ + raise NotImplementedError("Must be implemented to return a metric value") + + def _metric_computations( + self, + class_ids: List[int], + ground_truth_key: str, + prediction_key: str, + decode_ground_truth: bool = True, + decode_prediction: bool = False, + ignore_ground_truth_id: Optional[int] = None, + name: Optional[str] = None, + eval_config: Optional[config_pb2.EvalConfig] = None, + model_name: str = "", + output_name: str = "", + example_weighted: bool = False, + **kwargs, + ) -> metric_types.MetricComputations: + # generates private name to distinguish semantic segmentation confusion + # matrix from different configs. + semantic_segmentation_confusion_matrix_name = ( + metric_util.generate_private_name_from_arguments( + class_ids=class_ids, + ground_truth_key=ground_truth_key, + prediction_key=prediction_key, + decode_ground_truth=decode_ground_truth, + decode_prediction=decode_prediction, + ignore_ground_truth_id=ignore_ground_truth_id, + name=name, + ) ) - class_output_key = key._replace( - sub_key=metric_types.SubKey(class_id=class_id) + maxtrix_computation = SemanticSegmentationConfusionMatrix( + class_ids=class_ids, + ground_truth_key=ground_truth_key, + prediction_key=prediction_key, + decode_ground_truth=decode_ground_truth, + decode_prediction=decode_prediction, + ignore_ground_truth_id=ignore_ground_truth_id, + name=semantic_segmentation_confusion_matrix_name, + ) + computations = maxtrix_computation.computations( + model_names=[model_name], + output_names=[output_name], + example_weighted=example_weighted, ) - derived_output[class_output_key] = self._metric_value( - metrics[class_matrix_key] + # This is the key to metric output for the entire matrix computation. + # sub_key/class_id info is not encoded in the metric key yet. + matrix_key = computations[-1].keys[-1] + + key = metric_types.MetricKey( + name=name, + model_name=model_name, + output_name=output_name, ) - return derived_output - derived_computation = metric_types.DerivedMetricComputation( - keys=[key], result=result - ) - computations.append(derived_computation) + def result( + metrics: Dict[metric_types.MetricKey, Any], + ) -> Dict[metric_types.MetricKey, Union[float, np.ndarray]]: + derived_output = {} + for class_id in class_ids: + class_matrix_key = matrix_key._replace( + sub_key=metric_types.SubKey(class_id=class_id) + ) + class_output_key = key._replace( + sub_key=metric_types.SubKey(class_id=class_id) + ) + derived_output[class_output_key] = self._metric_value( + metrics[class_matrix_key] + ) + return derived_output + + derived_computation = metric_types.DerivedMetricComputation( + keys=[key], result=result + ) + computations.append(derived_computation) - return computations + return computations -class SemanticSegmentationTruePositive( - SemanticSegmentationConfusionMatrixMetricBase -): - """Calculates the true postive for semantic segmentation.""" +class SemanticSegmentationTruePositive(SemanticSegmentationConfusionMatrixMetricBase): + """Calculates the true postive for semantic segmentation.""" - def _default_name(self) -> str: - return SEMANTIC_SEGMENTATION_TRUE_POSITIVES_NAME + def _default_name(self) -> str: + return SEMANTIC_SEGMENTATION_TRUE_POSITIVES_NAME - def _metric_value(self, matrix: Matrix) -> float: - return matrix.tp + def _metric_value(self, matrix: Matrix) -> float: + return matrix.tp metric_types.register_metric(SemanticSegmentationTruePositive) -class SemanticSegmentationFalsePositive( - SemanticSegmentationConfusionMatrixMetricBase -): - """Calculates the true postive for semantic segmentation.""" +class SemanticSegmentationFalsePositive(SemanticSegmentationConfusionMatrixMetricBase): + """Calculates the true postive for semantic segmentation.""" - def _default_name(self) -> str: - return SEMANTIC_SEGMENTATION_FALSE_POSITIVES_NAME + def _default_name(self) -> str: + return SEMANTIC_SEGMENTATION_FALSE_POSITIVES_NAME - def _metric_value(self, matrix: Matrix) -> float: - return matrix.fp + def _metric_value(self, matrix: Matrix) -> float: + return matrix.fp metric_types.register_metric(SemanticSegmentationFalsePositive) diff --git a/tensorflow_model_analysis/metrics/semantic_segmentation_confusion_matrix_metrics_test.py b/tensorflow_model_analysis/metrics/semantic_segmentation_confusion_matrix_metrics_test.py index b19af57e4b..a0efe165aa 100644 --- a/tensorflow_model_analysis/metrics/semantic_segmentation_confusion_matrix_metrics_test.py +++ b/tensorflow_model_analysis/metrics/semantic_segmentation_confusion_matrix_metrics_test.py @@ -15,76 +15,70 @@ import io -from absl.testing import absltest -from absl.testing import parameterized import apache_beam as beam -from apache_beam.testing import util import numpy as np +from absl.testing import absltest, parameterized +from apache_beam.testing import util +from google.protobuf import text_format from PIL import Image + import tensorflow_model_analysis as tfma -from tensorflow_model_analysis.proto import config_pb2 from tensorflow_model_analysis import constants from tensorflow_model_analysis.contrib.aggregates import binary_confusion_matrices from tensorflow_model_analysis.metrics import metric_types - -from google.protobuf import text_format +from tensorflow_model_analysis.proto import config_pb2 Matrix = binary_confusion_matrices.Matrix def _encode_image_from_nparray(image_array: np.ndarray) -> bytes: - image = Image.fromarray(image_array) - encoded_buffer = io.BytesIO() - image.save(encoded_buffer, format='PNG') - return encoded_buffer.getvalue() + image = Image.fromarray(image_array) + encoded_buffer = io.BytesIO() + image.save(encoded_buffer, format="PNG") + return encoded_buffer.getvalue() class SegmentationConfusionMatrixTest(parameterized.TestCase): - - def setUp(self): - super().setUp() - label_image_array = np.array([[1, 2, 3], [1, 2, 3]], dtype=np.uint8) - prediction_image_array = np.array([[1, 1, 1], [2, 2, 2]], dtype=np.uint8) - label_encoded_image = _encode_image_from_nparray(label_image_array) - prediction_encoded_image = _encode_image_from_nparray( - prediction_image_array - ) - - label_image_array2 = np.array([[2, 2, 3], [1, 2, 3]], dtype=np.uint8) - prediction_image_array2 = np.array([[2, 1, 1], [2, 2, 2]], dtype=np.uint8) - label_encoded_image2 = _encode_image_from_nparray(label_image_array2) - prediction_encoded_image2 = _encode_image_from_nparray( - prediction_image_array2 - ) - - self._extracts = [ - { - 'features': { - constants.LABELS_KEY: { - 'image/encoded': np.array([label_encoded_image]), - }, - constants.PREDICTIONS_KEY: { - 'image/pred/encoded': np.array([prediction_encoded_image]), - }, - } - }, - { - 'features': { - constants.LABELS_KEY: { - 'image/encoded': np.array([label_encoded_image2]), - }, - constants.PREDICTIONS_KEY: { - 'image/pred/encoded': np.array([prediction_encoded_image2]), - }, - } - }, - ] - - @parameterized.named_parameters( - dict( - testcase_name='_two_class', - eval_config=text_format.Parse( - """ + def setUp(self): + super().setUp() + label_image_array = np.array([[1, 2, 3], [1, 2, 3]], dtype=np.uint8) + prediction_image_array = np.array([[1, 1, 1], [2, 2, 2]], dtype=np.uint8) + label_encoded_image = _encode_image_from_nparray(label_image_array) + prediction_encoded_image = _encode_image_from_nparray(prediction_image_array) + + label_image_array2 = np.array([[2, 2, 3], [1, 2, 3]], dtype=np.uint8) + prediction_image_array2 = np.array([[2, 1, 1], [2, 2, 2]], dtype=np.uint8) + label_encoded_image2 = _encode_image_from_nparray(label_image_array2) + prediction_encoded_image2 = _encode_image_from_nparray(prediction_image_array2) + + self._extracts = [ + { + "features": { + constants.LABELS_KEY: { + "image/encoded": np.array([label_encoded_image]), + }, + constants.PREDICTIONS_KEY: { + "image/pred/encoded": np.array([prediction_encoded_image]), + }, + } + }, + { + "features": { + constants.LABELS_KEY: { + "image/encoded": np.array([label_encoded_image2]), + }, + constants.PREDICTIONS_KEY: { + "image/pred/encoded": np.array([prediction_encoded_image2]), + }, + } + }, + ] + + @parameterized.named_parameters( + dict( + testcase_name="_two_class", + eval_config=text_format.Parse( + """ model_specs { signature_name: "serving_default" prediction_key: "predictions" # placeholder @@ -103,18 +97,18 @@ def setUp(self): } } """, - config_pb2.EvalConfig(), - ), - name='SegConfusionMatrix', - expected_result={ - 1: Matrix(tp=1, tn=5, fp=4, fn=2), - 2: Matrix(tp=3, tn=3, fp=4, fn=2), - }, - ), - dict( - testcase_name='_two_class_with_ignore', - eval_config=text_format.Parse( - """ + config_pb2.EvalConfig(), + ), + name="SegConfusionMatrix", + expected_result={ + 1: Matrix(tp=1, tn=5, fp=4, fn=2), + 2: Matrix(tp=3, tn=3, fp=4, fn=2), + }, + ), + dict( + testcase_name="_two_class_with_ignore", + eval_config=text_format.Parse( + """ model_specs { signature_name: "serving_default" prediction_key: "predictions" # placeholder @@ -134,18 +128,18 @@ def setUp(self): } } """, - config_pb2.EvalConfig(), - ), - name='SegConfusionMatrix', - expected_result={ - 1: Matrix(tp=1, tn=3, fp=2, fn=2), - 2: Matrix(tp=3, tn=1, fp=2, fn=2), - }, - ), - dict( - testcase_name='_tp_two_class_with_ignore', - eval_config=text_format.Parse( - """ + config_pb2.EvalConfig(), + ), + name="SegConfusionMatrix", + expected_result={ + 1: Matrix(tp=1, tn=3, fp=2, fn=2), + 2: Matrix(tp=3, tn=1, fp=2, fn=2), + }, + ), + dict( + testcase_name="_tp_two_class_with_ignore", + eval_config=text_format.Parse( + """ model_specs { signature_name: "serving_default" prediction_key: "predictions" # placeholder @@ -165,18 +159,18 @@ def setUp(self): } } """, - config_pb2.EvalConfig(), - ), - name='SegTruePositive', - expected_result={ - 1: np.array([1]), - 2: np.array([3]), - }, - ), - dict( - testcase_name='_fp_two_class_with_ignore', - eval_config=text_format.Parse( - """ + config_pb2.EvalConfig(), + ), + name="SegTruePositive", + expected_result={ + 1: np.array([1]), + 2: np.array([3]), + }, + ), + dict( + testcase_name="_fp_two_class_with_ignore", + eval_config=text_format.Parse( + """ model_specs { signature_name: "serving_default" prediction_key: "predictions" # placeholder @@ -196,58 +190,56 @@ def setUp(self): } } """, - config_pb2.EvalConfig(), - ), - name='SegFalsePositive', - expected_result={ - 1: np.array([2]), - 2: np.array([2]), - }, - ), - ) - def testEncodedImage(self, eval_config, name, expected_result): - extracts = self._extracts - - evaluators = tfma.default_evaluators(eval_config=eval_config) - extractors = tfma.default_extractors( - eval_shared_model=None, eval_config=eval_config + config_pb2.EvalConfig(), + ), + name="SegFalsePositive", + expected_result={ + 1: np.array([2]), + 2: np.array([2]), + }, + ), ) - - with beam.Pipeline() as p: - result = ( - p - | 'LoadData' >> beam.Create(extracts) - | 'ExtractEval' - >> tfma.ExtractAndEvaluate( - extractors=extractors, evaluators=evaluators - ) - ) - - def check_result(got): - try: - self.assertLen(got, 1) - got_slice_key, got_metrics = got[0] - self.assertEqual(got_slice_key, ()) - self.assertLen(got_metrics, 2) - - for class_id, expected_matrix in expected_result.items(): - key = metric_types.MetricKey( - name=name, sub_key=metric_types.SubKey(class_id=class_id) - ) - self.assertIn(key, got_metrics) - got_metric = got_metrics[key] - np.testing.assert_allclose( - expected_matrix, - got_metric, - rtol=1e-3, - err_msg=f'This {name} metric fails.', + def testEncodedImage(self, eval_config, name, expected_result): + extracts = self._extracts + + evaluators = tfma.default_evaluators(eval_config=eval_config) + extractors = tfma.default_extractors( + eval_shared_model=None, eval_config=eval_config + ) + + with beam.Pipeline() as p: + result = ( + p + | "LoadData" >> beam.Create(extracts) + | "ExtractEval" + >> tfma.ExtractAndEvaluate(extractors=extractors, evaluators=evaluators) ) - except AssertionError as err: - raise util.BeamAssertException(err) - - self.assertIn('metrics', result) - util.assert_that(result['metrics'], check_result, label='result') - -if __name__ == '__main__': - absltest.main() + def check_result(got): + try: + self.assertLen(got, 1) + got_slice_key, got_metrics = got[0] + self.assertEqual(got_slice_key, ()) + self.assertLen(got_metrics, 2) + + for class_id, expected_matrix in expected_result.items(): + key = metric_types.MetricKey( + name=name, sub_key=metric_types.SubKey(class_id=class_id) + ) + self.assertIn(key, got_metrics) + got_metric = got_metrics[key] + np.testing.assert_allclose( + expected_matrix, + got_metric, + rtol=1e-3, + err_msg=f"This {name} metric fails.", + ) + except AssertionError as err: + raise util.BeamAssertException(err) + + self.assertIn("metrics", result) + util.assert_that(result["metrics"], check_result, label="result") + + +if __name__ == "__main__": + absltest.main() diff --git a/tensorflow_model_analysis/metrics/set_match_confusion_matrix_metrics.py b/tensorflow_model_analysis/metrics/set_match_confusion_matrix_metrics.py index 87cf222484..38cf1dd7c6 100644 --- a/tensorflow_model_analysis/metrics/set_match_confusion_matrix_metrics.py +++ b/tensorflow_model_analysis/metrics/set_match_confusion_matrix_metrics.py @@ -13,261 +13,265 @@ # limitations under the License. """set match confusion matrices.""" -from typing import List, Optional, Union, Dict +from typing import Dict, List, Optional, Union -from tensorflow_model_analysis.metrics import confusion_matrix_metrics -from tensorflow_model_analysis.metrics import metric_types -from tensorflow_model_analysis.metrics import preprocessors +from tensorflow_model_analysis.metrics import ( + confusion_matrix_metrics, + metric_types, + preprocessors, +) from tensorflow_model_analysis.proto import config_pb2 -SET_MATCH_PRECISION_NAME = 'set_match_precision' -SET_MATCH_RECALL_NAME = 'set_match_recall' +SET_MATCH_PRECISION_NAME = "set_match_precision" +SET_MATCH_RECALL_NAME = "set_match_recall" class SetMatchPrecision(confusion_matrix_metrics.Precision): - """Computes precision for sets of labels and predictions. + """Computes precision for sets of labels and predictions. - The metric deals with labels and predictions which are provided in the format - of sets (stored as variable length numpy arrays). The precision is the - micro averaged classification precision. The metric is suitable for the case - where the number of classes is large or the list of classes could not be - provided in advance. + The metric deals with labels and predictions which are provided in the format + of sets (stored as variable length numpy arrays). The precision is the + micro averaged classification precision. The metric is suitable for the case + where the number of classes is large or the list of classes could not be + provided in advance. - Example: - Label: ['cats'], - Predictions: {'classes': ['cats, dogs']} + Example: + ------- + Label: ['cats'], + Predictions: {'classes': ['cats, dogs']} - The precision is 0.5. - """ - - def __init__( - self, - thresholds: Optional[Union[float, List[float]]] = None, - top_k: Optional[int] = None, - name: Optional[str] = None, - prediction_class_key: str = 'classes', - prediction_score_key: str = 'scores', - class_key: Optional[str] = None, - weight_key: Optional[str] = None, - **kwargs, - ): - """Initializes Precision metric. - - Args: - thresholds: (Optional) A float value or a python list/tuple of float - threshold values in [0, 1]. A threshold is compared with prediction - values to determine the truth value of predictions (i.e., above the - threshold is `true`, below is `false`). One metric value is generated - for each threshold value. If neither thresholds nor top_k are set, the - default is to calculate precision with `thresholds=0.5`. - top_k: (Optional) Used with a multi-class model to specify that the top-k - values should be used to compute the confusion matrix. The net effect is - that the non-top-k values are truncated and the matrix is then - constructed from the average TP, FP, TN, FN across the classes. When - top_k is used, metrics_specs.binarize settings must not be present. When - top_k is used, the default threshold is float('-inf'). In this case, - unmatched labels are still considered false negative, since they have - prediction with confidence score float('-inf'), - name: (Optional) string name of the metric instance. - prediction_class_key: the key name of the classes in prediction. - prediction_score_key: the key name of the scores in prediction. - class_key: (Optional) The key name of the classes in class-weight pairs. - If it is not provided, the classes are assumed to be the label classes. - weight_key: (Optional) The key name of the weights of classes in - class-weight pairs. The value in this key should be a numpy array of the - same length as the classes in class_key. The key should be stored under - the features key. - **kwargs: (Optional) Additional args to pass along to init (and eventually - on to _metric_computations and _metric_values). The args are passed to - the precision metric, the confusion matrix metric and binary - classification metric. + The precision is 0.5. """ - super().__init__( - thresholds=thresholds, - top_k=top_k, - name=name, - prediction_class_key=prediction_class_key, - prediction_score_key=prediction_score_key, - class_key=class_key, - weight_key=weight_key, + def __init__( + self, + thresholds: Optional[Union[float, List[float]]] = None, + top_k: Optional[int] = None, + name: Optional[str] = None, + prediction_class_key: str = "classes", + prediction_score_key: str = "scores", + class_key: Optional[str] = None, + weight_key: Optional[str] = None, **kwargs, - ) + ): + """Initializes Precision metric. - def _default_name(self) -> str: - return SET_MATCH_PRECISION_NAME + Args: + ---- + thresholds: (Optional) A float value or a python list/tuple of float + threshold values in [0, 1]. A threshold is compared with prediction + values to determine the truth value of predictions (i.e., above the + threshold is `true`, below is `false`). One metric value is generated + for each threshold value. If neither thresholds nor top_k are set, the + default is to calculate precision with `thresholds=0.5`. + top_k: (Optional) Used with a multi-class model to specify that the top-k + values should be used to compute the confusion matrix. The net effect is + that the non-top-k values are truncated and the matrix is then + constructed from the average TP, FP, TN, FN across the classes. When + top_k is used, metrics_specs.binarize settings must not be present. When + top_k is used, the default threshold is float('-inf'). In this case, + unmatched labels are still considered false negative, since they have + prediction with confidence score float('-inf'), + name: (Optional) string name of the metric instance. + prediction_class_key: the key name of the classes in prediction. + prediction_score_key: the key name of the scores in prediction. + class_key: (Optional) The key name of the classes in class-weight pairs. + If it is not provided, the classes are assumed to be the label classes. + weight_key: (Optional) The key name of the weights of classes in + class-weight pairs. The value in this key should be a numpy array of the + same length as the classes in class_key. The key should be stored under + the features key. + **kwargs: (Optional) Additional args to pass along to init (and eventually + on to _metric_computations and _metric_values). The args are passed to + the precision metric, the confusion matrix metric and binary + classification metric. + """ + super().__init__( + thresholds=thresholds, + top_k=top_k, + name=name, + prediction_class_key=prediction_class_key, + prediction_score_key=prediction_score_key, + class_key=class_key, + weight_key=weight_key, + **kwargs, + ) - def _metric_computations( - self, - thresholds: Optional[Union[float, List[float]]] = None, - top_k: Optional[int] = None, - name: Optional[str] = None, - prediction_class_key: str = 'classes', - prediction_score_key: str = 'scores', - class_key: Optional[str] = None, - weight_key: Optional[str] = None, - eval_config: Optional[config_pb2.EvalConfig] = None, - model_name: str = '', - sub_key: Optional[metric_types.SubKey] = None, - aggregation_type: Optional[metric_types.AggregationType] = None, - class_weights: Optional[Dict[int, float]] = None, - example_weighted: bool = False, - **kwargs, - ) -> metric_types.MetricComputations: - preprocessor = preprocessors.SetMatchPreprocessor( - top_k=top_k, - model_name=model_name, - prediction_class_key=prediction_class_key, - prediction_score_key=prediction_score_key, - class_key=class_key, - weight_key=weight_key, - ) - if top_k is not None and thresholds is None: - thresholds = float('-inf') + def _default_name(self) -> str: + return SET_MATCH_PRECISION_NAME - if weight_key: - # If example_weighted is False, it will by default set the example weights - # to 1.0. - # example_weighted could only be turned on from model_specs. However, in - # this case, the example_weights is not provided in the models. It should - # be turned on when per class weights are given. - example_weighted = True - return super()._metric_computations( - thresholds=thresholds, - name=name, - eval_config=eval_config, - model_name=model_name, - preprocessors=[preprocessor], - sub_key=sub_key, - aggregation_type=aggregation_type, - class_weights=class_weights, - example_weighted=example_weighted, + def _metric_computations( + self, + thresholds: Optional[Union[float, List[float]]] = None, + top_k: Optional[int] = None, + name: Optional[str] = None, + prediction_class_key: str = "classes", + prediction_score_key: str = "scores", + class_key: Optional[str] = None, + weight_key: Optional[str] = None, + eval_config: Optional[config_pb2.EvalConfig] = None, + model_name: str = "", + sub_key: Optional[metric_types.SubKey] = None, + aggregation_type: Optional[metric_types.AggregationType] = None, + class_weights: Optional[Dict[int, float]] = None, + example_weighted: bool = False, **kwargs, - ) + ) -> metric_types.MetricComputations: + preprocessor = preprocessors.SetMatchPreprocessor( + top_k=top_k, + model_name=model_name, + prediction_class_key=prediction_class_key, + prediction_score_key=prediction_score_key, + class_key=class_key, + weight_key=weight_key, + ) + if top_k is not None and thresholds is None: + thresholds = float("-inf") + + if weight_key: + # If example_weighted is False, it will by default set the example weights + # to 1.0. + # example_weighted could only be turned on from model_specs. However, in + # this case, the example_weights is not provided in the models. It should + # be turned on when per class weights are given. + example_weighted = True + return super()._metric_computations( + thresholds=thresholds, + name=name, + eval_config=eval_config, + model_name=model_name, + preprocessors=[preprocessor], + sub_key=sub_key, + aggregation_type=aggregation_type, + class_weights=class_weights, + example_weighted=example_weighted, + **kwargs, + ) metric_types.register_metric(SetMatchPrecision) class SetMatchRecall(confusion_matrix_metrics.Recall): - """Computes recall for sets of labels and predictions. + """Computes recall for sets of labels and predictions. - The metric deals with labels and predictions which are provided in the format - of sets (stored as variable length numpy arrays). The recall is the - micro averaged classification recall. The metric is suitable for the case - where the number of classes is large or the list of classes could not be - provided in advance. + The metric deals with labels and predictions which are provided in the format + of sets (stored as variable length numpy arrays). The recall is the + micro averaged classification recall. The metric is suitable for the case + where the number of classes is large or the list of classes could not be + provided in advance. - Example: - Label: ['cats'], - Predictions: {'classes': ['cats, dogs']} + Example: + ------- + Label: ['cats'], + Predictions: {'classes': ['cats, dogs']} - The recall is 1. - """ - - def __init__( - self, - thresholds: Optional[Union[float, List[float]]] = None, - top_k: Optional[int] = None, - name: Optional[str] = None, - prediction_class_key: str = 'classes', - prediction_score_key: str = 'scores', - class_key: Optional[str] = None, - weight_key: Optional[str] = None, - **kwargs, - ): - """Initializes recall metric. - - Args: - thresholds: (Optional) A float value or a python list/tuple of float - threshold values in [0, 1]. A threshold is compared with prediction - values to determine the truth value of predictions (i.e., above the - threshold is `true`, below is `false`). One metric value is generated - for each threshold value. If neither thresholds nor top_k are set, the - default is to calculate precision with `thresholds=0.5`. - top_k: (Optional) Used with a multi-class model to specify that the top-k - values should be used to compute the confusion matrix. The net effect is - that the non-top-k values are truncated and the matrix is then - constructed from the average TP, FP, TN, FN across the classes. When - top_k is used, metrics_specs.binarize settings must not be present. When - top_k is used, the default threshold is float('-inf'). In this case, - unmatched labels are still considered false negative, since they have - prediction with confidence score float('-inf'), - name: (Optional) string name of the metric instance. - prediction_class_key: the key name of the classes in prediction. - prediction_score_key: the key name of the scores in prediction. - class_key: (Optional) The key name of the classes in class-weight pairs. - If it is not provided, the classes are assumed to be the label classes. - weight_key: (Optional) The key name of the weights of classes in - class-weight pairs. The value in this key should be a numpy array of the - same length as the classes in class_key. The key should be stored under - the features key. - **kwargs: (Optional) Additional args to pass along to init (and eventually - on to _metric_computations and _metric_values). The args are passed to - the recall metric, the confusion matrix metric and binary classification - metric. + The recall is 1. """ - super().__init__( - thresholds=thresholds, - top_k=top_k, - name=name, - prediction_class_key=prediction_class_key, - prediction_score_key=prediction_score_key, - class_key=class_key, - weight_key=weight_key, + def __init__( + self, + thresholds: Optional[Union[float, List[float]]] = None, + top_k: Optional[int] = None, + name: Optional[str] = None, + prediction_class_key: str = "classes", + prediction_score_key: str = "scores", + class_key: Optional[str] = None, + weight_key: Optional[str] = None, **kwargs, - ) + ): + """Initializes recall metric. + + Args: + ---- + thresholds: (Optional) A float value or a python list/tuple of float + threshold values in [0, 1]. A threshold is compared with prediction + values to determine the truth value of predictions (i.e., above the + threshold is `true`, below is `false`). One metric value is generated + for each threshold value. If neither thresholds nor top_k are set, the + default is to calculate precision with `thresholds=0.5`. + top_k: (Optional) Used with a multi-class model to specify that the top-k + values should be used to compute the confusion matrix. The net effect is + that the non-top-k values are truncated and the matrix is then + constructed from the average TP, FP, TN, FN across the classes. When + top_k is used, metrics_specs.binarize settings must not be present. When + top_k is used, the default threshold is float('-inf'). In this case, + unmatched labels are still considered false negative, since they have + prediction with confidence score float('-inf'), + name: (Optional) string name of the metric instance. + prediction_class_key: the key name of the classes in prediction. + prediction_score_key: the key name of the scores in prediction. + class_key: (Optional) The key name of the classes in class-weight pairs. + If it is not provided, the classes are assumed to be the label classes. + weight_key: (Optional) The key name of the weights of classes in + class-weight pairs. The value in this key should be a numpy array of the + same length as the classes in class_key. The key should be stored under + the features key. + **kwargs: (Optional) Additional args to pass along to init (and eventually + on to _metric_computations and _metric_values). The args are passed to + the recall metric, the confusion matrix metric and binary classification + metric. + """ + super().__init__( + thresholds=thresholds, + top_k=top_k, + name=name, + prediction_class_key=prediction_class_key, + prediction_score_key=prediction_score_key, + class_key=class_key, + weight_key=weight_key, + **kwargs, + ) - def _default_name(self) -> str: - return SET_MATCH_RECALL_NAME + def _default_name(self) -> str: + return SET_MATCH_RECALL_NAME - def _metric_computations( - self, - thresholds: Optional[Union[float, List[float]]] = None, - top_k: Optional[int] = None, - name: Optional[str] = None, - prediction_class_key: str = 'classes', - prediction_score_key: str = 'scores', - class_key: Optional[str] = None, - weight_key: Optional[str] = None, - eval_config: Optional[config_pb2.EvalConfig] = None, - model_name: str = '', - sub_key: Optional[metric_types.SubKey] = None, - aggregation_type: Optional[metric_types.AggregationType] = None, - class_weights: Optional[Dict[int, float]] = None, - example_weighted: bool = False, - **kwargs, - ) -> metric_types.MetricComputations: - preprocessor = preprocessors.SetMatchPreprocessor( - top_k=top_k, - model_name=model_name, - prediction_class_key=prediction_class_key, - prediction_score_key=prediction_score_key, - class_key=class_key, - weight_key=weight_key, - ) - if top_k is not None and thresholds is None: - thresholds = float('-inf') - if weight_key: - # If example_weighted is False, it will by default set the example weights - # to 1.0. - # example_weighted could only be turned on from model_specs. However, in - # this case, the example_weights is not provided in the models. It should - # be turned on when per class weights are given. - example_weighted = True - return super()._metric_computations( - thresholds=thresholds, - name=name, - eval_config=eval_config, - model_name=model_name, - preprocessors=[preprocessor], - sub_key=sub_key, - aggregation_type=aggregation_type, - class_weights=class_weights, - example_weighted=example_weighted, + def _metric_computations( + self, + thresholds: Optional[Union[float, List[float]]] = None, + top_k: Optional[int] = None, + name: Optional[str] = None, + prediction_class_key: str = "classes", + prediction_score_key: str = "scores", + class_key: Optional[str] = None, + weight_key: Optional[str] = None, + eval_config: Optional[config_pb2.EvalConfig] = None, + model_name: str = "", + sub_key: Optional[metric_types.SubKey] = None, + aggregation_type: Optional[metric_types.AggregationType] = None, + class_weights: Optional[Dict[int, float]] = None, + example_weighted: bool = False, **kwargs, - ) + ) -> metric_types.MetricComputations: + preprocessor = preprocessors.SetMatchPreprocessor( + top_k=top_k, + model_name=model_name, + prediction_class_key=prediction_class_key, + prediction_score_key=prediction_score_key, + class_key=class_key, + weight_key=weight_key, + ) + if top_k is not None and thresholds is None: + thresholds = float("-inf") + if weight_key: + # If example_weighted is False, it will by default set the example weights + # to 1.0. + # example_weighted could only be turned on from model_specs. However, in + # this case, the example_weights is not provided in the models. It should + # be turned on when per class weights are given. + example_weighted = True + return super()._metric_computations( + thresholds=thresholds, + name=name, + eval_config=eval_config, + model_name=model_name, + preprocessors=[preprocessor], + sub_key=sub_key, + aggregation_type=aggregation_type, + class_weights=class_weights, + example_weighted=example_weighted, + **kwargs, + ) metric_types.register_metric(SetMatchRecall) diff --git a/tensorflow_model_analysis/metrics/set_match_confusion_matrix_metrics_test.py b/tensorflow_model_analysis/metrics/set_match_confusion_matrix_metrics_test.py index 7c22a3b0ba..78737fdabd 100644 --- a/tensorflow_model_analysis/metrics/set_match_confusion_matrix_metrics_test.py +++ b/tensorflow_model_analysis/metrics/set_match_confusion_matrix_metrics_test.py @@ -12,24 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. """Tests for set match related confusion matrix metrics.""" -from absl.testing import absltest -from absl.testing import parameterized + import apache_beam as beam -from apache_beam.testing import util import numpy as np +from absl.testing import absltest, parameterized +from apache_beam.testing import util +from google.protobuf import text_format + import tensorflow_model_analysis as tfma -from tensorflow_model_analysis.proto import config_pb2 from tensorflow_model_analysis.metrics import metric_types -from google.protobuf import text_format +from tensorflow_model_analysis.proto import config_pb2 class SetMatchConfusionMatrixMetricsTest(parameterized.TestCase): - - @parameterized.named_parameters( - ( - '_precision', - text_format.Parse( - """ + @parameterized.named_parameters( + ( + "_precision", + text_format.Parse( + """ model_specs { signature_name: "serving_default" prediction_key: "predictions" @@ -44,15 +44,15 @@ class SetMatchConfusionMatrixMetricsTest(parameterized.TestCase): } } """, - config_pb2.EvalConfig(), - ), - ['set_match_precision'], - [0.4], - ), - ( - '_recall', - text_format.Parse( - """ + config_pb2.EvalConfig(), + ), + ["set_match_precision"], + [0.4], + ), + ( + "_recall", + text_format.Parse( + """ model_specs { signature_name: "serving_default" prediction_key: "predictions" @@ -67,15 +67,15 @@ class SetMatchConfusionMatrixMetricsTest(parameterized.TestCase): } } """, - config_pb2.EvalConfig(), - ), - ['recall'], - [0.5], - ), - ( - '_precision_top_k', - text_format.Parse( - """ + config_pb2.EvalConfig(), + ), + ["recall"], + [0.5], + ), + ( + "_precision_top_k", + text_format.Parse( + """ model_specs { signature_name: "serving_default" prediction_key: "predictions" @@ -90,15 +90,15 @@ class SetMatchConfusionMatrixMetricsTest(parameterized.TestCase): } } """, - config_pb2.EvalConfig(), - ), - ['precision'], - [0.25], - ), - ( - '_recall_top_k', - text_format.Parse( - """ + config_pb2.EvalConfig(), + ), + ["precision"], + [0.25], + ), + ( + "_recall_top_k", + text_format.Parse( + """ model_specs { signature_name: "serving_default" prediction_key: "predictions" @@ -113,15 +113,15 @@ class SetMatchConfusionMatrixMetricsTest(parameterized.TestCase): } } """, - config_pb2.EvalConfig(), - ), - ['recall'], - [0.25], - ), - ( - '_recall_top_k_with_threshold_set', - text_format.Parse( - """ + config_pb2.EvalConfig(), + ), + ["recall"], + [0.25], + ), + ( + "_recall_top_k_with_threshold_set", + text_format.Parse( + """ model_specs { signature_name: "serving_default" prediction_key: "predictions" @@ -136,76 +136,74 @@ class SetMatchConfusionMatrixMetricsTest(parameterized.TestCase): } } """, - config_pb2.EvalConfig(), - ), - ['recall'], - [0.25], - ), - ) - def testSetMatchMetrics(self, eval_config, name_list, expected_results): - extracts = [ - { - 'features': { - 'labels': np.array([['dogs', 'cats']]), - 'predictions': { - 'classes': np.array([['dogs', 'pigs']]), - 'scores': np.array([[0.1, 0.3]]), - }, - } - }, - { - 'features': { - 'labels': np.array([['birds', 'cats']]), - 'predictions': { - 'classes': np.array([['dogs', 'pigs', 'birds']]), - 'scores': np.array([[0.1, 0.3, 0.4]]), - }, - } - }, - ] - - evaluators = tfma.default_evaluators(eval_config=eval_config) - extractors = tfma.default_extractors( - eval_shared_model=None, eval_config=eval_config + config_pb2.EvalConfig(), + ), + ["recall"], + [0.25], + ), ) + def testSetMatchMetrics(self, eval_config, name_list, expected_results): + extracts = [ + { + "features": { + "labels": np.array([["dogs", "cats"]]), + "predictions": { + "classes": np.array([["dogs", "pigs"]]), + "scores": np.array([[0.1, 0.3]]), + }, + } + }, + { + "features": { + "labels": np.array([["birds", "cats"]]), + "predictions": { + "classes": np.array([["dogs", "pigs", "birds"]]), + "scores": np.array([[0.1, 0.3, 0.4]]), + }, + } + }, + ] - with beam.Pipeline() as p: - result = ( - p - | 'LoadData' >> beam.Create(extracts) - | 'ExtractEval' - >> tfma.ExtractAndEvaluate( - extractors=extractors, evaluators=evaluators - ) - ) + evaluators = tfma.default_evaluators(eval_config=eval_config) + extractors = tfma.default_extractors( + eval_shared_model=None, eval_config=eval_config + ) - def check_result(got): - try: - self.assertLen(got, 1) - got_slice_key, got_metrics = got[0] - self.assertEqual(got_slice_key, ()) - self.assertLen(got_metrics, len(name_list)) - for name, expected_result in zip(name_list, expected_results): - key = metric_types.MetricKey(name=name) - self.assertIn(key, got_metrics) - got_metric = got_metrics[key] - np.testing.assert_allclose( - expected_result, - got_metric, - rtol=1e-3, - err_msg=f'This {name} metric fails.', + with beam.Pipeline() as p: + result = ( + p + | "LoadData" >> beam.Create(extracts) + | "ExtractEval" + >> tfma.ExtractAndEvaluate(extractors=extractors, evaluators=evaluators) ) - except AssertionError as err: - raise util.BeamAssertException(err) - self.assertIn('metrics', result) - util.assert_that(result['metrics'], check_result, label='result') + def check_result(got): + try: + self.assertLen(got, 1) + got_slice_key, got_metrics = got[0] + self.assertEqual(got_slice_key, ()) + self.assertLen(got_metrics, len(name_list)) + for name, expected_result in zip(name_list, expected_results): + key = metric_types.MetricKey(name=name) + self.assertIn(key, got_metrics) + got_metric = got_metrics[key] + np.testing.assert_allclose( + expected_result, + got_metric, + rtol=1e-3, + err_msg=f"This {name} metric fails.", + ) + except AssertionError as err: + raise util.BeamAssertException(err) - @parameterized.named_parameters( - ( - '_precision_with_class_weight', - text_format.Parse( - """ + self.assertIn("metrics", result) + util.assert_that(result["metrics"], check_result, label="result") + + @parameterized.named_parameters( + ( + "_precision_with_class_weight", + text_format.Parse( + """ model_specs { signature_name: "serving_default" prediction_key: "predictions" @@ -221,15 +219,15 @@ def check_result(got): } } """, - config_pb2.EvalConfig(), - ), - ['set_match_precision'], - [0.25], - ), - ( - '_recall_with_class_weight', - text_format.Parse( - """ + config_pb2.EvalConfig(), + ), + ["set_match_precision"], + [0.25], + ), + ( + "_recall_with_class_weight", + text_format.Parse( + """ model_specs { signature_name: "serving_default" prediction_key: "predictions" @@ -245,78 +243,76 @@ def check_result(got): } } """, - config_pb2.EvalConfig(), - ), - ['recall'], - [0.294118], - ), - ) - def testSetMatchMetricsWithClassWeights( - self, eval_config, name_list, expected_results - ): - extracts = [ - { - 'features': { - 'labels': np.array([['dogs', 'cats']]), - 'predictions': { - 'classes': np.array([['dogs', 'pigs']]), - 'scores': np.array([[0.1, 0.3]]), - 'weights': np.array([[0.1, 0.9]]), - }, - 'classes': np.array([['dogs', 'cats']]), - 'weights': np.array([[0.5, 1.2]]), - } - }, - { - 'features': { - 'labels': np.array([['birds', 'cats']]), - 'predictions': { - 'classes': np.array([['dogs', 'pigs', 'birds']]), - 'scores': np.array([[0.1, 0.3, 0.4]]), - }, - 'classes': np.array([['birds', 'cats']]), - 'weights': np.array([[0.5, 1.2]]), - } - }, - ] - - evaluators = tfma.default_evaluators(eval_config=eval_config) - extractors = tfma.default_extractors( - eval_shared_model=None, eval_config=eval_config + config_pb2.EvalConfig(), + ), + ["recall"], + [0.294118], + ), ) + def testSetMatchMetricsWithClassWeights( + self, eval_config, name_list, expected_results + ): + extracts = [ + { + "features": { + "labels": np.array([["dogs", "cats"]]), + "predictions": { + "classes": np.array([["dogs", "pigs"]]), + "scores": np.array([[0.1, 0.3]]), + "weights": np.array([[0.1, 0.9]]), + }, + "classes": np.array([["dogs", "cats"]]), + "weights": np.array([[0.5, 1.2]]), + } + }, + { + "features": { + "labels": np.array([["birds", "cats"]]), + "predictions": { + "classes": np.array([["dogs", "pigs", "birds"]]), + "scores": np.array([[0.1, 0.3, 0.4]]), + }, + "classes": np.array([["birds", "cats"]]), + "weights": np.array([[0.5, 1.2]]), + } + }, + ] - with beam.Pipeline() as p: - result = ( - p - | 'LoadData' >> beam.Create(extracts) - | 'ExtractEval' - >> tfma.ExtractAndEvaluate( - extractors=extractors, evaluators=evaluators - ) - ) + evaluators = tfma.default_evaluators(eval_config=eval_config) + extractors = tfma.default_extractors( + eval_shared_model=None, eval_config=eval_config + ) - def check_result(got): - try: - self.assertLen(got, 1) - got_slice_key, got_metrics = got[0] - self.assertEqual(got_slice_key, ()) - self.assertLen(got_metrics, len(name_list)) - for name, expected_result in zip(name_list, expected_results): - key = metric_types.MetricKey(name=name, example_weighted=True) - self.assertIn(key, got_metrics) - got_metric = got_metrics[key] - np.testing.assert_allclose( - expected_result, - got_metric, - rtol=1e-3, - err_msg=f'This {name} metric fails.', + with beam.Pipeline() as p: + result = ( + p + | "LoadData" >> beam.Create(extracts) + | "ExtractEval" + >> tfma.ExtractAndEvaluate(extractors=extractors, evaluators=evaluators) ) - except AssertionError as err: - raise util.BeamAssertException(err) - self.assertIn('metrics', result) - util.assert_that(result['metrics'], check_result, label='result') + def check_result(got): + try: + self.assertLen(got, 1) + got_slice_key, got_metrics = got[0] + self.assertEqual(got_slice_key, ()) + self.assertLen(got_metrics, len(name_list)) + for name, expected_result in zip(name_list, expected_results): + key = metric_types.MetricKey(name=name, example_weighted=True) + self.assertIn(key, got_metrics) + got_metric = got_metrics[key] + np.testing.assert_allclose( + expected_result, + got_metric, + rtol=1e-3, + err_msg=f"This {name} metric fails.", + ) + except AssertionError as err: + raise util.BeamAssertException(err) + + self.assertIn("metrics", result) + util.assert_that(result["metrics"], check_result, label="result") -if __name__ == '__main__': - absltest.main() +if __name__ == "__main__": + absltest.main() diff --git a/tensorflow_model_analysis/metrics/squared_pearson_correlation.py b/tensorflow_model_analysis/metrics/squared_pearson_correlation.py index 5e61595530..e493f265b7 100644 --- a/tensorflow_model_analysis/metrics/squared_pearson_correlation.py +++ b/tensorflow_model_analysis/metrics/squared_pearson_correlation.py @@ -16,25 +16,27 @@ from typing import Dict, Iterable, Optional import apache_beam as beam -from tensorflow_model_analysis.metrics import metric_types -from tensorflow_model_analysis.metrics import metric_util + +from tensorflow_model_analysis.metrics import metric_types, metric_util from tensorflow_model_analysis.proto import config_pb2 -SQUARED_PEARSON_CORRELATION_NAME = 'squared_pearson_correlation' +SQUARED_PEARSON_CORRELATION_NAME = "squared_pearson_correlation" class SquaredPearsonCorrelation(metric_types.Metric): - """Squared pearson correlation (r^2) metric.""" + """Squared pearson correlation (r^2) metric.""" - def __init__(self, name: str = SQUARED_PEARSON_CORRELATION_NAME): - """Initializes squared pearson correlation (r^2) metric. + def __init__(self, name: str = SQUARED_PEARSON_CORRELATION_NAME): + """Initializes squared pearson correlation (r^2) metric. - Args: - name: Metric name. - """ - super().__init__( - metric_util.merge_per_key_computations(_squared_pearson_correlation), - name=name) + Args: + ---- + name: Metric name. + """ + super().__init__( + metric_util.merge_per_key_computations(_squared_pearson_correlation), + name=name, + ) metric_types.register_metric(SquaredPearsonCorrelation) @@ -43,150 +45,172 @@ def __init__(self, name: str = SQUARED_PEARSON_CORRELATION_NAME): def _squared_pearson_correlation( name: str = SQUARED_PEARSON_CORRELATION_NAME, eval_config: Optional[config_pb2.EvalConfig] = None, - model_name: str = '', - output_name: str = '', + model_name: str = "", + output_name: str = "", sub_key: Optional[metric_types.SubKey] = None, aggregation_type: Optional[metric_types.AggregationType] = None, class_weights: Optional[Dict[int, float]] = None, - example_weighted: bool = False) -> metric_types.MetricComputations: - """Returns metric computations for squared pearson correlation (r^2).""" - key = metric_types.MetricKey( - name=name, - model_name=model_name, - output_name=output_name, - sub_key=sub_key, - example_weighted=example_weighted) - return [ - metric_types.MetricComputation( - keys=[key], - preprocessors=None, - combiner=_SquaredPearsonCorrelationCombiner(key, eval_config, - aggregation_type, - class_weights, - example_weighted)) - ] + example_weighted: bool = False, +) -> metric_types.MetricComputations: + """Returns metric computations for squared pearson correlation (r^2).""" + key = metric_types.MetricKey( + name=name, + model_name=model_name, + output_name=output_name, + sub_key=sub_key, + example_weighted=example_weighted, + ) + return [ + metric_types.MetricComputation( + keys=[key], + preprocessors=None, + combiner=_SquaredPearsonCorrelationCombiner( + key, eval_config, aggregation_type, class_weights, example_weighted + ), + ) + ] class _SquaredPearsonCorrelationAccumulator: - """Squared pearson correlation (r^2) accumulator.""" - __slots__ = [ - 'total_weighted_labels', 'total_weighted_predictions', - 'total_weighted_squared_labels', 'total_weighted_squared_predictions', - 'total_weighted_labels_times_predictions', 'total_weighted_examples' - ] - - def __init__(self): - self.total_weighted_labels = 0.0 - self.total_weighted_predictions = 0.0 - self.total_weighted_squared_labels = 0.0 - self.total_weighted_squared_predictions = 0.0 - self.total_weighted_labels_times_predictions = 0.0 - self.total_weighted_examples = 0.0 + """Squared pearson correlation (r^2) accumulator.""" + + __slots__ = [ + "total_weighted_labels", + "total_weighted_predictions", + "total_weighted_squared_labels", + "total_weighted_squared_predictions", + "total_weighted_labels_times_predictions", + "total_weighted_examples", + ] + + def __init__(self): + self.total_weighted_labels = 0.0 + self.total_weighted_predictions = 0.0 + self.total_weighted_squared_labels = 0.0 + self.total_weighted_squared_predictions = 0.0 + self.total_weighted_labels_times_predictions = 0.0 + self.total_weighted_examples = 0.0 class _SquaredPearsonCorrelationCombiner(beam.CombineFn): - """Computes squared pearson correlation (r^2) metric.""" - - def __init__(self, key: metric_types.MetricKey, - eval_config: Optional[config_pb2.EvalConfig], - aggregation_type: Optional[metric_types.AggregationType], - class_weights: Optional[Dict[int, - float]], example_weighted: bool): - self._key = key - self._eval_config = eval_config - self._aggregation_type = aggregation_type - self._class_weights = class_weights - self._example_weighted = example_weighted - - def create_accumulator(self) -> _SquaredPearsonCorrelationAccumulator: - return _SquaredPearsonCorrelationAccumulator() - - def add_input( - self, accumulator: _SquaredPearsonCorrelationAccumulator, - element: metric_types.StandardMetricInputs - ) -> _SquaredPearsonCorrelationAccumulator: - for label, prediction, example_weight in ( - metric_util.to_label_prediction_example_weight( + """Computes squared pearson correlation (r^2) metric.""" + + def __init__( + self, + key: metric_types.MetricKey, + eval_config: Optional[config_pb2.EvalConfig], + aggregation_type: Optional[metric_types.AggregationType], + class_weights: Optional[Dict[int, float]], + example_weighted: bool, + ): + self._key = key + self._eval_config = eval_config + self._aggregation_type = aggregation_type + self._class_weights = class_weights + self._example_weighted = example_weighted + + def create_accumulator(self) -> _SquaredPearsonCorrelationAccumulator: + return _SquaredPearsonCorrelationAccumulator() + + def add_input( + self, + accumulator: _SquaredPearsonCorrelationAccumulator, + element: metric_types.StandardMetricInputs, + ) -> _SquaredPearsonCorrelationAccumulator: + for ( + label, + prediction, + example_weight, + ) in metric_util.to_label_prediction_example_weight( element, eval_config=self._eval_config, model_name=self._key.model_name, output_name=self._key.output_name, aggregation_type=self._aggregation_type, class_weights=self._class_weights, - example_weighted=self._example_weighted)): - example_weight = float(example_weight) - label = float(label) - prediction = float(prediction) - accumulator.total_weighted_labels += example_weight * label - accumulator.total_weighted_predictions += example_weight * prediction - accumulator.total_weighted_squared_labels += example_weight * label**2 - accumulator.total_weighted_squared_predictions += ( - example_weight * prediction**2) - accumulator.total_weighted_labels_times_predictions += ( - example_weight * label * prediction) - accumulator.total_weighted_examples += example_weight - return accumulator - - def merge_accumulators( - self, accumulators: Iterable[_SquaredPearsonCorrelationAccumulator] - ) -> _SquaredPearsonCorrelationAccumulator: - accumulators = iter(accumulators) - result = next(accumulators) - for accumulator in accumulators: - result.total_weighted_labels += accumulator.total_weighted_labels - result.total_weighted_predictions += ( - accumulator.total_weighted_predictions) - result.total_weighted_squared_labels += ( - accumulator.total_weighted_squared_labels) - result.total_weighted_squared_predictions += ( - accumulator.total_weighted_squared_predictions) - result.total_weighted_labels_times_predictions += ( - accumulator.total_weighted_labels_times_predictions) - result.total_weighted_examples += accumulator.total_weighted_examples - return result - - def extract_output( - self, accumulator: _SquaredPearsonCorrelationAccumulator - ) -> Dict[metric_types.MetricKey, float]: - result = float('nan') - - if accumulator.total_weighted_examples > 0.0: - # See https://en.wikipedia.org/wiki/Pearson_correlation_coefficient - # r^2 = Cov(X, Y)^2 / VAR(X) * VAR(Y) - # = (E[XY] - E[X]E[Y])^2 / (E[X^2] - E[X]^2) * (E[Y^2] - E[Y]^2) - # = [SUM(xy) - n*mean(x)*mean(y)]^2 / - # [SUM(x^2) - n*mean(x)^2 * SUM(y^2) - n*mean(y)^2] - # n = total_weighted_examples - # SUM(x) = total_weighted_labels - # SUM(y) = total_weighted_predictions - # SUM(xy) = total_weighted_labels_times_predictions - # SUM(x^2) = total_weighted_squared_labels - # SUM(y^2) = total_weighted_squared_predictions - - # numerator = [SUM(xy) - n*mean(x)*mean(y)]^2 - # = [SUM(xy) - n*SUM(x)/n*SUM(y)/n]^2 - # = [SUM(xy) - SUM(x)*SUM(y)/n]^2 - numerator = (accumulator.total_weighted_labels_times_predictions - - accumulator.total_weighted_labels * - accumulator.total_weighted_predictions / - accumulator.total_weighted_examples)**2 - # denominator_y = SUM(y^2) - n*mean(y)^2 - # = SUM(y^2) - n*(SUM(y)/n)^2 - # = SUM(y^2) - SUM(y)^2/n - denominator_y = ( - accumulator.total_weighted_squared_predictions - - accumulator.total_weighted_predictions**2 / - accumulator.total_weighted_examples) - - # denominator_x = SUM(x^2) - n*mean(x)^2 - # = SUM(x^2) - n*(SUM(x)/n)^2 - # = SUM(x^2) - SUM(x)^2/n - denominator_x = ( - accumulator.total_weighted_squared_labels - - accumulator.total_weighted_labels**2 / - accumulator.total_weighted_examples) - denominator = denominator_x * denominator_y - if denominator > 0.0: - result = numerator / denominator - - return {self._key: result} + example_weighted=self._example_weighted, + ): + example_weight = float(example_weight) + label = float(label) + prediction = float(prediction) + accumulator.total_weighted_labels += example_weight * label + accumulator.total_weighted_predictions += example_weight * prediction + accumulator.total_weighted_squared_labels += example_weight * label**2 + accumulator.total_weighted_squared_predictions += ( + example_weight * prediction**2 + ) + accumulator.total_weighted_labels_times_predictions += ( + example_weight * label * prediction + ) + accumulator.total_weighted_examples += example_weight + return accumulator + + def merge_accumulators( + self, accumulators: Iterable[_SquaredPearsonCorrelationAccumulator] + ) -> _SquaredPearsonCorrelationAccumulator: + accumulators = iter(accumulators) + result = next(accumulators) + for accumulator in accumulators: + result.total_weighted_labels += accumulator.total_weighted_labels + result.total_weighted_predictions += accumulator.total_weighted_predictions + result.total_weighted_squared_labels += ( + accumulator.total_weighted_squared_labels + ) + result.total_weighted_squared_predictions += ( + accumulator.total_weighted_squared_predictions + ) + result.total_weighted_labels_times_predictions += ( + accumulator.total_weighted_labels_times_predictions + ) + result.total_weighted_examples += accumulator.total_weighted_examples + return result + + def extract_output( + self, accumulator: _SquaredPearsonCorrelationAccumulator + ) -> Dict[metric_types.MetricKey, float]: + result = float("nan") + + if accumulator.total_weighted_examples > 0.0: + # See https://en.wikipedia.org/wiki/Pearson_correlation_coefficient + # r^2 = Cov(X, Y)^2 / VAR(X) * VAR(Y) + # = (E[XY] - E[X]E[Y])^2 / (E[X^2] - E[X]^2) * (E[Y^2] - E[Y]^2) + # = [SUM(xy) - n*mean(x)*mean(y)]^2 / + # [SUM(x^2) - n*mean(x)^2 * SUM(y^2) - n*mean(y)^2] + # n = total_weighted_examples + # SUM(x) = total_weighted_labels + # SUM(y) = total_weighted_predictions + # SUM(xy) = total_weighted_labels_times_predictions + # SUM(x^2) = total_weighted_squared_labels + # SUM(y^2) = total_weighted_squared_predictions + + # numerator = [SUM(xy) - n*mean(x)*mean(y)]^2 + # = [SUM(xy) - n*SUM(x)/n*SUM(y)/n]^2 + # = [SUM(xy) - SUM(x)*SUM(y)/n]^2 + numerator = ( + accumulator.total_weighted_labels_times_predictions + - accumulator.total_weighted_labels + * accumulator.total_weighted_predictions + / accumulator.total_weighted_examples + ) ** 2 + # denominator_y = SUM(y^2) - n*mean(y)^2 + # = SUM(y^2) - n*(SUM(y)/n)^2 + # = SUM(y^2) - SUM(y)^2/n + denominator_y = ( + accumulator.total_weighted_squared_predictions + - accumulator.total_weighted_predictions**2 + / accumulator.total_weighted_examples + ) + + # denominator_x = SUM(x^2) - n*mean(x)^2 + # = SUM(x^2) - n*(SUM(x)/n)^2 + # = SUM(x^2) - SUM(x)^2/n + denominator_x = ( + accumulator.total_weighted_squared_labels + - accumulator.total_weighted_labels**2 + / accumulator.total_weighted_examples + ) + denominator = denominator_x * denominator_y + if denominator > 0.0: + result = numerator / denominator + + return {self._key: result} diff --git a/tensorflow_model_analysis/metrics/squared_pearson_correlation_test.py b/tensorflow_model_analysis/metrics/squared_pearson_correlation_test.py index de4e30cfd9..5f97ecf1ad 100644 --- a/tensorflow_model_analysis/metrics/squared_pearson_correlation_test.py +++ b/tensorflow_model_analysis/metrics/squared_pearson_correlation_test.py @@ -16,182 +16,191 @@ import math import apache_beam as beam -from apache_beam.testing import util import numpy as np import tensorflow as tf -from tensorflow_model_analysis.metrics import metric_util -from tensorflow_model_analysis.metrics import squared_pearson_correlation +from apache_beam.testing import util + +from tensorflow_model_analysis.metrics import metric_util, squared_pearson_correlation from tensorflow_model_analysis.utils import test_util class SquaredPearsonCorrelationTest(test_util.TensorflowModelAnalysisTest): - - def testSquaredPearsonCorrelationWithoutWeights(self): - computations = ( - squared_pearson_correlation.SquaredPearsonCorrelation().computations()) - metric = computations[0] - - example1 = { - 'labels': np.array([2.0]), - 'predictions': np.array([1.0]), - 'example_weights': np.array([1.0]), - } - example2 = { - 'labels': np.array([1.0]), - 'predictions': np.array([2.0]), - 'example_weights': np.array([1.0]), - } - example3 = { - 'labels': np.array([2.0]), - 'predictions': np.array([3.0]), - 'example_weights': np.array([1.0]), - } - example4 = { - 'labels': np.array([3.0]), - 'predictions': np.array([4.0]), - 'example_weights': np.array([1.0]), - } - - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'Create' >> beam.Create([example1, example2, example3, example4]) - | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs) - | 'AddSlice' >> beam.Map(lambda x: ((), x)) - | 'ComputeMetric' >> beam.CombinePerKey(metric.combiner)) - - # pylint: enable=no-value-for-parameter - - def check_result(got): - try: - self.assertLen(got, 1) - got_slice_key, got_metrics = got[0] - self.assertEqual(got_slice_key, ()) - key = metric.keys[0] - # 1: prediction = 1, label = 2 - # 2: prediction = 2, label = 1 - # 3: prediction = 3, label = 2 - # 4: prediction = 4, label = 3 - # - # pred_x_labels = 2 + 2 + 6 + 12 = 22 - # labels = 2 + 1 + 2 + 3 = 8 - # preds = 1 + 2 + 3 + 4 = 10 - # sq_labels = 4 + 1 + 4 + 9 = 18 - # sq_preds = 1 + 4 + 9 + 16 = 30 - # examples = 4 - # - # r^2 = (22 - 8 * 10 / 4)^2 / (30 - 10^2 / 4) * (18 - 8^2 / 4) - # r^2 = 4 / (5 * 2) = 0.4 - self.assertDictElementsAlmostEqual(got_metrics, {key: 0.4}, places=5) - - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(result, check_result, label='result') - - def testSquaredPearsonCorrelationWithWeights(self): - computations = ( - squared_pearson_correlation.SquaredPearsonCorrelation().computations( - example_weighted=True)) - metric = computations[0] - - example1 = { - 'labels': np.array([1.0]), - 'predictions': np.array([1.0]), - 'example_weights': np.array([1.0]), - } - example2 = { - 'labels': np.array([4.0]), - 'predictions': np.array([2.0]), - 'example_weights': np.array([2.0]), - } - example3 = { - 'labels': np.array([3.0]), - 'predictions': np.array([3.0]), - 'example_weights': np.array([3.0]), - } - example4 = { - 'labels': np.array([3.0]), - 'predictions': np.array([4.0]), - 'example_weights': np.array([4.0]), - } - - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'Create' >> beam.Create([example1, example2, example3, example4]) - | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs) - | 'AddSlice' >> beam.Map(lambda x: ((), x)) - | 'ComputeMetric' >> beam.CombinePerKey(metric.combiner)) - - # pylint: enable=no-value-for-parameter - - def check_result(got): - try: - self.assertLen(got, 1) - got_slice_key, got_metrics = got[0] - self.assertEqual(got_slice_key, ()) - key = metric.keys[0] - # 1: prediction = 1, label = 1 - # 2: prediction = 2, label = 4 - # 3: prediction = 3, label = 3 - # 4: prediction = 4, label = 3 - # - # pred_x_labels = 1x1x1 + 2x2x4 + 3x3x3 + 4x4x3 = 92 - # labels = 1x1 + 2x4 + 3x3 + 4x3 = 30 - # preds = 1 + 2x2 + 3x3 + 4x4= 30 - # sq_labels = 1x1x1 + 2x4x4+ 3x3x3 + 4x3x3 = 96 - # sq_preds = 1x1x1 + 2x2x2 + 3x3x3 + 4x4x4 = 100 - # examples = 1 + 2 + 3 + 4 = 10 - # - # r^2 = (92 - 30 * 30 / 10)^2 / (100 - 30^2 / 10) * (96 - 30^2 / 10) - # r^2 = 4 / (10 * 6) = 0.06667 - self.assertDictElementsAlmostEqual( - got_metrics, {key: 0.06667}, places=5) - - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(result, check_result, label='result') - - def testSquaredPearsonCorrelationMetricsWithNan(self): - computations = ( - squared_pearson_correlation.SquaredPearsonCorrelation().computations()) - metric = computations[0] - - example = { - 'labels': np.array([0.0]), - 'predictions': np.array([1.0]), - 'example_weights': np.array([1.0]), - } - - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'Create' >> beam.Create([example]) - | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs) - | 'AddSlice' >> beam.Map(lambda x: ((), x)) - | 'ComputeMetric' >> beam.CombinePerKey(metric.combiner)) - - # pylint: enable=no-value-for-parameter - - def check_result(got): - try: - self.assertLen(got, 1) - got_slice_key, got_metrics = got[0] - self.assertEqual(got_slice_key, ()) - key = metric.keys[0] - self.assertIn(key, got_metrics) - self.assertTrue(math.isnan(got_metrics[key])) - - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(result, check_result, label='result') - - -if __name__ == '__main__': - tf.test.main() + def testSquaredPearsonCorrelationWithoutWeights(self): + computations = ( + squared_pearson_correlation.SquaredPearsonCorrelation().computations() + ) + metric = computations[0] + + example1 = { + "labels": np.array([2.0]), + "predictions": np.array([1.0]), + "example_weights": np.array([1.0]), + } + example2 = { + "labels": np.array([1.0]), + "predictions": np.array([2.0]), + "example_weights": np.array([1.0]), + } + example3 = { + "labels": np.array([2.0]), + "predictions": np.array([3.0]), + "example_weights": np.array([1.0]), + } + example4 = { + "labels": np.array([3.0]), + "predictions": np.array([4.0]), + "example_weights": np.array([1.0]), + } + + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "Create" >> beam.Create([example1, example2, example3, example4]) + | "Process" >> beam.Map(metric_util.to_standard_metric_inputs) + | "AddSlice" >> beam.Map(lambda x: ((), x)) + | "ComputeMetric" >> beam.CombinePerKey(metric.combiner) + ) + + # pylint: enable=no-value-for-parameter + + def check_result(got): + try: + self.assertLen(got, 1) + got_slice_key, got_metrics = got[0] + self.assertEqual(got_slice_key, ()) + key = metric.keys[0] + # 1: prediction = 1, label = 2 + # 2: prediction = 2, label = 1 + # 3: prediction = 3, label = 2 + # 4: prediction = 4, label = 3 + # + # pred_x_labels = 2 + 2 + 6 + 12 = 22 + # labels = 2 + 1 + 2 + 3 = 8 + # preds = 1 + 2 + 3 + 4 = 10 + # sq_labels = 4 + 1 + 4 + 9 = 18 + # sq_preds = 1 + 4 + 9 + 16 = 30 + # examples = 4 + # + # r^2 = (22 - 8 * 10 / 4)^2 / (30 - 10^2 / 4) * (18 - 8^2 / 4) + # r^2 = 4 / (5 * 2) = 0.4 + self.assertDictElementsAlmostEqual( + got_metrics, {key: 0.4}, places=5 + ) + + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(result, check_result, label="result") + + def testSquaredPearsonCorrelationWithWeights(self): + computations = ( + squared_pearson_correlation.SquaredPearsonCorrelation().computations( + example_weighted=True + ) + ) + metric = computations[0] + + example1 = { + "labels": np.array([1.0]), + "predictions": np.array([1.0]), + "example_weights": np.array([1.0]), + } + example2 = { + "labels": np.array([4.0]), + "predictions": np.array([2.0]), + "example_weights": np.array([2.0]), + } + example3 = { + "labels": np.array([3.0]), + "predictions": np.array([3.0]), + "example_weights": np.array([3.0]), + } + example4 = { + "labels": np.array([3.0]), + "predictions": np.array([4.0]), + "example_weights": np.array([4.0]), + } + + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "Create" >> beam.Create([example1, example2, example3, example4]) + | "Process" >> beam.Map(metric_util.to_standard_metric_inputs) + | "AddSlice" >> beam.Map(lambda x: ((), x)) + | "ComputeMetric" >> beam.CombinePerKey(metric.combiner) + ) + + # pylint: enable=no-value-for-parameter + + def check_result(got): + try: + self.assertLen(got, 1) + got_slice_key, got_metrics = got[0] + self.assertEqual(got_slice_key, ()) + key = metric.keys[0] + # 1: prediction = 1, label = 1 + # 2: prediction = 2, label = 4 + # 3: prediction = 3, label = 3 + # 4: prediction = 4, label = 3 + # + # pred_x_labels = 1x1x1 + 2x2x4 + 3x3x3 + 4x4x3 = 92 + # labels = 1x1 + 2x4 + 3x3 + 4x3 = 30 + # preds = 1 + 2x2 + 3x3 + 4x4= 30 + # sq_labels = 1x1x1 + 2x4x4+ 3x3x3 + 4x3x3 = 96 + # sq_preds = 1x1x1 + 2x2x2 + 3x3x3 + 4x4x4 = 100 + # examples = 1 + 2 + 3 + 4 = 10 + # + # r^2 = (92 - 30 * 30 / 10)^2 / (100 - 30^2 / 10) * (96 - 30^2 / 10) + # r^2 = 4 / (10 * 6) = 0.06667 + self.assertDictElementsAlmostEqual( + got_metrics, {key: 0.06667}, places=5 + ) + + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(result, check_result, label="result") + + def testSquaredPearsonCorrelationMetricsWithNan(self): + computations = ( + squared_pearson_correlation.SquaredPearsonCorrelation().computations() + ) + metric = computations[0] + + example = { + "labels": np.array([0.0]), + "predictions": np.array([1.0]), + "example_weights": np.array([1.0]), + } + + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "Create" >> beam.Create([example]) + | "Process" >> beam.Map(metric_util.to_standard_metric_inputs) + | "AddSlice" >> beam.Map(lambda x: ((), x)) + | "ComputeMetric" >> beam.CombinePerKey(metric.combiner) + ) + + # pylint: enable=no-value-for-parameter + + def check_result(got): + try: + self.assertLen(got, 1) + got_slice_key, got_metrics = got[0] + self.assertEqual(got_slice_key, ()) + key = metric.keys[0] + self.assertIn(key, got_metrics) + self.assertTrue(math.isnan(got_metrics[key])) + + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(result, check_result, label="result") + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow_model_analysis/metrics/stats.py b/tensorflow_model_analysis/metrics/stats.py index be8cf1021d..2206d41ce6 100644 --- a/tensorflow_model_analysis/metrics/stats.py +++ b/tensorflow_model_analysis/metrics/stats.py @@ -17,94 +17,86 @@ from typing import Any, Dict, Iterable, Optional, Tuple, Union import apache_beam as beam + from tensorflow_model_analysis.metrics import metric_types from tensorflow_model_analysis.utils import util - _KeyPath = Tuple[str] -_MEAN_METRIC_BASE_NAME = 'mean' +_MEAN_METRIC_BASE_NAME = "mean" # TODO(b/287700355): Add __slots__ to _Accumulator @dataclasses.dataclass class _Accumulator: - count: float = 0.0 - total: float = 0.0 + count: float = 0.0 + total: float = 0.0 class _MeanCombiner(beam.CombineFn): - """Computes mean metric.""" - - def __init__( - self, - key: metric_types.MetricKey, - feature_key_path: _KeyPath, - example_weights_key_path: _KeyPath, - ): - self._key = key - self._feature_key_path = feature_key_path - self._example_weights_key_path = example_weights_key_path - - def create_accumulator(self) -> _Accumulator: - return _Accumulator() - - def add_input( - self, - accumulator: _Accumulator, - element: metric_types.StandardMetricInputs, - ) -> _Accumulator: - # Get feature value - features = util.get_by_keys(element, self._feature_key_path) - assert len(features) == 1, ( - 'Mean() is only supported for scalar features, but found features = ' - f'{features}' - ) - - # Get example weight - if self._example_weights_key_path is None: - example_weight = 1.0 - else: - example_weights = util.get_by_keys( - element, self._example_weights_key_path - ) - assert len(example_weights) == 1, ( - 'Expected 1 (scalar) example weight for each example, ' - f'but found example weight = {example_weights}' - ) - example_weight = example_weights[0] - - # Update accumulator - accumulator.count += example_weight - accumulator.total += example_weight * features[0] - - return accumulator - - def merge_accumulators( - self, accumulators: Iterable[_Accumulator] - ) -> _Accumulator: - accumulators = iter(accumulators) - result = next(accumulators) - for accumulator in accumulators: - result.count += accumulator.count - result.total += accumulator.total - return result - - def extract_output( - self, accumulator: _Accumulator - ) -> Dict[metric_types.MetricKey, float]: - if accumulator.count == 0.0: - return {self._key: float('nan')} - return {self._key: accumulator.total / accumulator.count} - - -def _convert_key_path_to_dict( - key_path: Union[_KeyPath, Tuple[()]] -) -> Dict[str, Any]: - """Recursively converts _KeyPath to nested dict.""" - return ( - {key_path[0]: _convert_key_path_to_dict(key_path[1:])} if key_path else {} - ) + """Computes mean metric.""" + + def __init__( + self, + key: metric_types.MetricKey, + feature_key_path: _KeyPath, + example_weights_key_path: _KeyPath, + ): + self._key = key + self._feature_key_path = feature_key_path + self._example_weights_key_path = example_weights_key_path + + def create_accumulator(self) -> _Accumulator: + return _Accumulator() + + def add_input( + self, + accumulator: _Accumulator, + element: metric_types.StandardMetricInputs, + ) -> _Accumulator: + # Get feature value + features = util.get_by_keys(element, self._feature_key_path) + assert len(features) == 1, ( + "Mean() is only supported for scalar features, but found features = " + f"{features}" + ) + + # Get example weight + if self._example_weights_key_path is None: + example_weight = 1.0 + else: + example_weights = util.get_by_keys(element, self._example_weights_key_path) + assert len(example_weights) == 1, ( + "Expected 1 (scalar) example weight for each example, " + f"but found example weight = {example_weights}" + ) + example_weight = example_weights[0] + + # Update accumulator + accumulator.count += example_weight + accumulator.total += example_weight * features[0] + + return accumulator + + def merge_accumulators(self, accumulators: Iterable[_Accumulator]) -> _Accumulator: + accumulators = iter(accumulators) + result = next(accumulators) + for accumulator in accumulators: + result.count += accumulator.count + result.total += accumulator.total + return result + + def extract_output( + self, accumulator: _Accumulator + ) -> Dict[metric_types.MetricKey, float]: + if accumulator.count == 0.0: + return {self._key: float("nan")} + return {self._key: accumulator.total / accumulator.count} + + +def _convert_key_path_to_dict(key_path: Union[_KeyPath, Tuple[()]]) -> Dict[str, Any]: + """Recursively converts _KeyPath to nested dict.""" + return {key_path[0]: _convert_key_path_to_dict(key_path[1:])} if key_path else {} def _mean_metric( @@ -112,56 +104,55 @@ def _mean_metric( example_weights_key_path: _KeyPath, name: str, ) -> metric_types.MetricComputations: - """Returns metric computation for mean metric.""" - key = metric_types.MetricKey( - name=name, example_weighted=example_weights_key_path is not None - ) - - include_filter = _convert_key_path_to_dict(feature_key_path) - if example_weights_key_path: - include_filter = util.merge_filters( - include_filter, - _convert_key_path_to_dict(example_weights_key_path), + """Returns metric computation for mean metric.""" + key = metric_types.MetricKey( + name=name, example_weighted=example_weights_key_path is not None ) - return [ - metric_types.MetricComputation( - keys=[key], - preprocessors=[ - metric_types.StandardMetricInputsPreprocessor( - include_filter=include_filter, - include_default_inputs=False, - ) - ], - combiner=_MeanCombiner( - key, feature_key_path, example_weights_key_path - ), - ) - ] + include_filter = _convert_key_path_to_dict(feature_key_path) + if example_weights_key_path: + include_filter = util.merge_filters( + include_filter, + _convert_key_path_to_dict(example_weights_key_path), + ) + + return [ + metric_types.MetricComputation( + keys=[key], + preprocessors=[ + metric_types.StandardMetricInputsPreprocessor( + include_filter=include_filter, + include_default_inputs=False, + ) + ], + combiner=_MeanCombiner(key, feature_key_path, example_weights_key_path), + ) + ] class Mean(metric_types.Metric): - """Mean metric.""" - - def __init__( - self, - feature_key_path: _KeyPath, - example_weights_key_path: Optional[_KeyPath] = None, - name: Optional[str] = None, - ): - """Initializes mean metric. - - Args: - feature_key_path: key path to feature to calculate the mean of. - example_weights_key_path: key path to example weights. - name: Metric base name. - """ - super().__init__( - _mean_metric, - feature_key_path=feature_key_path, - example_weights_key_path=example_weights_key_path, - name=name or f"{_MEAN_METRIC_BASE_NAME}_{'.'.join(feature_key_path)}", - ) + """Mean metric.""" + + def __init__( + self, + feature_key_path: _KeyPath, + example_weights_key_path: Optional[_KeyPath] = None, + name: Optional[str] = None, + ): + """Initializes mean metric. + + Args: + ---- + feature_key_path: key path to feature to calculate the mean of. + example_weights_key_path: key path to example weights. + name: Metric base name. + """ + super().__init__( + _mean_metric, + feature_key_path=feature_key_path, + example_weights_key_path=example_weights_key_path, + name=name or f"{_MEAN_METRIC_BASE_NAME}_{'.'.join(feature_key_path)}", + ) metric_types.register_metric(Mean) diff --git a/tensorflow_model_analysis/metrics/stats_test.py b/tensorflow_model_analysis/metrics/stats_test.py index c26dc633b2..5c26550db7 100644 --- a/tensorflow_model_analysis/metrics/stats_test.py +++ b/tensorflow_model_analysis/metrics/stats_test.py @@ -13,304 +13,299 @@ # limitations under the License. """Tests for stats metrics.""" -from absl.testing import parameterized import apache_beam as beam -from apache_beam.testing import util import numpy as np import tensorflow as tf +from absl.testing import parameterized +from apache_beam.testing import util +from google.protobuf import text_format + import tensorflow_model_analysis as tfma +from tensorflow_model_analysis.metrics import metric_types, metric_util, stats from tensorflow_model_analysis.proto import config_pb2 -from tensorflow_model_analysis.metrics import metric_types -from tensorflow_model_analysis.metrics import metric_util -from tensorflow_model_analysis.metrics import stats from tensorflow_model_analysis.utils import test_util -from google.protobuf import text_format - def _get_examples(): - example0 = { - 'labels': None, - 'predictions': None, - 'features': { - 'example_weights': [2.2], - 'age': [18], - 'income': [50000], - }, - } - example1 = { - 'labels': None, - 'predictions': None, - 'features': { - 'example_weights': [6.8], - 'age': [21], - 'income': [100000], - }, - } - example2 = { - 'labels': None, - 'predictions': None, - 'features': { - 'example_weights': [9.2], - 'age': [50], - 'income': [300000], - }, - } - example3 = { - 'labels': None, - 'predictions': None, - 'features': { - 'example_weights': [6.7], - 'age': [65], - 'income': [400000], - }, - } - return [example0, example1, example2, example3] + example0 = { + "labels": None, + "predictions": None, + "features": { + "example_weights": [2.2], + "age": [18], + "income": [50000], + }, + } + example1 = { + "labels": None, + "predictions": None, + "features": { + "example_weights": [6.8], + "age": [21], + "income": [100000], + }, + } + example2 = { + "labels": None, + "predictions": None, + "features": { + "example_weights": [9.2], + "age": [50], + "income": [300000], + }, + } + example3 = { + "labels": None, + "predictions": None, + "features": { + "example_weights": [6.7], + "age": [65], + "income": [400000], + }, + } + return [example0, example1, example2, example3] def _compute_mean_metric(pipeline, computation): - return ( - pipeline - | 'Create' >> beam.Create(_get_examples()) - | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs) - | 'AddSlice' >> beam.Map(lambda x: ((), x)) - | 'ComputeMeanMetric' >> beam.CombinePerKey(computation.combiner) - ) + return ( + pipeline + | "Create" >> beam.Create(_get_examples()) + | "Process" >> beam.Map(metric_util.to_standard_metric_inputs) + | "AddSlice" >> beam.Map(lambda x: ((), x)) + | "ComputeMeanMetric" >> beam.CombinePerKey(computation.combiner) + ) class MeanTestValidExamples( test_util.TensorflowModelAnalysisTest, parameterized.TestCase ): - - def _check_got(self, got, rouge_computation): - self.assertLen(got, 1) - got_slice_key, got_metrics = got[0] - self.assertEqual(got_slice_key, ()) - self.assertIn(rouge_computation.keys[0], got_metrics) - return got_metrics - - @parameterized.named_parameters( - ('Age', ['features', 'age'], 'mean_features.age', 38.5), - ('Income', ['features', 'income'], 'mean_features.income', 212500), - ) - def testMeanUnweighted( - self, feature_key_path, expected_metric_key_name, expected_mean - ): - mean_metric_key = metric_types.MetricKey(name=expected_metric_key_name) - mean_metric_computation = stats.Mean(feature_key_path).computations()[0] - - with beam.Pipeline() as pipeline: - result = _compute_mean_metric(pipeline, mean_metric_computation) - - def check_result(got): - try: - got_metrics = self._check_got(got, mean_metric_computation) - self.assertDictElementsAlmostEqual( - got_metrics, {mean_metric_key: expected_mean} - ) - - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(result, check_result, label='result') - - @parameterized.named_parameters( - ('Age', ['features', 'age'], 'mean_features.age', 1077.9 / 24.9), - ( - 'Income', - ['features', 'income'], - 'mean_features.income', - 6230000 / 24.9, - ), - ) - def testMeanWeighted( - self, feature_key_path, expected_metric_key_name, expected_mean - ): - mean_metric_key = metric_types.MetricKey( - name=expected_metric_key_name, example_weighted=True + def _check_got(self, got, rouge_computation): + self.assertLen(got, 1) + got_slice_key, got_metrics = got[0] + self.assertEqual(got_slice_key, ()) + self.assertIn(rouge_computation.keys[0], got_metrics) + return got_metrics + + @parameterized.named_parameters( + ("Age", ["features", "age"], "mean_features.age", 38.5), + ("Income", ["features", "income"], "mean_features.income", 212500), + ) + def testMeanUnweighted( + self, feature_key_path, expected_metric_key_name, expected_mean + ): + mean_metric_key = metric_types.MetricKey(name=expected_metric_key_name) + mean_metric_computation = stats.Mean(feature_key_path).computations()[0] + + with beam.Pipeline() as pipeline: + result = _compute_mean_metric(pipeline, mean_metric_computation) + + def check_result(got): + try: + got_metrics = self._check_got(got, mean_metric_computation) + self.assertDictElementsAlmostEqual( + got_metrics, {mean_metric_key: expected_mean} + ) + + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(result, check_result, label="result") + + @parameterized.named_parameters( + ("Age", ["features", "age"], "mean_features.age", 1077.9 / 24.9), + ( + "Income", + ["features", "income"], + "mean_features.income", + 6230000 / 24.9, + ), ) - mean_metric_computation = stats.Mean( - feature_key_path, - example_weights_key_path=['features', 'example_weights'], - ).computations()[0] + def testMeanWeighted( + self, feature_key_path, expected_metric_key_name, expected_mean + ): + mean_metric_key = metric_types.MetricKey( + name=expected_metric_key_name, example_weighted=True + ) + mean_metric_computation = stats.Mean( + feature_key_path, + example_weights_key_path=["features", "example_weights"], + ).computations()[0] - with beam.Pipeline() as pipeline: - result = _compute_mean_metric(pipeline, mean_metric_computation) + with beam.Pipeline() as pipeline: + result = _compute_mean_metric(pipeline, mean_metric_computation) - def check_result(got): - try: - got_metrics = self._check_got(got, mean_metric_computation) - self.assertDictElementsAlmostEqual( - got_metrics, {mean_metric_key: expected_mean} - ) + def check_result(got): + try: + got_metrics = self._check_got(got, mean_metric_computation) + self.assertDictElementsAlmostEqual( + got_metrics, {mean_metric_key: expected_mean} + ) - except AssertionError as err: - raise util.BeamAssertException(err) + except AssertionError as err: + raise util.BeamAssertException(err) - util.assert_that(result, check_result, label='result') + util.assert_that(result, check_result, label="result") - def testMeanName(self): - feature_key_path = ['features', 'age'] - name = 'name_to_verify_123_!@#' - expected_mean = 38.5 - mean_metric_key = metric_types.MetricKey(name=name) - mean_metric_computation = stats.Mean( - feature_key_path, name=name - ).computations()[0] + def testMeanName(self): + feature_key_path = ["features", "age"] + name = "name_to_verify_123_!@#" + expected_mean = 38.5 + mean_metric_key = metric_types.MetricKey(name=name) + mean_metric_computation = stats.Mean( + feature_key_path, name=name + ).computations()[0] - with beam.Pipeline() as pipeline: - result = _compute_mean_metric(pipeline, mean_metric_computation) + with beam.Pipeline() as pipeline: + result = _compute_mean_metric(pipeline, mean_metric_computation) - def check_result(got): - try: - got_metrics = self._check_got(got, mean_metric_computation) - self.assertDictElementsAlmostEqual( - got_metrics, {mean_metric_key: expected_mean} - ) + def check_result(got): + try: + got_metrics = self._check_got(got, mean_metric_computation) + self.assertDictElementsAlmostEqual( + got_metrics, {mean_metric_key: expected_mean} + ) - except AssertionError as err: - raise util.BeamAssertException(err) + except AssertionError as err: + raise util.BeamAssertException(err) - util.assert_that(result, check_result, label='result') + util.assert_that(result, check_result, label="result") class MeanTestInvalidExamples( test_util.TensorflowModelAnalysisTest, parameterized.TestCase ): + def testMeanNotOneFeatureValue(self): + # This example should cause a ValueError because + # the feature (age) contains more than one value. + example = { + "labels": None, + "predictions": None, + "features": { + "example_weights": [2.2], + "age": [18, 21], + }, + } - def testMeanNotOneFeatureValue(self): - # This example should cause a ValueError because - # the feature (age) contains more than one value. - example = { - 'labels': None, - 'predictions': None, - 'features': { - 'example_weights': [2.2], - 'age': [18, 21], - }, - } - - feature_key_path = ['features', 'age'] - example_weights_key_path = ['features', 'example_weights'] - - mean_metric_computation = stats.Mean( - feature_key_path, example_weights_key_path=example_weights_key_path - ).computations()[0] - - with self.assertRaisesRegex( - AssertionError, - r'Mean\(\) is only supported for scalar features, but found features = ' - r'\[18, 21\]', - ): - with beam.Pipeline() as pipeline: - _ = ( - pipeline - | 'Create' >> beam.Create([example]) - | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs) - | 'AddSlice' >> beam.Map(lambda x: ((), x)) - | 'ComputeMeanMetric' - >> beam.CombinePerKey(mean_metric_computation.combiner) - ) - - def testMeanNotOneExampleWeight(self): - # This example should cause a ValueError - # because it has multiple example weights. - example = { - 'labels': None, - 'predictions': None, - 'features': { - 'example_weights': [4.6, 8.5], - 'age': [18], - }, - } - - feature_key_path = ['features', 'age'] - example_weights_key_path = ['features', 'example_weights'] - - mean_metric_computation = stats.Mean( - feature_key_path, example_weights_key_path=example_weights_key_path - ).computations()[0] - - with self.assertRaisesRegex( - AssertionError, - r'Expected 1 \(scalar\) example weight for each example, but found ' - r'example weight = \[4.6, 8.5\]', - ): - with beam.Pipeline() as pipeline: - _ = ( - pipeline - | 'Create' >> beam.Create([example]) - | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs) - | 'AddSlice' >> beam.Map(lambda x: ((), x)) - | 'ComputeMeanMetric' - >> beam.CombinePerKey(mean_metric_computation.combiner) - ) + feature_key_path = ["features", "age"] + example_weights_key_path = ["features", "example_weights"] + + mean_metric_computation = stats.Mean( + feature_key_path, example_weights_key_path=example_weights_key_path + ).computations()[0] + + with self.assertRaisesRegex( + AssertionError, + r"Mean\(\) is only supported for scalar features, but found features = " + r"\[18, 21\]", + ): + with beam.Pipeline() as pipeline: + _ = ( + pipeline + | "Create" >> beam.Create([example]) + | "Process" >> beam.Map(metric_util.to_standard_metric_inputs) + | "AddSlice" >> beam.Map(lambda x: ((), x)) + | "ComputeMeanMetric" + >> beam.CombinePerKey(mean_metric_computation.combiner) + ) + + def testMeanNotOneExampleWeight(self): + # This example should cause a ValueError + # because it has multiple example weights. + example = { + "labels": None, + "predictions": None, + "features": { + "example_weights": [4.6, 8.5], + "age": [18], + }, + } - def testMeanExampleCountIsZero(self): - example = { - 'labels': None, - 'predictions': None, - 'features': { - 'example_weights': [0.0], - 'age': [18], - }, - } + feature_key_path = ["features", "age"] + example_weights_key_path = ["features", "example_weights"] + + mean_metric_computation = stats.Mean( + feature_key_path, example_weights_key_path=example_weights_key_path + ).computations()[0] + + with self.assertRaisesRegex( + AssertionError, + r"Expected 1 \(scalar\) example weight for each example, but found " + r"example weight = \[4.6, 8.5\]", + ): + with beam.Pipeline() as pipeline: + _ = ( + pipeline + | "Create" >> beam.Create([example]) + | "Process" >> beam.Map(metric_util.to_standard_metric_inputs) + | "AddSlice" >> beam.Map(lambda x: ((), x)) + | "ComputeMeanMetric" + >> beam.CombinePerKey(mean_metric_computation.combiner) + ) + + def testMeanExampleCountIsZero(self): + example = { + "labels": None, + "predictions": None, + "features": { + "example_weights": [0.0], + "age": [18], + }, + } - feature_key_path = ['features', 'age'] - example_weights_key_path = ['features', 'example_weights'] + feature_key_path = ["features", "age"] + example_weights_key_path = ["features", "example_weights"] - mean_metric_computation = stats.Mean( - feature_key_path, example_weights_key_path=example_weights_key_path - ).computations()[0] - key = mean_metric_computation.keys[0] + mean_metric_computation = stats.Mean( + feature_key_path, example_weights_key_path=example_weights_key_path + ).computations()[0] + key = mean_metric_computation.keys[0] - with beam.Pipeline() as pipeline: - result = ( - pipeline - | 'Create' >> beam.Create([example]) - | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs) - | 'AddSlice' >> beam.Map(lambda x: ((), x)) - | 'ComputeMeanMetric' - >> beam.CombinePerKey(mean_metric_computation.combiner) - ) + with beam.Pipeline() as pipeline: + result = ( + pipeline + | "Create" >> beam.Create([example]) + | "Process" >> beam.Map(metric_util.to_standard_metric_inputs) + | "AddSlice" >> beam.Map(lambda x: ((), x)) + | "ComputeMeanMetric" + >> beam.CombinePerKey(mean_metric_computation.combiner) + ) - def check_result(got): - try: - self.assertLen(got, 1) - got_slice_key, got_metrics = got[0] - self.assertEqual(got_slice_key, ()) - self.assertLen(got_metrics, 1) - self.assertIn(key, got_metrics) - self.assertTrue(np.isnan(got_metrics[key])) + def check_result(got): + try: + self.assertLen(got, 1) + got_slice_key, got_metrics = got[0] + self.assertEqual(got_slice_key, ()) + self.assertLen(got_metrics, 1) + self.assertIn(key, got_metrics) + self.assertTrue(np.isnan(got_metrics[key])) - except AssertionError as err: - raise util.BeamAssertException(err) + except AssertionError as err: + raise util.BeamAssertException(err) - util.assert_that(result, check_result, label='result') + util.assert_that(result, check_result, label="result") class MeanEnd2EndTest(parameterized.TestCase): - - def testMeanEnd2End(self): - extracts = [ - { - 'features': { - 'example_weights': np.array([0.5]), - 'age': np.array([30]), - 'income': np.array([150000]), + def testMeanEnd2End(self): + extracts = [ + { + "features": { + "example_weights": np.array([0.5]), + "age": np.array([30]), + "income": np.array([150000]), + }, }, - }, - { - 'features': { - 'example_weights': np.array([0.3]), - 'age': np.array([40]), - 'income': np.array([200000]), + { + "features": { + "example_weights": np.array([0.3]), + "age": np.array([40]), + "income": np.array([200000]), + }, }, - }, - ] + ] - eval_config = text_format.Parse( - """ + eval_config = text_format.Parse( + """ metrics_specs { metrics { class_name: "Mean" @@ -324,71 +319,69 @@ def testMeanEnd2End(self): } , } """, - config_pb2.EvalConfig(), - ) + config_pb2.EvalConfig(), + ) - extractors = tfma.default_extractors(eval_config=eval_config) - evaluators = tfma.default_evaluators(eval_config=eval_config) + extractors = tfma.default_extractors(eval_config=eval_config) + evaluators = tfma.default_evaluators(eval_config=eval_config) - expected_key_age = metric_types.MetricKey( - name='mean_features.age', example_weighted=True - ) - expected_key_income = metric_types.MetricKey( - name='mean_features.income', example_weighted=True - ) + expected_key_age = metric_types.MetricKey( + name="mean_features.age", example_weighted=True + ) + expected_key_income = metric_types.MetricKey( + name="mean_features.income", example_weighted=True + ) - expected_result_age = 33.75 # (30 * 0.5 + 40 * 0.3) / (0.5 + 0.3) = 33.75 - # (150k * 0.5 + 200k * 0.3) / (0.5 + 0.3) = 168,750 - expected_result_income = 168750 - - with beam.Pipeline() as pipeline: - result = ( - pipeline - | 'LoadData' >> beam.Create(extracts) - | 'ExtractEval' - >> tfma.ExtractAndEvaluate( - extractors=extractors, evaluators=evaluators - ) - ) - - def check_result(got): - try: - self.assertLen(got, 1) - got_slice_key, got_metrics = got[0] - self.assertEqual(got_slice_key, ()) - self.assertLen(got_metrics, 2) - self.assertIn(expected_key_age, got_metrics) - self.assertIn(expected_key_income, got_metrics) - self.assertAlmostEqual( - expected_result_age, got_metrics[expected_key_age] - ) - self.assertAlmostEqual( - expected_result_income, got_metrics[expected_key_income] - ) - except AssertionError as err: - raise util.BeamAssertException(err) - - self.assertIn('metrics', result) - util.assert_that(result['metrics'], check_result, label='result') - - def testMeanEnd2EndWithoutExampleWeights(self): - extracts = [ - { - 'features': { - 'age': np.array([30]), - 'income': np.array([150000]), + expected_result_age = 33.75 # (30 * 0.5 + 40 * 0.3) / (0.5 + 0.3) = 33.75 + # (150k * 0.5 + 200k * 0.3) / (0.5 + 0.3) = 168,750 + expected_result_income = 168750 + + with beam.Pipeline() as pipeline: + result = ( + pipeline + | "LoadData" >> beam.Create(extracts) + | "ExtractEval" + >> tfma.ExtractAndEvaluate(extractors=extractors, evaluators=evaluators) + ) + + def check_result(got): + try: + self.assertLen(got, 1) + got_slice_key, got_metrics = got[0] + self.assertEqual(got_slice_key, ()) + self.assertLen(got_metrics, 2) + self.assertIn(expected_key_age, got_metrics) + self.assertIn(expected_key_income, got_metrics) + self.assertAlmostEqual( + expected_result_age, got_metrics[expected_key_age] + ) + self.assertAlmostEqual( + expected_result_income, got_metrics[expected_key_income] + ) + except AssertionError as err: + raise util.BeamAssertException(err) + + self.assertIn("metrics", result) + util.assert_that(result["metrics"], check_result, label="result") + + def testMeanEnd2EndWithoutExampleWeights(self): + extracts = [ + { + "features": { + "age": np.array([30]), + "income": np.array([150000]), + }, }, - }, - { - 'features': { - 'age': np.array([40]), - 'income': np.array([200000]), + { + "features": { + "age": np.array([40]), + "income": np.array([200000]), + }, }, - }, - ] + ] - eval_config = text_format.Parse( - """ + eval_config = text_format.Parse( + """ metrics_specs { metrics { class_name: "Mean" @@ -400,53 +393,51 @@ def testMeanEnd2EndWithoutExampleWeights(self): } , } """, - config_pb2.EvalConfig(), - ) + config_pb2.EvalConfig(), + ) - extractors = tfma.default_extractors(eval_config=eval_config) - evaluators = tfma.default_evaluators(eval_config=eval_config) + extractors = tfma.default_extractors(eval_config=eval_config) + evaluators = tfma.default_evaluators(eval_config=eval_config) - expected_key_age = metric_types.MetricKey( - name='mean_features.age', example_weighted=False - ) - expected_key_income = metric_types.MetricKey( - name='mean_features.income', example_weighted=False - ) + expected_key_age = metric_types.MetricKey( + name="mean_features.age", example_weighted=False + ) + expected_key_income = metric_types.MetricKey( + name="mean_features.income", example_weighted=False + ) - expected_result_age = 35 # (30 + 40) / (1 + 1) = 35 - # (150k + 200k) / (1 + 1) = 175000 - expected_result_income = 175000 - - with beam.Pipeline() as pipeline: - result = ( - pipeline - | 'LoadData' >> beam.Create(extracts) - | 'ExtractEval' - >> tfma.ExtractAndEvaluate( - extractors=extractors, evaluators=evaluators - ) - ) - - def check_result(got): - try: - self.assertLen(got, 1) - got_slice_key, got_metrics = got[0] - self.assertEqual(got_slice_key, ()) - self.assertLen(got_metrics, 2) - self.assertIn(expected_key_age, got_metrics) - self.assertIn(expected_key_income, got_metrics) - self.assertAlmostEqual( - expected_result_age, got_metrics[expected_key_age] - ) - self.assertAlmostEqual( - expected_result_income, got_metrics[expected_key_income] - ) - except AssertionError as err: - raise util.BeamAssertException(err) - - self.assertIn('metrics', result) - util.assert_that(result['metrics'], check_result, label='result') - - -if __name__ == '__main__': - tf.test.main() + expected_result_age = 35 # (30 + 40) / (1 + 1) = 35 + # (150k + 200k) / (1 + 1) = 175000 + expected_result_income = 175000 + + with beam.Pipeline() as pipeline: + result = ( + pipeline + | "LoadData" >> beam.Create(extracts) + | "ExtractEval" + >> tfma.ExtractAndEvaluate(extractors=extractors, evaluators=evaluators) + ) + + def check_result(got): + try: + self.assertLen(got, 1) + got_slice_key, got_metrics = got[0] + self.assertEqual(got_slice_key, ()) + self.assertLen(got_metrics, 2) + self.assertIn(expected_key_age, got_metrics) + self.assertIn(expected_key_income, got_metrics) + self.assertAlmostEqual( + expected_result_age, got_metrics[expected_key_age] + ) + self.assertAlmostEqual( + expected_result_income, got_metrics[expected_key_income] + ) + except AssertionError as err: + raise util.BeamAssertException(err) + + self.assertIn("metrics", result) + util.assert_that(result["metrics"], check_result, label="result") + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow_model_analysis/metrics/test_util.py b/tensorflow_model_analysis/metrics/test_util.py index 6fa3334f5c..64b40ba697 100644 --- a/tensorflow_model_analysis/metrics/test_util.py +++ b/tensorflow_model_analysis/metrics/test_util.py @@ -15,134 +15,140 @@ from typing import Iterable -from absl.testing import absltest import apache_beam as beam import numpy as np +from absl.testing import absltest + from tensorflow_model_analysis import constants from tensorflow_model_analysis.api import types -from tensorflow_model_analysis.evaluators import metrics_plots_and_validations_evaluator as evaluator +from tensorflow_model_analysis.evaluators import ( + metrics_plots_and_validations_evaluator as evaluator, +) from tensorflow_model_analysis.metrics import metric_types from tensorflow_model_analysis.slicer import slicer_lib class TestCase(absltest.TestCase): - """Base class for metric tests which provides assertMetricEqual.""" - - def assertDerivedMetricsEqual( # pylint: disable=invalid-name - self, - expected_metrics: metric_types.MetricsDict, - metric: metric_types.Metric, - extracts: Iterable[types.Extracts], - example_weighted: bool = True, - enable_debug_print: bool = False, - ): - """Asserts that the given metric has the expected values. + """Base class for metric tests which provides assertMetricEqual.""" - This method exists to allow metric authors to easily test that their code - behaves correctly when excercised by the standard evaluator. This utility - relies heavily on the actual evaluator implementation due to the complexity - of the metric-evaluator contract. Though this pattern is in conflict with - the principles of unit testing, we consider this to be preferable to many, - scattered and incorrect versions of the metric-evaluator contract. + def assertDerivedMetricsEqual( # pylint: disable=invalid-name + self, + expected_metrics: metric_types.MetricsDict, + metric: metric_types.Metric, + extracts: Iterable[types.Extracts], + example_weighted: bool = True, + enable_debug_print: bool = False, + ): + """Asserts that the given metric has the expected values. - Schematically, this method: - - generates the computations from the metric instance - - filters and separates the different types of computations - - applies those computations in the same way that the evaluator would - - non-derived: applies preprocessors and a merged combine_fn which - possibly includes multiple metric CombineFns - - derived: applies the derived metric computations to the - non-derived metric results - - removes any private metrics from the result - - asserts that the result matches the expected metrics + This method exists to allow metric authors to easily test that their code + behaves correctly when excercised by the standard evaluator. This utility + relies heavily on the actual evaluator implementation due to the complexity + of the metric-evaluator contract. Though this pattern is in conflict with + the principles of unit testing, we consider this to be preferable to many, + scattered and incorrect versions of the metric-evaluator contract. - Args: - expected_metrics: The expected metrics dict containing the exact metric - keys and value. - metric: The metric instance to test. - extracts: The extracts to use as input to the evaluator. These should be - of the format that would be produced by applying the Input-, Features-, - Predictions-, Labels- and ExampleWeight- Extractors. - example_weighted: Whether the metric is example weighted. - enable_debug_print: Whether to print the beam PCollections after each - stage. + Schematically, this method: + - generates the computations from the metric instance + - filters and separates the different types of computations + - applies those computations in the same way that the evaluator would + - non-derived: applies preprocessors and a merged combine_fn which + possibly includes multiple metric CombineFns + - derived: applies the derived metric computations to the + non-derived metric results + - removes any private metrics from the result + - asserts that the result matches the expected metrics - Raises: - AssertionError: If the metric does not have the expected values. - """ + Args: + ---- + expected_metrics: The expected metrics dict containing the exact metric + keys and value. + metric: The metric instance to test. + extracts: The extracts to use as input to the evaluator. These should be + of the format that would be produced by applying the Input-, Features-, + Predictions-, Labels- and ExampleWeight- Extractors. + example_weighted: Whether the metric is example weighted. + enable_debug_print: Whether to print the beam PCollections after each + stage. - def debug_print(element, stage_name): - if enable_debug_print: - print(f'[{stage_name}]\t{element}') - return element + Raises: + ------ + AssertionError: If the metric does not have the expected values. + """ - computations = evaluator._filter_and_separate_computations( # pylint: disable=protected-access - metric.computations(example_weighted=example_weighted) - ) - with beam.Pipeline() as pipeline: - result = ( - pipeline - | 'Create' >> beam.Create(extracts) - | 'PrintAfterCreate' >> beam.Map(debug_print, 'AfterCreate') - | 'AddSlice' - >> beam.Map( - lambda x: x - | { - constants.SLICE_KEY_TYPES_KEY: np.array( - [slicer_lib.slice_keys_to_numpy_array([()])] - ) - } - ) - | 'PrintAfterAddSlice' >> beam.Map(debug_print, 'AfterAddSlice') - | 'Preprocess' - >> beam.ParDo( - evaluator._PreprocessorDoFn( # pylint: disable=protected-access - computations.non_derived_computations - ) - ) - | 'PrintAfterPreprocess' >> beam.Map(debug_print, 'AfterPreprocess') - | 'FanoutSlices' >> slicer_lib.FanoutSlices() - | 'PrintAfterFanoutSlices' - >> beam.Map(debug_print, 'AfterFanoutSlices') - | 'ComputeNonDerivedMetrics' - >> beam.CombinePerKey( - evaluator._ComputationsCombineFn( # pylint: disable=protected-access - computations=computations.non_derived_computations - ) - ) - | 'PrintAfterComputeNonDerivedMetrics' - >> beam.Map(debug_print, 'AfterComputeNonDerivedMetrics') - | 'ComputeDerivedMetrics' - >> evaluator._AddDerivedCrossSliceAndDiffMetrics( # pylint: disable=protected-access - derived_computations=computations.derived_computations, - cross_slice_computations=[], - cross_slice_specs=[], - ) - | 'PrintAfterComputeDerivedMetrics' - >> beam.Map(debug_print, 'AfterComputeDerivedMetrics') - | 'RemovePrivateMetrics' - >> beam.MapTuple(evaluator._remove_private_metrics) # pylint: disable=protected-access - ) + def debug_print(element, stage_name): + if enable_debug_print: + print(f"[{stage_name}]\t{element}") + return element - # pylint: enable=no-value-for-parameter - def check_result(got): - try: - self.assertLen(got, 1) - got_slice_key, got_metrics = got[0] - self.assertEqual((), got_slice_key) - self.assertEqual(expected_metrics.keys(), got_metrics.keys()) - for key, expected_value in expected_metrics.items(): - self.assertIn(key, got_metrics) - if isinstance(expected_value, np.ndarray): - if np.issubdtype(expected_value.dtype, np.floating): - np.testing.assert_almost_equal( - expected_value, got_metrics[key], decimal=5 + computations = evaluator._filter_and_separate_computations( # pylint: disable=protected-access + metric.computations(example_weighted=example_weighted) + ) + with beam.Pipeline() as pipeline: + result = ( + pipeline + | "Create" >> beam.Create(extracts) + | "PrintAfterCreate" >> beam.Map(debug_print, "AfterCreate") + | "AddSlice" + >> beam.Map( + lambda x: x + | { + constants.SLICE_KEY_TYPES_KEY: np.array( + [slicer_lib.slice_keys_to_numpy_array([()])] + ) + } ) - else: - np.testing.assert_array_equal(expected_value, got_metrics[key]) - else: - self.assertEqual(expected_value, got_metrics[key]) - except AssertionError as err: - raise beam.testing.util.BeamAssertException(err) + | "PrintAfterAddSlice" >> beam.Map(debug_print, "AfterAddSlice") + | "Preprocess" + >> beam.ParDo( + evaluator._PreprocessorDoFn( # pylint: disable=protected-access + computations.non_derived_computations + ) + ) + | "PrintAfterPreprocess" >> beam.Map(debug_print, "AfterPreprocess") + | "FanoutSlices" >> slicer_lib.FanoutSlices() + | "PrintAfterFanoutSlices" >> beam.Map(debug_print, "AfterFanoutSlices") + | "ComputeNonDerivedMetrics" + >> beam.CombinePerKey( + evaluator._ComputationsCombineFn( # pylint: disable=protected-access + computations=computations.non_derived_computations + ) + ) + | "PrintAfterComputeNonDerivedMetrics" + >> beam.Map(debug_print, "AfterComputeNonDerivedMetrics") + | "ComputeDerivedMetrics" + >> evaluator._AddDerivedCrossSliceAndDiffMetrics( # pylint: disable=protected-access + derived_computations=computations.derived_computations, + cross_slice_computations=[], + cross_slice_specs=[], + ) + | "PrintAfterComputeDerivedMetrics" + >> beam.Map(debug_print, "AfterComputeDerivedMetrics") + | "RemovePrivateMetrics" + >> beam.MapTuple(evaluator._remove_private_metrics) # pylint: disable=protected-access + ) + + # pylint: enable=no-value-for-parameter + def check_result(got): + try: + self.assertLen(got, 1) + got_slice_key, got_metrics = got[0] + self.assertEqual((), got_slice_key) + self.assertEqual(expected_metrics.keys(), got_metrics.keys()) + for key, expected_value in expected_metrics.items(): + self.assertIn(key, got_metrics) + if isinstance(expected_value, np.ndarray): + if np.issubdtype(expected_value.dtype, np.floating): + np.testing.assert_almost_equal( + expected_value, got_metrics[key], decimal=5 + ) + else: + np.testing.assert_array_equal( + expected_value, got_metrics[key] + ) + else: + self.assertEqual(expected_value, got_metrics[key]) + except AssertionError as err: + raise beam.testing.util.BeamAssertException(err) - beam.testing.util.assert_that(result, check_result, label='result') + beam.testing.util.assert_that(result, check_result, label="result") diff --git a/tensorflow_model_analysis/metrics/tf_metric_accumulators.py b/tensorflow_model_analysis/metrics/tf_metric_accumulators.py index 12f7d950a9..06332d34bb 100644 --- a/tensorflow_model_analysis/metrics/tf_metric_accumulators.py +++ b/tensorflow_model_analysis/metrics/tf_metric_accumulators.py @@ -16,213 +16,242 @@ from typing import Any, Callable, List, Optional, Tuple, Union import numpy as np + from tensorflow_model_analysis.metrics import metric_util from tensorflow_model_analysis.proto import config_pb2 from tensorflow_model_analysis.utils import size_estimator class TFMetricsAccumulator: - """Accumulator for TF metrics. - - Attributes: - inputs: Accumulated batch of inputs. The inputs are stored in a - multi-dimensional list. The first dimension is used to index the - associated output (for single-output models this will only have one item). - The second dimension is used to store the args used by the combiner. For - example the args might be a tf.Example if feeding a model or they might be - (y_true, y_pred, example_weight) for calling update_state directly. - Batching is done on the last dimension. - weights: Accumulated weights. The weights are stored in a multi-dimensional - list where the first dimension is used to index the associated output (for - single-output models this will only have one item). The second dimension - is used to store the accumulated weights for each metric associated with - the output dimension. - size_estimator: Batch size estimator. - desired_batch_size: Desired batch size. - """ - - # We really want the batch size to be adaptive like it is in - # beam.BatchElements(), but there isn't an easy way to make it so. For now - # we will limit stored inputs to a max overall byte size. - # TODO(b/73789023): Figure out how to make this batch size dynamic. - _TOTAL_INPUT_BYTE_SIZE_THRESHOLD = 16 << 20 # 16MiB - _DEFAULT_DESIRED_BATCH_SIZE = 1000 - - __slots__ = ['_inputs', '_weights', '_size_estimator', '_desired_batch_size'] - - def __init__(self, - input_counts: List[int], - metric_counts: List[int], - size_estimator_fn: Callable[[Any], int], - desired_batch_size: Optional[int] = None): - """Initializes accumulator using a list of metric counts per output. - - Args: - input_counts: Number of inputs associated with each output index. - metric_counts: Number of metrics associated with each output index. - size_estimator_fn: Function to use for estimating the size of the inputs. - desired_batch_size: FOR TESTING ONLY. + """Accumulator for TF metrics. + + Attributes + ---------- + inputs: Accumulated batch of inputs. The inputs are stored in a + multi-dimensional list. The first dimension is used to index the + associated output (for single-output models this will only have one item). + The second dimension is used to store the args used by the combiner. For + example the args might be a tf.Example if feeding a model or they might be + (y_true, y_pred, example_weight) for calling update_state directly. + Batching is done on the last dimension. + weights: Accumulated weights. The weights are stored in a multi-dimensional + list where the first dimension is used to index the associated output (for + single-output models this will only have one item). The second dimension + is used to store the accumulated weights for each metric associated with + the output dimension. + size_estimator: Batch size estimator. + desired_batch_size: Desired batch size. """ - # Inputs have shape (num_outputs, num_metrics, num_accumulated_inputs) - self._inputs = [] - # Weights have shape (num_outputs, num_metrics) - self._weights = [] # type: List[List[Optional[np.ndarray]]] - for input_count in input_counts: - self._inputs.append(tuple([] for _ in range(input_count))) - for output_metric_count in metric_counts: - self._weights.append([None] * output_metric_count) - self._size_estimator = size_estimator.SizeEstimator( - size_threshold=self._TOTAL_INPUT_BYTE_SIZE_THRESHOLD, - size_fn=size_estimator_fn) - if desired_batch_size and desired_batch_size > 0: - self._desired_batch_size = desired_batch_size - else: - self._desired_batch_size = self._DEFAULT_DESIRED_BATCH_SIZE - - def len_inputs(self) -> int: - """Returns length of inputs.""" - return len(self._inputs[0][0]) - - def add_input(self, output_index: int, *args): - """Adds new inputs to the lists of input args stored at output_index.""" - for i, v in enumerate(args): - self._inputs[output_index][i].append(v) - if v is not None: - self._size_estimator.update(v) - - def get_inputs(self, output_index: int) -> Any: - """Returns input args for output at given offset.""" - return self._inputs[output_index] - - def clear_inputs(self): - """Clears currently stored inputs.""" - for output_index in range(len(self._inputs)): - for i in range(len(self._inputs[output_index])): - del self._inputs[output_index][i][:] - self._size_estimator.clear() - - def add_weights(self, output_index: int, metric_index: int, - weights: np.ndarray): - """Adds weights for metric at given metric_index and output_index.""" - cur_weights = self._weights[output_index][metric_index] - if cur_weights is None: - self._weights[output_index][metric_index] = weights - else: - self._weights[output_index][metric_index] = np.add(cur_weights, weights) - - def get_weights(self, output_index: int, - metric_index: int) -> Optional[np.ndarray]: - """Gets currently stored weights for given metric_index and output_index.""" - return self._weights[output_index][metric_index] - - def should_flush(self) -> bool: - """Returns true if size estimator indicates flush is needed.""" - return (self.len_inputs() >= self._desired_batch_size or - self._size_estimator.should_flush()) - - def get_size_estimate(self) -> int: - """Returns size estimator associated with accumulator.""" - return self._size_estimator.get_estimate() + + # We really want the batch size to be adaptive like it is in + # beam.BatchElements(), but there isn't an easy way to make it so. For now + # we will limit stored inputs to a max overall byte size. + # TODO(b/73789023): Figure out how to make this batch size dynamic. + _TOTAL_INPUT_BYTE_SIZE_THRESHOLD = 16 << 20 # 16MiB + _DEFAULT_DESIRED_BATCH_SIZE = 1000 + + __slots__ = ["_inputs", "_weights", "_size_estimator", "_desired_batch_size"] + + def __init__( + self, + input_counts: List[int], + metric_counts: List[int], + size_estimator_fn: Callable[[Any], int], + desired_batch_size: Optional[int] = None, + ): + """Initializes accumulator using a list of metric counts per output. + + Args: + ---- + input_counts: Number of inputs associated with each output index. + metric_counts: Number of metrics associated with each output index. + size_estimator_fn: Function to use for estimating the size of the inputs. + desired_batch_size: FOR TESTING ONLY. + """ + # Inputs have shape (num_outputs, num_metrics, num_accumulated_inputs) + self._inputs = [] + # Weights have shape (num_outputs, num_metrics) + self._weights = [] # type: List[List[Optional[np.ndarray]]] + for input_count in input_counts: + self._inputs.append(tuple([] for _ in range(input_count))) + for output_metric_count in metric_counts: + self._weights.append([None] * output_metric_count) + self._size_estimator = size_estimator.SizeEstimator( + size_threshold=self._TOTAL_INPUT_BYTE_SIZE_THRESHOLD, + size_fn=size_estimator_fn, + ) + if desired_batch_size and desired_batch_size > 0: + self._desired_batch_size = desired_batch_size + else: + self._desired_batch_size = self._DEFAULT_DESIRED_BATCH_SIZE + + def len_inputs(self) -> int: + """Returns length of inputs.""" + return len(self._inputs[0][0]) + + def add_input(self, output_index: int, *args): + """Adds new inputs to the lists of input args stored at output_index.""" + for i, v in enumerate(args): + self._inputs[output_index][i].append(v) + if v is not None: + self._size_estimator.update(v) + + def get_inputs(self, output_index: int) -> Any: + """Returns input args for output at given offset.""" + return self._inputs[output_index] + + def clear_inputs(self): + """Clears currently stored inputs.""" + for output_index in range(len(self._inputs)): + for i in range(len(self._inputs[output_index])): + del self._inputs[output_index][i][:] + self._size_estimator.clear() + + def add_weights(self, output_index: int, metric_index: int, weights: np.ndarray): + """Adds weights for metric at given metric_index and output_index.""" + cur_weights = self._weights[output_index][metric_index] + if cur_weights is None: + self._weights[output_index][metric_index] = weights + else: + self._weights[output_index][metric_index] = np.add(cur_weights, weights) + + def get_weights(self, output_index: int, metric_index: int) -> Optional[np.ndarray]: + """Gets currently stored weights for given metric_index and output_index.""" + return self._weights[output_index][metric_index] + + def should_flush(self) -> bool: + """Returns true if size estimator indicates flush is needed.""" + return ( + self.len_inputs() >= self._desired_batch_size + or self._size_estimator.should_flush() + ) + + def get_size_estimate(self) -> int: + """Returns size estimator associated with accumulator.""" + return self._size_estimator.get_estimate() def _numpy_array_size_fn(array: np.ndarray) -> int: - """Size estimator for numpy arrays.""" - return array.nbytes + """Size estimator for numpy arrays.""" + return array.nbytes class TFCompilableMetricsAccumulator(TFMetricsAccumulator): - """Accumulator for compilable TF metrics. - - Attributes: - inputs: Accumulated batch of inputs. The inputs are stored in a - multi-dimensional list. The first dimension is used to index the - associated output (for single-output models this will only have one item). - The second dimension is used to store the args passed to update_state - (i.e. (y_true, y_pred, example_weight)). Batching is done on the last - dimension.calling update_state directly. Batching is done on the last - dimension. - weights: Accumulated weights. The weights are stored in a multi-dimensional - list where the first dimension is used to index the associated output (for - single-output models this will only have one item). The second dimension - is used to store the accumulated weights for each metric associated with - the output dimension. - pad: True if padding needed. - last_dim: Max size of the last dimension of labels or predictions (used with - padding). - size_estimator: Batch size estimator. - desired_batch_size: Desired batch size. - """ - - __slots__ = [ - '_inputs', '_weights', '_pad', '_pad_to_dim', '_label_padding', - '_prediction_padding', '_size_estimator', '_desired_batch_size' - ] - - def __init__(self, - padding_options: Optional[config_pb2.PaddingOptions], - metric_counts: List[int], - desired_batch_size: Optional[int] = None): - """Initializes accumulator using a list of metric counts per output.""" - super().__init__( - # Input args of labels, predictions, example_weights for each output. - input_counts=[3] * len(metric_counts), - metric_counts=metric_counts, - size_estimator_fn=_numpy_array_size_fn, - desired_batch_size=desired_batch_size) - - self._pad = False - if padding_options is not None: - - def get_padding_value(oneof_name): - oneof = padding_options.WhichOneof(oneof_name) - return None if oneof is None else getattr(padding_options, oneof) - - self._pad = True - self._label_padding = get_padding_value('label_padding') - self._prediction_padding = get_padding_value('prediction_padding') - self._pad_to_dim = 0 - - def add_input(self, output_index: int, label: np.ndarray, - prediction: np.ndarray, example_weight: np.ndarray): - """Adds label, prediction, and example weight to output_index.""" - super().add_input(output_index, label, prediction, example_weight) - # The first output for multi-output models is not for inputs and the label - # will be None. - if self._pad and label is not None: - self._pad_to_dim = max(self._pad_to_dim, label.shape[-1], - prediction.shape[-1]) - - def get_inputs( - self, output_index: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: - """Returns labels, predictions, and weights for output at given offset.""" - labels, preds, example_weights = super().get_inputs(output_index) - if self._pad: - - def pad_value( - name: str, a: np.ndarray, - configured_value: Optional[Union[float, int]]) -> Union[int, float]: - if configured_value is None: - return 0 if a.dtype.kind == 'i' else .0 - if isinstance(configured_value, int) and a.dtype.kind == 'i': - return configured_value - if isinstance(configured_value, float) and a.dtype.kind == 'f': - return configured_value - raise ValueError('%s padding is configured to be %s but data is %s' % - (name, type(configured_value), a.dtype)) - - labels = [ - metric_util.pad(l, self._pad_to_dim, - pad_value('label', l, self._label_padding)) - for l in labels - ] - preds = [ - metric_util.pad(p, self._pad_to_dim, - pad_value('prediction', p, self._prediction_padding)) - for p in preds - ] - return (np.array(labels), np.array(preds), np.array(example_weights)) - - def clear_inputs(self): - """Clears currently stored inputs.""" - super().clear_inputs() - self._pad_to_dim = 0 + """Accumulator for compilable TF metrics. + + Attributes + ---------- + inputs: Accumulated batch of inputs. The inputs are stored in a + multi-dimensional list. The first dimension is used to index the + associated output (for single-output models this will only have one item). + The second dimension is used to store the args passed to update_state + (i.e. (y_true, y_pred, example_weight)). Batching is done on the last + dimension.calling update_state directly. Batching is done on the last + dimension. + weights: Accumulated weights. The weights are stored in a multi-dimensional + list where the first dimension is used to index the associated output (for + single-output models this will only have one item). The second dimension + is used to store the accumulated weights for each metric associated with + the output dimension. + pad: True if padding needed. + last_dim: Max size of the last dimension of labels or predictions (used with + padding). + size_estimator: Batch size estimator. + desired_batch_size: Desired batch size. + """ + + __slots__ = [ + "_inputs", + "_weights", + "_pad", + "_pad_to_dim", + "_label_padding", + "_prediction_padding", + "_size_estimator", + "_desired_batch_size", + ] + + def __init__( + self, + padding_options: Optional[config_pb2.PaddingOptions], + metric_counts: List[int], + desired_batch_size: Optional[int] = None, + ): + """Initializes accumulator using a list of metric counts per output.""" + super().__init__( + # Input args of labels, predictions, example_weights for each output. + input_counts=[3] * len(metric_counts), + metric_counts=metric_counts, + size_estimator_fn=_numpy_array_size_fn, + desired_batch_size=desired_batch_size, + ) + + self._pad = False + if padding_options is not None: + + def get_padding_value(oneof_name): + oneof = padding_options.WhichOneof(oneof_name) + return None if oneof is None else getattr(padding_options, oneof) + + self._pad = True + self._label_padding = get_padding_value("label_padding") + self._prediction_padding = get_padding_value("prediction_padding") + self._pad_to_dim = 0 + + def add_input( + self, + output_index: int, + label: np.ndarray, + prediction: np.ndarray, + example_weight: np.ndarray, + ): + """Adds label, prediction, and example weight to output_index.""" + super().add_input(output_index, label, prediction, example_weight) + # The first output for multi-output models is not for inputs and the label + # will be None. + if self._pad and label is not None: + self._pad_to_dim = max( + self._pad_to_dim, label.shape[-1], prediction.shape[-1] + ) + + def get_inputs( + self, output_index: int + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """Returns labels, predictions, and weights for output at given offset.""" + labels, preds, example_weights = super().get_inputs(output_index) + if self._pad: + + def pad_value( + name: str, a: np.ndarray, configured_value: Optional[Union[float, int]] + ) -> Union[int, float]: + if configured_value is None: + return 0 if a.dtype.kind == "i" else 0.0 + if isinstance(configured_value, int) and a.dtype.kind == "i": + return configured_value + if isinstance(configured_value, float) and a.dtype.kind == "f": + return configured_value + raise ValueError( + "%s padding is configured to be %s but data is %s" + % (name, type(configured_value), a.dtype) + ) + + labels = [ + metric_util.pad( + l, self._pad_to_dim, pad_value("label", l, self._label_padding) + ) + for l in labels + ] + preds = [ + metric_util.pad( + p, + self._pad_to_dim, + pad_value("prediction", p, self._prediction_padding), + ) + for p in preds + ] + return (np.array(labels), np.array(preds), np.array(example_weights)) + + def clear_inputs(self): + """Clears currently stored inputs.""" + super().clear_inputs() + self._pad_to_dim = 0 diff --git a/tensorflow_model_analysis/metrics/tf_metric_accumulators_test.py b/tensorflow_model_analysis/metrics/tf_metric_accumulators_test.py index b711927985..59925cbda8 100644 --- a/tensorflow_model_analysis/metrics/tf_metric_accumulators_test.py +++ b/tensorflow_model_analysis/metrics/tf_metric_accumulators_test.py @@ -15,128 +15,124 @@ import numpy as np import tensorflow as tf + from tensorflow_model_analysis.metrics import tf_metric_accumulators from tensorflow_model_analysis.proto import config_pb2 from tensorflow_model_analysis.utils import test_util class TfMetricAccumulatorsTest(test_util.TensorflowModelAnalysisTest): - - def testTFMetricsAccumulator(self): - # This test uses strings as inputs, but it works similarly to how an - # accumulator based on tf.Examples would work. - acc = tf_metric_accumulators.TFMetricsAccumulator( - input_counts=[1, 1], metric_counts=[1, 2], size_estimator_fn=len) - - self.assertEqual(0, acc.len_inputs()) - - acc.add_input(0, 'output_1_input_1') - acc.add_input(0, 'output_1_input_2') - acc.add_input(1, 'output_2_input_1') - acc.add_input(1, 'output_2_input_2') - acc.add_input(1, 'output_2_input_3') - self.assertEqual( - acc.get_inputs(0), (['output_1_input_1', 'output_1_input_2'],)) - self.assertEqual( - acc.get_inputs(1), - (['output_2_input_1', 'output_2_input_2', 'output_2_input_3'],)) - - acc.clear_inputs() - self.assertEqual(0, acc.len_inputs()) - - acc.add_weights(0, 0, np.array([1.0, 2.0])) - acc.add_weights(0, 0, np.array([3.0, 4.0])) - acc.add_weights(1, 0, np.array([5.0, 6.0])) - acc.add_weights(1, 1, np.array([7.0, 8.0])) - acc.add_weights(1, 1, np.array([9.0, 10.0])) - self.assertAllClose(acc.get_weights(0, 0), np.array([4.0, 6.0])) - self.assertAllClose(acc.get_weights(1, 0), np.array([5.0, 6.0])) - self.assertAllClose(acc.get_weights(1, 1), np.array([16.0, 18.0])) - - def testTFCompilableMetricsAccumulator(self): - acc = tf_metric_accumulators.TFCompilableMetricsAccumulator( - metric_counts=[1, 2], padding_options=None) - - self.assertEqual(0, acc.len_inputs()) - - acc.add_input(0, np.array([1.0, 0.0]), np.array([0.5, 0.5]), - np.array([1.0])) - acc.add_input(0, np.array([1.0, 1.0]), np.array([0.3, 0.7]), - np.array([1.0])) - acc.add_input(1, np.array([0.0, 0.0]), np.array([0.2, 0.8]), - np.array([0.5])) - acc.add_input(1, np.array([0.0, 1.0]), np.array([0.1, 0.9]), - np.array([0.5])) - acc.add_input(1, np.array([1.0, 1.0]), np.array([0.6, 0.4]), - np.array([0.7])) - self.assertAllClose( - acc.get_inputs(0), (np.array([ - np.array([1.0, 0.0]), np.array([1.0, 1.0]) - ]), np.array([ - np.array([0.5, 0.5]), np.array([0.3, 0.7]) - ]), np.array([np.array([1.0]), np.array([1.0])]))) - self.assertAllClose( - acc.get_inputs(1), - (np.array( - [np.array([0.0, 0.0]), - np.array([0.0, 1.0]), - np.array([1.0, 1.0])]), - np.array([ - np.array([0.2, 0.8]), - np.array([0.1, 0.9]), - np.array([0.6, 0.4]) - ]), np.array([np.array([0.5]), - np.array([0.5]), - np.array([0.7])]))) - - acc.clear_inputs() - self.assertEqual(0, acc.len_inputs()) - - acc.add_weights(0, 0, np.array([1.0, 2.0])) - acc.add_weights(1, 0, np.array([3.0, 4.0])) - acc.add_weights(1, 1, np.array([5.0, 6.0])) - acc.add_weights(1, 1, np.array([7.0, 8.0])) - self.assertAllClose(acc.get_weights(0, 0), np.array([1.0, 2.0])) - self.assertAllClose(acc.get_weights(1, 0), np.array([3.0, 4.0])) - self.assertAllClose(acc.get_weights(1, 1), np.array([12.0, 14.0])) - - def testTFCompilableMetricsAccumulatorWithFirstEmptyInput(self): - acc = tf_metric_accumulators.TFCompilableMetricsAccumulator( - metric_counts=[1, 2, 3], - padding_options=config_pb2.PaddingOptions( - label_float_padding=-1.0, - prediction_float_padding=-1.0, - ), - ) - - self.assertEqual(0, acc.len_inputs()) - - acc.add_input(0, None, None, None) - - acc.add_input( - 1, np.array([1.0, 1.0, 1.0]), np.array([0.3, 0.7]), np.array([1.0]) - ) - acc.add_input( - 2, np.array([0.0, 0.0]), np.array([0.2, 0.8]), np.array([0.5]) - ) - self.assertAllClose( - acc.get_inputs(1), - ( - np.array([[1.0, 1.0, 1.0]]), - np.array([[0.3, 0.7, -1.0]]), - np.array([[1.0]]), - ), - ) - self.assertAllClose( - acc.get_inputs(2), - ( - np.array([[0.0, 0.0, -1.0]]), - np.array([[0.2, 0.8, -1.0]]), - np.array([[0.5]]), - ), - ) - - -if __name__ == '__main__': - tf.compat.v1.enable_v2_behavior() - tf.test.main() + def testTFMetricsAccumulator(self): + # This test uses strings as inputs, but it works similarly to how an + # accumulator based on tf.Examples would work. + acc = tf_metric_accumulators.TFMetricsAccumulator( + input_counts=[1, 1], metric_counts=[1, 2], size_estimator_fn=len + ) + + self.assertEqual(0, acc.len_inputs()) + + acc.add_input(0, "output_1_input_1") + acc.add_input(0, "output_1_input_2") + acc.add_input(1, "output_2_input_1") + acc.add_input(1, "output_2_input_2") + acc.add_input(1, "output_2_input_3") + self.assertEqual(acc.get_inputs(0), (["output_1_input_1", "output_1_input_2"],)) + self.assertEqual( + acc.get_inputs(1), + (["output_2_input_1", "output_2_input_2", "output_2_input_3"],), + ) + + acc.clear_inputs() + self.assertEqual(0, acc.len_inputs()) + + acc.add_weights(0, 0, np.array([1.0, 2.0])) + acc.add_weights(0, 0, np.array([3.0, 4.0])) + acc.add_weights(1, 0, np.array([5.0, 6.0])) + acc.add_weights(1, 1, np.array([7.0, 8.0])) + acc.add_weights(1, 1, np.array([9.0, 10.0])) + self.assertAllClose(acc.get_weights(0, 0), np.array([4.0, 6.0])) + self.assertAllClose(acc.get_weights(1, 0), np.array([5.0, 6.0])) + self.assertAllClose(acc.get_weights(1, 1), np.array([16.0, 18.0])) + + def testTFCompilableMetricsAccumulator(self): + acc = tf_metric_accumulators.TFCompilableMetricsAccumulator( + metric_counts=[1, 2], padding_options=None + ) + + self.assertEqual(0, acc.len_inputs()) + + acc.add_input(0, np.array([1.0, 0.0]), np.array([0.5, 0.5]), np.array([1.0])) + acc.add_input(0, np.array([1.0, 1.0]), np.array([0.3, 0.7]), np.array([1.0])) + acc.add_input(1, np.array([0.0, 0.0]), np.array([0.2, 0.8]), np.array([0.5])) + acc.add_input(1, np.array([0.0, 1.0]), np.array([0.1, 0.9]), np.array([0.5])) + acc.add_input(1, np.array([1.0, 1.0]), np.array([0.6, 0.4]), np.array([0.7])) + self.assertAllClose( + acc.get_inputs(0), + ( + np.array([np.array([1.0, 0.0]), np.array([1.0, 1.0])]), + np.array([np.array([0.5, 0.5]), np.array([0.3, 0.7])]), + np.array([np.array([1.0]), np.array([1.0])]), + ), + ) + self.assertAllClose( + acc.get_inputs(1), + ( + np.array( + [np.array([0.0, 0.0]), np.array([0.0, 1.0]), np.array([1.0, 1.0])] + ), + np.array( + [np.array([0.2, 0.8]), np.array([0.1, 0.9]), np.array([0.6, 0.4])] + ), + np.array([np.array([0.5]), np.array([0.5]), np.array([0.7])]), + ), + ) + + acc.clear_inputs() + self.assertEqual(0, acc.len_inputs()) + + acc.add_weights(0, 0, np.array([1.0, 2.0])) + acc.add_weights(1, 0, np.array([3.0, 4.0])) + acc.add_weights(1, 1, np.array([5.0, 6.0])) + acc.add_weights(1, 1, np.array([7.0, 8.0])) + self.assertAllClose(acc.get_weights(0, 0), np.array([1.0, 2.0])) + self.assertAllClose(acc.get_weights(1, 0), np.array([3.0, 4.0])) + self.assertAllClose(acc.get_weights(1, 1), np.array([12.0, 14.0])) + + def testTFCompilableMetricsAccumulatorWithFirstEmptyInput(self): + acc = tf_metric_accumulators.TFCompilableMetricsAccumulator( + metric_counts=[1, 2, 3], + padding_options=config_pb2.PaddingOptions( + label_float_padding=-1.0, + prediction_float_padding=-1.0, + ), + ) + + self.assertEqual(0, acc.len_inputs()) + + acc.add_input(0, None, None, None) + + acc.add_input( + 1, np.array([1.0, 1.0, 1.0]), np.array([0.3, 0.7]), np.array([1.0]) + ) + acc.add_input(2, np.array([0.0, 0.0]), np.array([0.2, 0.8]), np.array([0.5])) + self.assertAllClose( + acc.get_inputs(1), + ( + np.array([[1.0, 1.0, 1.0]]), + np.array([[0.3, 0.7, -1.0]]), + np.array([[1.0]]), + ), + ) + self.assertAllClose( + acc.get_inputs(2), + ( + np.array([[0.0, 0.0, -1.0]]), + np.array([[0.2, 0.8, -1.0]]), + np.array([[0.5]]), + ), + ) + + +if __name__ == "__main__": + tf.compat.v1.enable_v2_behavior() + tf.test.main() diff --git a/tensorflow_model_analysis/metrics/tf_metric_wrapper.py b/tensorflow_model_analysis/metrics/tf_metric_wrapper.py index fa3f5ed205..5cc672591a 100644 --- a/tensorflow_model_analysis/metrics/tf_metric_wrapper.py +++ b/tensorflow_model_analysis/metrics/tf_metric_wrapper.py @@ -20,21 +20,23 @@ import apache_beam as beam import numpy as np + from tensorflow_model_analysis import constants -from tensorflow_model_analysis.metrics import binary_confusion_matrices -from tensorflow_model_analysis.metrics import metric_types -from tensorflow_model_analysis.metrics import metric_util -from tensorflow_model_analysis.metrics import tf_metric_accumulators +from tensorflow_model_analysis.metrics import ( + binary_confusion_matrices, + metric_types, + metric_util, + tf_metric_accumulators, +) from tensorflow_model_analysis.proto import config_pb2 from tensorflow_model_analysis.utils import model_util from tensorflow_model_analysis.utils.keras_lib import tf_keras - -_CONFIG_KEY = 'config' -_NUM_THRESHOLDS_KEY = 'num_thresholds' -_THRESHOLDS_KEY = 'thresholds' -_CLASS_ID_KEY = 'class_id' -_TOP_K_KEY = 'top_k' +_CONFIG_KEY = "config" +_NUM_THRESHOLDS_KEY = "num_thresholds" +_THRESHOLDS_KEY = "thresholds" +_CLASS_ID_KEY = "class_id" +_TOP_K_KEY = "top_k" _DEFAULT_NUM_THRESHOLDS_IN_KERAS = 200 _TFMetricOrLoss = Union[tf_keras.metrics.Metric, tf_keras.losses.Loss] @@ -43,105 +45,123 @@ def tf_metric_computations( metrics: Union[List[_TFMetricOrLoss], Dict[str, List[_TFMetricOrLoss]]], eval_config: Optional[config_pb2.EvalConfig] = None, - model_name: str = '', + model_name: str = "", sub_key: Optional[metric_types.SubKey] = None, aggregation_type: Optional[metric_types.AggregationType] = None, class_weights: Optional[Dict[int, float]] = None, example_weighted: bool = False, - desired_batch_size: Optional[int] = None + desired_batch_size: Optional[int] = None, ) -> metric_types.MetricComputations: - """Returns metric computations for the given TF metrics. - - Note that there is no requirement that a one to one mapping exist between the - input metrics and the output metric computations. The implementation may - combine multiple metrics into a single computation for efficency. - - Args: - metrics: Dict from metric name to tf_keras.metrics.Metric or - tf_keras.metrics.Loss. For multi-output models a dict of dicts may be - passed where the first dict is indexed by the output_name. - eval_config: Eval config. - model_name: Optional model name (if multi-model evaluation). - sub_key: Optional sub key. - aggregation_type: Optional aggregation type. - class_weights: Optional class weights to apply to multi-class / multi-label - labels and predictions. This should only be used when the aggregation_type - is set. - example_weighted: True if example weights should be applied. - desired_batch_size: Batch size to use when calling TF metrics (testing - only). - - Returns: - Metric computations. - """ - if not isinstance(metrics, dict): - metrics = {'': metrics} - - if aggregation_type is not None: - sparse_metrics = _sparse_metrics(metrics) - if sparse_metrics: - raise ValueError( - 'sparse metrics cannot be used with aggregation options. Either ' - 'disable aggregation settings or replace the sparse metrics with' - 'non-sparse versions: {}'.format(sparse_metrics)) - - metrics = _filter_duplicate_metrics(metrics, model_name, sub_key) - - computations = [] - - # For efficency, metrics are separated into confusion matrix based vs - # non-confusion matrix based metrics. Since the confusion matrix based metrics - # can all be calculated from the calibration histogram, these metrics are - # computed separately as derived metrics. The remaining non-confusion matrix - # metrics are calculated using batches of predictions/labels in eager mode - # (possibly with additional pre-processing of the values to perform - # binarization, etc). - # - # Note that in theory if a model was provided, all the metrics could be - # calculated by calling model.evaluate(). However, this call is inefficient - # for confusion matrix based metrics given the large number of weights that - # need to be calculated and the overlapping computations between the metrics. - # In addition, some metrics and plots are only defined in TFMA so a separate - # evaluation step would still be required. Lastly, if the metrics have any - # binarization, etc applied the inputs and outputs will not match those - # expected by the model. For these reasons, a separate implementation is used - # for each specific use case. It also allows evaluations that are not - # associated with a model (i.e. raw predictions are passed as input) to share - # the same code path as model based evaluations where possible. - confusion_matrix_metrics, non_confusion_matrix_metrics = ( - _separate_confusion_matrix_metrics(metrics)) - - for output_name, metrics in confusion_matrix_metrics.items(): - for metric in metrics: - computations.extend( - _wrap_confusion_matrix_metric(metric, eval_config, model_name, - output_name, sub_key, aggregation_type, - class_weights, example_weighted)) - - if non_confusion_matrix_metrics: - custom_objects = _custom_objects(non_confusion_matrix_metrics) - metric_keys, metric_configs, loss_configs = _metric_keys_and_configs( - non_confusion_matrix_metrics, model_name, sub_key, aggregation_type, - example_weighted) - for sub_key, keys in metric_keys.items(): - computations.append( - metric_types.MetricComputation( - keys=keys, - preprocessors=None, - combiner=_CompilableMetricsCombiner( - metric_configs[sub_key], - loss_configs[sub_key], - custom_objects, - eval_config, - model_name, - sub_key, - aggregation_type, - class_weights, - example_weighted, - desired_batch_size, - ))) - - return computations + """Returns metric computations for the given TF metrics. + + Note that there is no requirement that a one to one mapping exist between the + input metrics and the output metric computations. The implementation may + combine multiple metrics into a single computation for efficency. + + Args: + ---- + metrics: Dict from metric name to tf_keras.metrics.Metric or + tf_keras.metrics.Loss. For multi-output models a dict of dicts may be + passed where the first dict is indexed by the output_name. + eval_config: Eval config. + model_name: Optional model name (if multi-model evaluation). + sub_key: Optional sub key. + aggregation_type: Optional aggregation type. + class_weights: Optional class weights to apply to multi-class / multi-label + labels and predictions. This should only be used when the aggregation_type + is set. + example_weighted: True if example weights should be applied. + desired_batch_size: Batch size to use when calling TF metrics (testing + only). + + Returns: + ------- + Metric computations. + """ + if not isinstance(metrics, dict): + metrics = {"": metrics} + + if aggregation_type is not None: + sparse_metrics = _sparse_metrics(metrics) + if sparse_metrics: + raise ValueError( + "sparse metrics cannot be used with aggregation options. Either " + "disable aggregation settings or replace the sparse metrics with" + f"non-sparse versions: {sparse_metrics}" + ) + + metrics = _filter_duplicate_metrics(metrics, model_name, sub_key) + + computations = [] + + # For efficency, metrics are separated into confusion matrix based vs + # non-confusion matrix based metrics. Since the confusion matrix based metrics + # can all be calculated from the calibration histogram, these metrics are + # computed separately as derived metrics. The remaining non-confusion matrix + # metrics are calculated using batches of predictions/labels in eager mode + # (possibly with additional pre-processing of the values to perform + # binarization, etc). + # + # Note that in theory if a model was provided, all the metrics could be + # calculated by calling model.evaluate(). However, this call is inefficient + # for confusion matrix based metrics given the large number of weights that + # need to be calculated and the overlapping computations between the metrics. + # In addition, some metrics and plots are only defined in TFMA so a separate + # evaluation step would still be required. Lastly, if the metrics have any + # binarization, etc applied the inputs and outputs will not match those + # expected by the model. For these reasons, a separate implementation is used + # for each specific use case. It also allows evaluations that are not + # associated with a model (i.e. raw predictions are passed as input) to share + # the same code path as model based evaluations where possible. + confusion_matrix_metrics, non_confusion_matrix_metrics = ( + _separate_confusion_matrix_metrics(metrics) + ) + + for output_name, metrics in confusion_matrix_metrics.items(): + for metric in metrics: + computations.extend( + _wrap_confusion_matrix_metric( + metric, + eval_config, + model_name, + output_name, + sub_key, + aggregation_type, + class_weights, + example_weighted, + ) + ) + + if non_confusion_matrix_metrics: + custom_objects = _custom_objects(non_confusion_matrix_metrics) + metric_keys, metric_configs, loss_configs = _metric_keys_and_configs( + non_confusion_matrix_metrics, + model_name, + sub_key, + aggregation_type, + example_weighted, + ) + for sub_key, keys in metric_keys.items(): + computations.append( + metric_types.MetricComputation( + keys=keys, + preprocessors=None, + combiner=_CompilableMetricsCombiner( + metric_configs[sub_key], + loss_configs[sub_key], + custom_objects, + eval_config, + model_name, + sub_key, + aggregation_type, + class_weights, + example_weighted, + desired_batch_size, + ), + ) + ) + + return computations def _filter_duplicate_metrics( @@ -149,192 +169,197 @@ def _filter_duplicate_metrics( model_name: str, sub_key: Optional[metric_types.SubKey] = None, ) -> Dict[str, List[tf_keras.metrics.Metric]]: - """Filters duplicate metrics from the metrics.""" - for output_name, metrics_list in metrics.items(): - unique_metrics = {} - for metric in metrics_list: - key = metric_types.MetricKey( - name=metric.name, - model_name=model_name, - output_name=output_name, - sub_key=_verify_and_update_sub_key(model_name, output_name, sub_key, - metric)) - # Replace any previous metric (i.e. last added metric wins). - unique_metrics[key] = metric - metrics[output_name] = list(unique_metrics.values()) - return metrics + """Filters duplicate metrics from the metrics.""" + for output_name, metrics_list in metrics.items(): + unique_metrics = {} + for metric in metrics_list: + key = metric_types.MetricKey( + name=metric.name, + model_name=model_name, + output_name=output_name, + sub_key=_verify_and_update_sub_key( + model_name, output_name, sub_key, metric + ), + ) + # Replace any previous metric (i.e. last added metric wins). + unique_metrics[key] = metric + metrics[output_name] = list(unique_metrics.values()) + return metrics def _sparse_metrics( - metrics: Dict[str, List[tf_keras.metrics.Metric]] + metrics: Dict[str, List[tf_keras.metrics.Metric]], ) -> Dict[str, List[tf_keras.metrics.Metric]]: - """Returns input metrics filtered to contain only the sparse metrics.""" - results = {} - for k, v in metrics.items(): - for m in v: - if m.__class__.__name__.startswith('Sparse'): - if k not in results: - results[k] = [] - results[k].append(m) - return results + """Returns input metrics filtered to contain only the sparse metrics.""" + results = {} + for k, v in metrics.items(): + for m in v: + if m.__class__.__name__.startswith("Sparse"): + if k not in results: + results[k] = [] + results[k].append(m) + return results def _separate_confusion_matrix_metrics( - metrics: Dict[Optional[str], List[_TFMetricOrLoss]] + metrics: Dict[Optional[str], List[_TFMetricOrLoss]], ) -> Tuple[ Dict[Optional[str], List[tf_keras.metrics.Metric]], Dict[Optional[str], List[_TFMetricOrLoss]], ]: - """Separates the confusion matrix metrics from the other metrics.""" - confusion_matrix_metrics = {} - non_confusion_matrix_metrics = {} - for output_name, metrics in metrics.items(): - for metric in metrics: - # We are using type instead of isinstance here because we only want to - # match specific types and not their subclasses. Note that if the top_k - # setting is specified as part of the keras metric directly, then we - # compute the value directly in keras. Otherwise, if the top_k setting is - # only provided via BinarizeOptions then we compute the value using the - # the confusion matrix. - if type(metric) in ( # pylint: disable=unidiomatic-typecheck - tf_keras.metrics.AUC, - tf_keras.metrics.SpecificityAtSensitivity, - tf_keras.metrics.SensitivityAtSpecificity, - tf_keras.metrics.TruePositives, - tf_keras.metrics.FalsePositives, - tf_keras.metrics.TrueNegatives, - tf_keras.metrics.FalseNegatives, - tf_keras.metrics.Precision, - tf_keras.metrics.Recall, - ) and not (hasattr(metric, _TOP_K_KEY) and metric.top_k is not None): - if output_name not in confusion_matrix_metrics: - confusion_matrix_metrics[output_name] = [] - confusion_matrix_metrics[output_name].append(metric) - else: - if output_name not in non_confusion_matrix_metrics: - non_confusion_matrix_metrics[output_name] = [] - non_confusion_matrix_metrics[output_name].append(metric) - return confusion_matrix_metrics, non_confusion_matrix_metrics # pytype: disable=bad-return-type # typed-keras - - -def _verify_and_update_sub_key(model_name: str, output_name: str, - sub_key: metric_types.SubKey, - metric: _TFMetricOrLoss): - """Verifies the multi-class metric key matches settings used by the metric.""" - if hasattr(metric, _CLASS_ID_KEY) and metric.class_id is not None: - if sub_key and sub_key.class_id != metric.class_id: - raise ValueError( - '{} tf_keras.metric has class_id = {}, but the metric is being added ' - 'using sub_key = {}: model_name={}, output_name={}'.format( - metric.name, metric.class_id, sub_key, model_name, output_name - ) - ) - return metric_types.SubKey(class_id=metric.class_id) - elif hasattr(metric, _TOP_K_KEY) and metric.top_k is not None: - if sub_key and sub_key.top_k != metric.top_k: - raise ValueError( - '{} tf_keras.metric has top_k = {}, but the metric is being added ' - 'using sub_key = {}: model_name={}, output_name={}'.format( - metric.name, metric.top_k, sub_key, model_name, output_name - ) - ) - return metric_types.SubKey(top_k=metric.top_k) - else: - return sub_key - - -_KeysBySubKey = Dict[Optional[metric_types.SubKey], - List[metric_types.MetricKey]] -_ConfigsBySubKey = Dict[Optional[metric_types.SubKey], - Dict[str, List[Dict[str, Any]]]] + """Separates the confusion matrix metrics from the other metrics.""" + confusion_matrix_metrics = {} + non_confusion_matrix_metrics = {} + for output_name, metrics in metrics.items(): + for metric in metrics: + # We are using type instead of isinstance here because we only want to + # match specific types and not their subclasses. Note that if the top_k + # setting is specified as part of the keras metric directly, then we + # compute the value directly in keras. Otherwise, if the top_k setting is + # only provided via BinarizeOptions then we compute the value using the + # the confusion matrix. + if type(metric) in ( # pylint: disable=unidiomatic-typecheck + tf_keras.metrics.AUC, + tf_keras.metrics.SpecificityAtSensitivity, + tf_keras.metrics.SensitivityAtSpecificity, + tf_keras.metrics.TruePositives, + tf_keras.metrics.FalsePositives, + tf_keras.metrics.TrueNegatives, + tf_keras.metrics.FalseNegatives, + tf_keras.metrics.Precision, + tf_keras.metrics.Recall, + ) and not (hasattr(metric, _TOP_K_KEY) and metric.top_k is not None): + if output_name not in confusion_matrix_metrics: + confusion_matrix_metrics[output_name] = [] + confusion_matrix_metrics[output_name].append(metric) + else: + if output_name not in non_confusion_matrix_metrics: + non_confusion_matrix_metrics[output_name] = [] + non_confusion_matrix_metrics[output_name].append(metric) + return ( + confusion_matrix_metrics, + non_confusion_matrix_metrics, + ) # pytype: disable=bad-return-type # typed-keras + + +def _verify_and_update_sub_key( + model_name: str, + output_name: str, + sub_key: metric_types.SubKey, + metric: _TFMetricOrLoss, +): + """Verifies the multi-class metric key matches settings used by the metric.""" + if hasattr(metric, _CLASS_ID_KEY) and metric.class_id is not None: + if sub_key and sub_key.class_id != metric.class_id: + raise ValueError( + f"{metric.name} tf_keras.metric has class_id = {metric.class_id}, but the metric is being added " + f"using sub_key = {sub_key}: model_name={model_name}, output_name={output_name}" + ) + return metric_types.SubKey(class_id=metric.class_id) + elif hasattr(metric, _TOP_K_KEY) and metric.top_k is not None: + if sub_key and sub_key.top_k != metric.top_k: + raise ValueError( + f"{metric.name} tf_keras.metric has top_k = {metric.top_k}, but the metric is being added " + f"using sub_key = {sub_key}: model_name={model_name}, output_name={output_name}" + ) + return metric_types.SubKey(top_k=metric.top_k) + else: + return sub_key + + +_KeysBySubKey = Dict[Optional[metric_types.SubKey], List[metric_types.MetricKey]] +_ConfigsBySubKey = Dict[Optional[metric_types.SubKey], Dict[str, List[Dict[str, Any]]]] def _metric_keys_and_configs( - metrics: Dict[str, List[_TFMetricOrLoss]], model_name: str, + metrics: Dict[str, List[_TFMetricOrLoss]], + model_name: str, sub_key: Optional[metric_types.SubKey], aggregation_type: Optional[metric_types.AggregationType], - example_weighted: bool + example_weighted: bool, ) -> Tuple[_KeysBySubKey, _ConfigsBySubKey, _ConfigsBySubKey]: - """Returns metric keys, metric configs, and loss configs by sub key.""" - metric_keys = collections.defaultdict(list) - metric_configs = collections.defaultdict(dict) - loss_configs = collections.defaultdict(dict) - for output_name, metrics_list in metrics.items(): - for metric in metrics_list: - updated_sub_key = _verify_and_update_sub_key(model_name, output_name, - sub_key, metric) - if output_name not in metric_configs[updated_sub_key]: - metric_configs[updated_sub_key][output_name] = [] - if output_name not in loss_configs[updated_sub_key]: - loss_configs[updated_sub_key][output_name] = [] - metric_keys[updated_sub_key].append( - metric_types.MetricKey( - name=metric.name, - model_name=model_name, - output_name=output_name, - sub_key=updated_sub_key, - aggregation_type=aggregation_type, - example_weighted=example_weighted)) - if isinstance(metric, tf_keras.metrics.Metric): - metric_configs[updated_sub_key][output_name].append( - metric_util.serialize_metric(metric, use_legacy_format=True) - ) - elif isinstance(metric, tf_keras.losses.Loss): - loss_configs[updated_sub_key][output_name].append( - metric_util.serialize_loss(metric, use_legacy_format=True) - ) - return metric_keys, metric_configs, loss_configs + """Returns metric keys, metric configs, and loss configs by sub key.""" + metric_keys = collections.defaultdict(list) + metric_configs = collections.defaultdict(dict) + loss_configs = collections.defaultdict(dict) + for output_name, metrics_list in metrics.items(): + for metric in metrics_list: + updated_sub_key = _verify_and_update_sub_key( + model_name, output_name, sub_key, metric + ) + if output_name not in metric_configs[updated_sub_key]: + metric_configs[updated_sub_key][output_name] = [] + if output_name not in loss_configs[updated_sub_key]: + loss_configs[updated_sub_key][output_name] = [] + metric_keys[updated_sub_key].append( + metric_types.MetricKey( + name=metric.name, + model_name=model_name, + output_name=output_name, + sub_key=updated_sub_key, + aggregation_type=aggregation_type, + example_weighted=example_weighted, + ) + ) + if isinstance(metric, tf_keras.metrics.Metric): + metric_configs[updated_sub_key][output_name].append( + metric_util.serialize_metric(metric, use_legacy_format=True) + ) + elif isinstance(metric, tf_keras.losses.Loss): + loss_configs[updated_sub_key][output_name].append( + metric_util.serialize_loss(metric, use_legacy_format=True) + ) + return metric_keys, metric_configs, loss_configs def _deserialize_metrics( - metric_configs: List[Dict[str, Any]] + metric_configs: List[Dict[str, Any]], ) -> List[tf_keras.metrics.Metric]: - return [ - metric_util.deserialize_metric(c, use_legacy_format=True) - for c in metric_configs - ] + return [ + metric_util.deserialize_metric(c, use_legacy_format=True) + for c in metric_configs + ] def _deserialize_losses( - loss_configs: List[Dict[str, Any]] + loss_configs: List[Dict[str, Any]], ) -> List[tf_keras.losses.Loss]: - return [ - metric_util.deserialize_loss(c, use_legacy_format=True) - for c in loss_configs - ] + return [ + metric_util.deserialize_loss(c, use_legacy_format=True) for c in loss_configs + ] def _custom_objects( - metrics: Dict[str, List[tf_keras.metrics.Metric]] + metrics: Dict[str, List[tf_keras.metrics.Metric]], ) -> List[Tuple[str, str]]: - """Returns list of (module, class_name) tuples for custom objects.""" - custom_objects = [] - for metric_list in metrics.values(): - for metric in metric_list: - if ( - metric.__class__.__module__ != tf_keras.metrics.__name__ - and metric.__class__.__module__ != tf_keras.losses.__name__ - ): - custom_objects.append( - (metric.__class__.__module__, metric.__class__.__name__)) - return custom_objects - - -def _load_custom_objects( - custom_objects: List[Tuple[str, str]]) -> Dict[str, Type[Any]]: - """Loads custom metric options.""" - loaded_custom_objects = {} - for module_name, class_name in custom_objects: - module = importlib.import_module(module_name) - loaded_custom_objects[class_name] = getattr(module, class_name) - return loaded_custom_objects + """Returns list of (module, class_name) tuples for custom objects.""" + custom_objects = [] + for metric_list in metrics.values(): + for metric in metric_list: + if ( + metric.__class__.__module__ != tf_keras.metrics.__name__ + and metric.__class__.__module__ != tf_keras.losses.__name__ + ): + custom_objects.append( + (metric.__class__.__module__, metric.__class__.__name__) + ) + return custom_objects + + +def _load_custom_objects(custom_objects: List[Tuple[str, str]]) -> Dict[str, Type[Any]]: + """Loads custom metric options.""" + loaded_custom_objects = {} + for module_name, class_name in custom_objects: + module = importlib.import_module(module_name) + loaded_custom_objects[class_name] = getattr(module, class_name) + return loaded_custom_objects def _get_config_value(key: str, metric_config: Dict[str, Any]) -> Optional[Any]: - """Returns value for key within config or None.""" - if _CONFIG_KEY in metric_config and key in metric_config[_CONFIG_KEY]: - return metric_config[_CONFIG_KEY][key] - return None + """Returns value for key within config or None.""" + if _CONFIG_KEY in metric_config and key in metric_config[_CONFIG_KEY]: + return metric_config[_CONFIG_KEY][key] + return None def _wrap_confusion_matrix_metric( @@ -347,303 +372,320 @@ def _wrap_confusion_matrix_metric( class_weights: Optional[Dict[int, float]], example_weighted: bool, ) -> metric_types.MetricComputations: - """Returns confusion matrix metric wrapped in a more efficient computation.""" - - # Special handling for AUC metric which supports aggregation inherently via - # multi_label flag. - if isinstance(metric, tf_keras.metrics.AUC) and hasattr( - metric, 'label_weights' - ): - if metric.label_weights: - if class_weights: - raise ValueError( - 'class weights are configured in two different places: (1) via the ' - 'tf_keras.metrics.AUC class (using "label_weights") and (2) via ' - 'the MetricsSpecs (using "aggregate.class_weights"). Either remove ' - 'the label_weights settings in the AUC class or remove the ' - 'class_weights from the AggregationOptions: metric={}, ' - 'class_weights={}'.format(metric, class_weights) - ) - class_weights = {i: v for i, v in enumerate(metric.label_weights)} - if metric.multi_label: - raise NotImplementedError('AUC.multi_label=True is not implemented yet.') - - sub_key = _verify_and_update_sub_key(model_name, output_name, sub_key, metric) - key = metric_types.MetricKey( - name=metric.name, - model_name=model_name, - output_name=output_name, - aggregation_type=aggregation_type, - sub_key=sub_key, - example_weighted=example_weighted) - - metric_config = metric_util.serialize_metric(metric, use_legacy_format=True) - - thresholds = None - num_thresholds = None - # The top_k metrics have special settings. If we are setting the top_k value - # outside of keras (i.e. using BinarizeOptions), then we need to set the - # special threshold ourselves otherwise the default threshold of 0.5 is used. - if (sub_key and sub_key.top_k is not None and - _get_config_value(_TOP_K_KEY, metric_config) is None and - _get_config_value(_THRESHOLDS_KEY, metric_config) is None and - _get_config_value(_NUM_THRESHOLDS_KEY, metric_config) is None): - thresholds = [float('-inf')] - elif hasattr(metric, _THRESHOLDS_KEY): - thresholds = metric.thresholds - # Only one of either thresholds or num_thresholds should be used. Keras AUC - # allows both but thresholds has more precedence. - if thresholds is None and hasattr(metric, _NUM_THRESHOLDS_KEY): - num_thresholds = metric.num_thresholds - - # Make sure matrices are calculated. - computations = binary_confusion_matrices.binary_confusion_matrices( - num_thresholds=num_thresholds, - thresholds=thresholds, - eval_config=eval_config, - model_name=model_name, - output_name=output_name, - sub_key=sub_key, - aggregation_type=aggregation_type, - class_weights=class_weights, - example_weighted=example_weighted) - matrices_key = computations[-1].keys[-1] - - def result( - metrics: Dict[metric_types.MetricKey, Any] - ) -> Dict[metric_types.MetricKey, Any]: - """Returns result derived from binary confusion matrices.""" - matrices = metrics[matrices_key] - - metric = metric_util.deserialize_metric( - metric_config, use_legacy_format=True + """Returns confusion matrix metric wrapped in a more efficient computation.""" + # Special handling for AUC metric which supports aggregation inherently via + # multi_label flag. + if isinstance(metric, tf_keras.metrics.AUC) and hasattr(metric, "label_weights"): + if metric.label_weights: + if class_weights: + raise ValueError( + "class weights are configured in two different places: (1) via the " + 'tf_keras.metrics.AUC class (using "label_weights") and (2) via ' + 'the MetricsSpecs (using "aggregate.class_weights"). Either remove ' + "the label_weights settings in the AUC class or remove the " + f"class_weights from the AggregationOptions: metric={metric}, " + f"class_weights={class_weights}" + ) + class_weights = {i: v for i, v in enumerate(metric.label_weights)} + if metric.multi_label: + raise NotImplementedError("AUC.multi_label=True is not implemented yet.") + + sub_key = _verify_and_update_sub_key(model_name, output_name, sub_key, metric) + key = metric_types.MetricKey( + name=metric.name, + model_name=model_name, + output_name=output_name, + aggregation_type=aggregation_type, + sub_key=sub_key, + example_weighted=example_weighted, ) + + metric_config = metric_util.serialize_metric(metric, use_legacy_format=True) + + thresholds = None + num_thresholds = None + # The top_k metrics have special settings. If we are setting the top_k value + # outside of keras (i.e. using BinarizeOptions), then we need to set the + # special threshold ourselves otherwise the default threshold of 0.5 is used. if ( - isinstance(metric, tf_keras.metrics.AUC) - or isinstance(metric, tf_keras.metrics.SpecificityAtSensitivity) - or isinstance(metric, tf_keras.metrics.SensitivityAtSpecificity) + sub_key + and sub_key.top_k is not None + and _get_config_value(_TOP_K_KEY, metric_config) is None + and _get_config_value(_THRESHOLDS_KEY, metric_config) is None + and _get_config_value(_NUM_THRESHOLDS_KEY, metric_config) is None ): - metric.true_positives.assign(np.array(matrices.tp)) - metric.true_negatives.assign(np.array(matrices.tn)) - metric.false_positives.assign(np.array(matrices.fp)) - metric.false_negatives.assign(np.array(matrices.fn)) - elif isinstance(metric, tf_keras.metrics.Precision): - metric.true_positives.assign(np.array(matrices.tp)) - metric.false_positives.assign(np.array(matrices.fp)) - elif isinstance(metric, tf_keras.metrics.Recall): - metric.true_positives.assign(np.array(matrices.tp)) - metric.false_negatives.assign(np.array(matrices.fn)) - elif isinstance(metric, tf_keras.metrics.TruePositives): - metric.accumulator.assign(np.array(matrices.tp)) - elif isinstance(metric, tf_keras.metrics.FalsePositives): - metric.accumulator.assign(np.array(matrices.fp)) - elif isinstance(metric, tf_keras.metrics.TrueNegatives): - metric.accumulator.assign(np.array(matrices.tn)) - elif isinstance(metric, tf_keras.metrics.FalseNegatives): - metric.accumulator.assign(np.array(matrices.fn)) - return {key: metric.result().numpy()} - - derived_computation = metric_types.DerivedMetricComputation( - keys=[key], result=result) - computations.append(derived_computation) - return computations + thresholds = [float("-inf")] + elif hasattr(metric, _THRESHOLDS_KEY): + thresholds = metric.thresholds + # Only one of either thresholds or num_thresholds should be used. Keras AUC + # allows both but thresholds has more precedence. + if thresholds is None and hasattr(metric, _NUM_THRESHOLDS_KEY): + num_thresholds = metric.num_thresholds + + # Make sure matrices are calculated. + computations = binary_confusion_matrices.binary_confusion_matrices( + num_thresholds=num_thresholds, + thresholds=thresholds, + eval_config=eval_config, + model_name=model_name, + output_name=output_name, + sub_key=sub_key, + aggregation_type=aggregation_type, + class_weights=class_weights, + example_weighted=example_weighted, + ) + matrices_key = computations[-1].keys[-1] + + def result( + metrics: Dict[metric_types.MetricKey, Any], + ) -> Dict[metric_types.MetricKey, Any]: + """Returns result derived from binary confusion matrices.""" + matrices = metrics[matrices_key] + + metric = metric_util.deserialize_metric(metric_config, use_legacy_format=True) + if ( + isinstance(metric, tf_keras.metrics.AUC) + or isinstance(metric, tf_keras.metrics.SpecificityAtSensitivity) + or isinstance(metric, tf_keras.metrics.SensitivityAtSpecificity) + ): + metric.true_positives.assign(np.array(matrices.tp)) + metric.true_negatives.assign(np.array(matrices.tn)) + metric.false_positives.assign(np.array(matrices.fp)) + metric.false_negatives.assign(np.array(matrices.fn)) + elif isinstance(metric, tf_keras.metrics.Precision): + metric.true_positives.assign(np.array(matrices.tp)) + metric.false_positives.assign(np.array(matrices.fp)) + elif isinstance(metric, tf_keras.metrics.Recall): + metric.true_positives.assign(np.array(matrices.tp)) + metric.false_negatives.assign(np.array(matrices.fn)) + elif isinstance(metric, tf_keras.metrics.TruePositives): + metric.accumulator.assign(np.array(matrices.tp)) + elif isinstance(metric, tf_keras.metrics.FalsePositives): + metric.accumulator.assign(np.array(matrices.fp)) + elif isinstance(metric, tf_keras.metrics.TrueNegatives): + metric.accumulator.assign(np.array(matrices.tn)) + elif isinstance(metric, tf_keras.metrics.FalseNegatives): + metric.accumulator.assign(np.array(matrices.fn)) + return {key: metric.result().numpy()} + + derived_computation = metric_types.DerivedMetricComputation( + keys=[key], result=result + ) + computations.append(derived_computation) + return computations class _LossMetric(tf_keras.metrics.Mean): - """Converts a loss function into a metric.""" - - def __init__(self, loss, name=None, dtype=None): - if name is None: - name = loss.name - super().__init__(name=name, dtype=dtype) - self.loss = loss - - def update_state(self, y_true, y_pred, sample_weight): # pytype: disable=signature-mismatch # overriding-parameter-count-checks - return super().update_state( - self.loss(y_true, y_pred), sample_weight=sample_weight) + """Converts a loss function into a metric.""" + + def __init__(self, loss, name=None, dtype=None): + if name is None: + name = loss.name + super().__init__(name=name, dtype=dtype) + self.loss = loss + + def update_state( + self, y_true, y_pred, sample_weight + ): # pytype: disable=signature-mismatch # overriding-parameter-count-checks + return super().update_state( + self.loss(y_true, y_pred), sample_weight=sample_weight + ) class _CompilableMetricsCombiner(beam.CombineFn): - """Combines compilable metric weights and computes result.""" - - # TODO(b/173811366): Consider removing the desired_batch_size knob and - # only use input size. - def __init__(self, - metric_configs: Dict[str, List[Dict[str, Any]]], - loss_configs: Dict[str, List[Dict[str, Any]]], - custom_objects: List[Tuple[str, str]], - eval_config: Optional[config_pb2.EvalConfig], - model_name: Optional[str], - sub_key: Optional[metric_types.SubKey], - aggregation_type: Optional[metric_types.AggregationType], - class_weights: Dict[int, float], - example_weighted: bool, - desired_batch_size: Optional[int] = None): - # Use parallel lists to store output_names and configs to guarantee - # consistent ordering and for natural alignment with the accumulator where - # lists are used instead of dicts for efficency. - self._eval_config = eval_config - self._model_name = model_name - self._output_names = sorted(metric_configs) - self._metric_configs = [metric_configs[n] for n in self._output_names] - self._loss_configs = [loss_configs[n] for n in self._output_names] - self._custom_objects = custom_objects - self._sub_key = sub_key - self._aggregation_type = aggregation_type - self._class_weights = class_weights - self._example_weighted = example_weighted - # True if the sub_key is part of the metric config already (i.e. top_k). - self._sub_key_in_config = sub_key and sub_key.top_k is not None - for cfg in itertools.chain.from_iterable(metric_configs.values()): - if _get_config_value(_TOP_K_KEY, cfg) is None: - self._sub_key_in_config = False - break - self._metrics = None # type: Dict[str, List[tf_keras.metrics.Metric]] - self._desired_batch_size = desired_batch_size - self._batch_size_beam_metric = ( - beam.metrics.Metrics.distribution( + """Combines compilable metric weights and computes result.""" + + # TODO(b/173811366): Consider removing the desired_batch_size knob and + # only use input size. + def __init__( + self, + metric_configs: Dict[str, List[Dict[str, Any]]], + loss_configs: Dict[str, List[Dict[str, Any]]], + custom_objects: List[Tuple[str, str]], + eval_config: Optional[config_pb2.EvalConfig], + model_name: Optional[str], + sub_key: Optional[metric_types.SubKey], + aggregation_type: Optional[metric_types.AggregationType], + class_weights: Dict[int, float], + example_weighted: bool, + desired_batch_size: Optional[int] = None, + ): + # Use parallel lists to store output_names and configs to guarantee + # consistent ordering and for natural alignment with the accumulator where + # lists are used instead of dicts for efficency. + self._eval_config = eval_config + self._model_name = model_name + self._output_names = sorted(metric_configs) + self._metric_configs = [metric_configs[n] for n in self._output_names] + self._loss_configs = [loss_configs[n] for n in self._output_names] + self._custom_objects = custom_objects + self._sub_key = sub_key + self._aggregation_type = aggregation_type + self._class_weights = class_weights + self._example_weighted = example_weighted + # True if the sub_key is part of the metric config already (i.e. top_k). + self._sub_key_in_config = sub_key and sub_key.top_k is not None + for cfg in itertools.chain.from_iterable(metric_configs.values()): + if _get_config_value(_TOP_K_KEY, cfg) is None: + self._sub_key_in_config = False + break + self._metrics = None # type: Dict[str, List[tf_keras.metrics.Metric]] + self._desired_batch_size = desired_batch_size + self._batch_size_beam_metric = beam.metrics.Metrics.distribution( + constants.METRICS_NAMESPACE, "keras_compilable_metrics_combine_batch_size" + ) + self._total_input_byte_size_beam_metric = beam.metrics.Metrics.distribution( constants.METRICS_NAMESPACE, - 'keras_compilable_metrics_combine_batch_size')) - self._total_input_byte_size_beam_metric = beam.metrics.Metrics.distribution( - constants.METRICS_NAMESPACE, - 'keras_compilable_metrics_combine_batch_bytes_size') - self._num_compacts = beam.metrics.Metrics.counter( - constants.METRICS_NAMESPACE, 'num_compacts') - - def setup(self): - if self._metrics is None: - self._metrics = {} - with tf_keras.utils.custom_object_scope( - _load_custom_objects(self._custom_objects) - ): + "keras_compilable_metrics_combine_batch_bytes_size", + ) + self._num_compacts = beam.metrics.Metrics.counter( + constants.METRICS_NAMESPACE, "num_compacts" + ) + + def setup(self): + if self._metrics is None: + self._metrics = {} + with tf_keras.utils.custom_object_scope( + _load_custom_objects(self._custom_objects) + ): + for i, output_name in enumerate(self._output_names): + self._metrics[output_name] = _deserialize_metrics( + self._metric_configs[i] + ) + for loss in _deserialize_losses(self._loss_configs[i]): + self._metrics[output_name].append(_LossMetric(loss)) + + def _process_batch( + self, accumulator: tf_metric_accumulators.TFCompilableMetricsAccumulator + ): + if accumulator.len_inputs() == 0: + return + self._batch_size_beam_metric.update(accumulator.len_inputs()) + self._total_input_byte_size_beam_metric.update(accumulator.get_size_estimate()) + for output_index, output_name in enumerate(self._output_names): + inputs = accumulator.get_inputs(output_index) + for metric_index, metric in enumerate(self._metrics[output_name]): + try: + metric.reset_states() + metric.update_state(*inputs) + except Exception as e: + raise ValueError( + f"TF Metric {metric.name} fails to update with inputs:\n{inputs}," + f"\nMetric full config: {metric.get_config()}" + ) from e + accumulator.add_weights( + output_index, metric_index, metric.get_weights() + ) + accumulator.clear_inputs() + + def create_accumulator( + self, + ) -> tf_metric_accumulators.TFCompilableMetricsAccumulator: + configs = zip(self._metric_configs, self._loss_configs) + padding_options = None + if self._eval_config is not None: + model_spec = model_util.get_model_spec(self._eval_config, self._model_name) + if model_spec is not None and model_spec.HasField("padding_options"): + padding_options = model_spec.padding_options + + return tf_metric_accumulators.TFCompilableMetricsAccumulator( + padding_options, + [len(m) + len(l) for m, l in configs], + desired_batch_size=self._desired_batch_size, + ) + + def add_input( + self, + accumulator: tf_metric_accumulators.TFCompilableMetricsAccumulator, + element: metric_types.StandardMetricInputs, + ) -> tf_metric_accumulators.TFCompilableMetricsAccumulator: for i, output_name in enumerate(self._output_names): - self._metrics[output_name] = ( - _deserialize_metrics(self._metric_configs[i])) - for loss in _deserialize_losses(self._loss_configs[i]): - self._metrics[output_name].append(_LossMetric(loss)) - - def _process_batch( - self, accumulator: tf_metric_accumulators.TFCompilableMetricsAccumulator): - if accumulator.len_inputs() == 0: - return - self._batch_size_beam_metric.update(accumulator.len_inputs()) - self._total_input_byte_size_beam_metric.update( - accumulator.get_size_estimate()) - for output_index, output_name in enumerate(self._output_names): - inputs = accumulator.get_inputs(output_index) - for metric_index, metric in enumerate(self._metrics[output_name]): - try: - metric.reset_states() - metric.update_state(*inputs) - except Exception as e: - raise ValueError( - f'TF Metric {metric.name} fails to update with inputs:\n{inputs},' - f'\nMetric full config: {metric.get_config()}' - ) from e - accumulator.add_weights(output_index, metric_index, - metric.get_weights()) - accumulator.clear_inputs() - - def create_accumulator( - self) -> tf_metric_accumulators.TFCompilableMetricsAccumulator: - configs = zip(self._metric_configs, self._loss_configs) - padding_options = None - if self._eval_config is not None: - model_spec = model_util.get_model_spec(self._eval_config, - self._model_name) - if model_spec is not None and model_spec.HasField('padding_options'): - padding_options = model_spec.padding_options - - return tf_metric_accumulators.TFCompilableMetricsAccumulator( - padding_options, [len(m) + len(l) for m, l in configs], - desired_batch_size=self._desired_batch_size) - - def add_input( - self, accumulator: tf_metric_accumulators.TFCompilableMetricsAccumulator, - element: metric_types.StandardMetricInputs - ) -> tf_metric_accumulators.TFCompilableMetricsAccumulator: - for i, output_name in enumerate(self._output_names): - # When micro averaging is being used, flatten should be set to True so - # that each class is treated as though it was an independent example. - micro_average = ( - self._aggregation_type and self._aggregation_type.micro_average) - for label, prediction, example_weight in ( - metric_util.to_label_prediction_example_weight( - element, - eval_config=self._eval_config, - model_name=self._model_name, - output_name=output_name, - # Skip sub_key processing if part of the keras config - sub_key=self._sub_key if not self._sub_key_in_config else None, - aggregation_type=self._aggregation_type, - class_weights=self._class_weights, - example_weighted=self._example_weighted, - flatten=micro_average)): - # Keras requires non-sparse keys for its calcuations. - if self._sub_key_in_config and label.shape != prediction.shape: - label = metric_util.one_hot(label, prediction) - accumulator.add_input(i, label, prediction, example_weight) - if accumulator.should_flush(): - self._process_batch(accumulator) - return accumulator - - def merge_accumulators( - self, accumulators: Iterable[ - tf_metric_accumulators.TFCompilableMetricsAccumulator] - ) -> tf_metric_accumulators.TFCompilableMetricsAccumulator: - accumulators = iter(accumulators) - result = next(accumulators) - self._process_batch(result) - for accumulator in accumulators: - # Finish processing last batch - self._process_batch(accumulator) - # Merge the weights - for output_index, output_name in enumerate(self._output_names): - for metric_index in range(len(self._metrics[output_name])): - weights = accumulator.get_weights(output_index, metric_index) - if weights is None: - # It is possible for beam to create an accumulator but pass no - # inputs to it resulting in in empty weights. In theory all weights - # should be empty but we check on a per metric weights basis. - continue - result.add_weights(output_index, metric_index, weights) - return result - - def compact( - self, accumulator: tf_metric_accumulators.TFCompilableMetricsAccumulator - ) -> tf_metric_accumulators.TFCompilableMetricsAccumulator: - self._process_batch(accumulator) - self._num_compacts.inc(1) - return accumulator - - def extract_output( - self, accumulator: tf_metric_accumulators.TFCompilableMetricsAccumulator - ) -> Dict[metric_types.MetricKey, Any]: - self._process_batch(accumulator) - - def make_metric_key(metric_name, output_name): - return metric_types.MetricKey( - name=metric_name, - model_name=self._model_name, - output_name=output_name, - sub_key=self._sub_key, - example_weighted=self._example_weighted) - - result = {} - for output_index, output_name in enumerate(self._output_names): - for metric_index, metric in enumerate(self._metrics[output_name]): - - weights = accumulator.get_weights(output_index, metric_index) - if weights is not None: - metric.set_weights(weights) - else: - metric.reset_states() - metric_result = metric.result() - if isinstance(metric_result, dict): - for name, value in metric_result.items(): - key = make_metric_key(f'{metric.name}/{name}', output_name) - result[key] = value.numpy() - else: - key = make_metric_key(metric.name, output_name) - result[key] = metric_result.numpy() - return result + # When micro averaging is being used, flatten should be set to True so + # that each class is treated as though it was an independent example. + micro_average = ( + self._aggregation_type and self._aggregation_type.micro_average + ) + for ( + label, + prediction, + example_weight, + ) in metric_util.to_label_prediction_example_weight( + element, + eval_config=self._eval_config, + model_name=self._model_name, + output_name=output_name, + # Skip sub_key processing if part of the keras config + sub_key=self._sub_key if not self._sub_key_in_config else None, + aggregation_type=self._aggregation_type, + class_weights=self._class_weights, + example_weighted=self._example_weighted, + flatten=micro_average, + ): + # Keras requires non-sparse keys for its calcuations. + if self._sub_key_in_config and label.shape != prediction.shape: + label = metric_util.one_hot(label, prediction) + accumulator.add_input(i, label, prediction, example_weight) + if accumulator.should_flush(): + self._process_batch(accumulator) + return accumulator + + def merge_accumulators( + self, + accumulators: Iterable[tf_metric_accumulators.TFCompilableMetricsAccumulator], + ) -> tf_metric_accumulators.TFCompilableMetricsAccumulator: + accumulators = iter(accumulators) + result = next(accumulators) + self._process_batch(result) + for accumulator in accumulators: + # Finish processing last batch + self._process_batch(accumulator) + # Merge the weights + for output_index, output_name in enumerate(self._output_names): + for metric_index in range(len(self._metrics[output_name])): + weights = accumulator.get_weights(output_index, metric_index) + if weights is None: + # It is possible for beam to create an accumulator but pass no + # inputs to it resulting in in empty weights. In theory all weights + # should be empty but we check on a per metric weights basis. + continue + result.add_weights(output_index, metric_index, weights) + return result + + def compact( + self, accumulator: tf_metric_accumulators.TFCompilableMetricsAccumulator + ) -> tf_metric_accumulators.TFCompilableMetricsAccumulator: + self._process_batch(accumulator) + self._num_compacts.inc(1) + return accumulator + + def extract_output( + self, accumulator: tf_metric_accumulators.TFCompilableMetricsAccumulator + ) -> Dict[metric_types.MetricKey, Any]: + self._process_batch(accumulator) + + def make_metric_key(metric_name, output_name): + return metric_types.MetricKey( + name=metric_name, + model_name=self._model_name, + output_name=output_name, + sub_key=self._sub_key, + example_weighted=self._example_weighted, + ) + + result = {} + for output_index, output_name in enumerate(self._output_names): + for metric_index, metric in enumerate(self._metrics[output_name]): + weights = accumulator.get_weights(output_index, metric_index) + if weights is not None: + metric.set_weights(weights) + else: + metric.reset_states() + metric_result = metric.result() + if isinstance(metric_result, dict): + for name, value in metric_result.items(): + key = make_metric_key(f"{metric.name}/{name}", output_name) + result[key] = value.numpy() + else: + key = make_metric_key(metric.name, output_name) + result[key] = metric_result.numpy() + return result diff --git a/tensorflow_model_analysis/metrics/tf_metric_wrapper_test.py b/tensorflow_model_analysis/metrics/tf_metric_wrapper_test.py index 73084c071a..5be1a515c4 100644 --- a/tensorflow_model_analysis/metrics/tf_metric_wrapper_test.py +++ b/tensorflow_model_analysis/metrics/tf_metric_wrapper_test.py @@ -13,1134 +13,1142 @@ # limitations under the License. """Tests for TF metric wrapper.""" -from absl.testing import parameterized import apache_beam as beam -from apache_beam.testing import util import numpy as np import tensorflow as tf -from tensorflow_model_analysis.metrics import metric_types -from tensorflow_model_analysis.metrics import metric_util -from tensorflow_model_analysis.metrics import tf_metric_wrapper +from absl.testing import parameterized +from apache_beam.testing import util + +from tensorflow_model_analysis.metrics import ( + metric_types, + metric_util, + tf_metric_wrapper, +) from tensorflow_model_analysis.proto import config_pb2 from tensorflow_model_analysis.utils import test_util from tensorflow_model_analysis.utils.keras_lib import tf_keras class _CustomMetric(tf_keras.metrics.Mean): + def __init__(self, name="custom", dtype=None, update_y_pred=True): + super().__init__(name=name, dtype=dtype) + self.update_y_pred = update_y_pred - def __init__(self, name='custom', dtype=None, update_y_pred=True): - super().__init__(name=name, dtype=dtype) - self.update_y_pred = update_y_pred - - def update_state(self, y_true, y_pred, sample_weight): - return super().update_state( - y_pred if self.update_y_pred else y_true, sample_weight=sample_weight - ) + def update_state(self, y_true, y_pred, sample_weight): + return super().update_state( + y_pred if self.update_y_pred else y_true, sample_weight=sample_weight + ) - def get_config(self): - cfg = super().get_config() - cfg.update({'update_y_pred': self.update_y_pred}) - return cfg + def get_config(self): + cfg = super().get_config() + cfg.update({"update_y_pred": self.update_y_pred}) + return cfg class _CustomConfusionMatrixMetric(tf_keras.metrics.Precision): + def __init__(self, name="custom", dtype=None): + super().__init__(name=name, dtype=dtype) - def __init__(self, name='custom', dtype=None): - super().__init__(name=name, dtype=dtype) - - def update_state(self, y_true, y_pred, sample_weight): - super().update_state(y_true, y_pred, sample_weight=sample_weight) + def update_state(self, y_true, y_pred, sample_weight): + super().update_state(y_true, y_pred, sample_weight=sample_weight) - def get_config(self): - # Remove config items we don't accept or they will be passed to __init__. - base_config = super().get_config() - return {'name': base_config['name'], 'dtype': base_config['dtype']} + def get_config(self): + # Remove config items we don't accept or they will be passed to __init__. + base_config = super().get_config() + return {"name": base_config["name"], "dtype": base_config["dtype"]} class _CustomMeanSquaredError(tf_keras.metrics.MeanSquaredError): + def __init__(self, name, dtype=None): + super().__init__(name=name, dtype=dtype) - def __init__(self, name, dtype=None): - super().__init__(name=name, dtype=dtype) - - def result(self): - mse = super().result() - return {'mse': mse, 'one_minus_mse': 1 - mse} + def result(self): + mse = super().result() + return {"mse": mse, "one_minus_mse": 1 - mse} class ConfusionMatrixMetricsTest( test_util.TensorflowModelAnalysisTest, parameterized.TestCase ): + # This is needed because of pickling errors when using + # parameterized.named_parameters with TF metric types. + def _tf_metric_by_name(self, metric_name): + """Returns instance of tf_keras.metric with default args given name.""" + if metric_name == "auc": + return tf_keras.metrics.AUC(name="auc") + elif metric_name == "auc_pr": + return tf_keras.metrics.AUC(name="auc_pr", curve="PR") + elif metric_name == "precision": + return tf_keras.metrics.Precision(name="precision") + elif metric_name == "precision@2": + return tf_keras.metrics.Precision(name="precision@2", top_k=2) + elif metric_name == "precision@3": + return tf_keras.metrics.Precision(name="precision@3", top_k=3) + elif metric_name == "recall": + return tf_keras.metrics.Recall(name="recall") + elif metric_name == "recall@2": + return tf_keras.metrics.Recall(name="recall@2", top_k=2) + elif metric_name == "recall@3": + return tf_keras.metrics.Recall(name="recall@3", top_k=3) + elif metric_name == "true_positives": + return tf_keras.metrics.TruePositives(name="true_positives") + elif metric_name == "false_positives": + return tf_keras.metrics.FalsePositives(name="false_positives") + elif metric_name == "true_negatives": + return tf_keras.metrics.TrueNegatives(name="true_negatives") + elif metric_name == "false_negatives": + return tf_keras.metrics.FalseNegatives(name="false_negatives") + elif metric_name == "specificity_at_sensitivity": + return tf_keras.metrics.SpecificityAtSensitivity( + 0.5, name="specificity_at_sensitivity" + ) + elif metric_name == "sensitivity_at_specificity": + return tf_keras.metrics.SensitivityAtSpecificity( + 0.5, name="sensitivity_at_specificity" + ) + + @parameterized.named_parameters( + ("auc", "auc", 0.75), + ("auc_pr", "auc_pr", 0.79727), + ("precision", "precision", 1.0), + ("recall", "recall", 0.5), + ("true_positives", "true_positives", 1.0), + ("false_positives", "false_positives", 0.0), + ("true_negatives", "true_negatives", 2.0), + ("false_negatives", "false_negatives", 1.0), + ("specificity_at_sensitivity", "specificity_at_sensitivity", 1.0), + ("sensitivity_at_specificity", "sensitivity_at_specificity", 1.0), + ) + def testMetricsWithoutWeights(self, metric_name, expected_value): + # TODO (b/151636380): remove when CL/299961405 is propagated through Kokoro. + if metric_name == "specificity_at_sensitivity": + fix_present = hasattr( + tf_keras.metrics.SpecificityAtSensitivity, + "_find_max_under_constraint", + ) + if not fix_present: + expected_value = 0.5 + computations = tf_metric_wrapper.tf_metric_computations( + [self._tf_metric_by_name(metric_name)], example_weighted=False + ) + histogram = computations[0] + matrix = computations[1] + metric = computations[2] + + example1 = { + "labels": np.array([0.0]), + "predictions": np.array([0.0]), + "example_weights": np.array([0.1]), # ignored, example_weighted=False + } + example2 = { + "labels": np.array([0.0]), + "predictions": np.array([0.5]), + "example_weights": np.array([0.2]), # ignored, example_weighted=False + } + example3 = { + "labels": np.array([1.0]), + "predictions": np.array([0.3]), + "example_weights": np.array([0.3]), # ignored, example_weighted=False + } + example4 = { + "labels": np.array([1.0]), + "predictions": np.array([0.9]), + "example_weights": np.array([0.4]), # ignored, example_weighted=False + } + + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "Create" >> beam.Create([example1, example2, example3, example4]) + | "Process" >> beam.Map(metric_util.to_standard_metric_inputs) + | "AddSlice" >> beam.Map(lambda x: ((), x)) + | "ComputeHistogram" >> beam.CombinePerKey(histogram.combiner) + | "ComputeConfusionMatrix" + >> beam.Map(lambda x: (x[0], matrix.result(x[1]))) + | "ComputeMetric" >> beam.Map(lambda x: (x[0], metric.result(x[1]))) + ) + + # pylint: enable=no-value-for-parameter + + def check_result(got): + try: + self.assertLen(got, 1) + got_slice_key, got_metrics = got[0] + self.assertEqual(got_slice_key, ()) + key = metric_types.MetricKey( + name=metric_name, example_weighted=False + ) + self.assertDictElementsAlmostEqual( + got_metrics, {key: expected_value}, places=5 + ) - # This is needed because of pickling errors when using - # parameterized.named_parameters with TF metric types. - def _tf_metric_by_name(self, metric_name): - """Returns instance of tf_keras.metric with default args given name.""" - if metric_name == 'auc': - return tf_keras.metrics.AUC(name='auc') - elif metric_name == 'auc_pr': - return tf_keras.metrics.AUC(name='auc_pr', curve='PR') - elif metric_name == 'precision': - return tf_keras.metrics.Precision(name='precision') - elif metric_name == 'precision@2': - return tf_keras.metrics.Precision(name='precision@2', top_k=2) - elif metric_name == 'precision@3': - return tf_keras.metrics.Precision(name='precision@3', top_k=3) - elif metric_name == 'recall': - return tf_keras.metrics.Recall(name='recall') - elif metric_name == 'recall@2': - return tf_keras.metrics.Recall(name='recall@2', top_k=2) - elif metric_name == 'recall@3': - return tf_keras.metrics.Recall(name='recall@3', top_k=3) - elif metric_name == 'true_positives': - return tf_keras.metrics.TruePositives(name='true_positives') - elif metric_name == 'false_positives': - return tf_keras.metrics.FalsePositives(name='false_positives') - elif metric_name == 'true_negatives': - return tf_keras.metrics.TrueNegatives(name='true_negatives') - elif metric_name == 'false_negatives': - return tf_keras.metrics.FalseNegatives(name='false_negatives') - elif metric_name == 'specificity_at_sensitivity': - return tf_keras.metrics.SpecificityAtSensitivity( - 0.5, name='specificity_at_sensitivity' - ) - elif metric_name == 'sensitivity_at_specificity': - return tf_keras.metrics.SensitivityAtSpecificity( - 0.5, name='sensitivity_at_specificity' - ) - - @parameterized.named_parameters( - ('auc', 'auc', 0.75), - ('auc_pr', 'auc_pr', 0.79727), - ('precision', 'precision', 1.0), - ('recall', 'recall', 0.5), - ('true_positives', 'true_positives', 1.0), - ('false_positives', 'false_positives', 0.0), - ('true_negatives', 'true_negatives', 2.0), - ('false_negatives', 'false_negatives', 1.0), - ('specificity_at_sensitivity', 'specificity_at_sensitivity', 1.0), - ('sensitivity_at_specificity', 'sensitivity_at_specificity', 1.0), - ) - def testMetricsWithoutWeights(self, metric_name, expected_value): - # TODO (b/151636380): remove when CL/299961405 is propagated through Kokoro. - if metric_name == 'specificity_at_sensitivity': - fix_present = hasattr( - tf_keras.metrics.SpecificityAtSensitivity, - '_find_max_under_constraint', - ) - if not fix_present: - expected_value = 0.5 - computations = tf_metric_wrapper.tf_metric_computations( - [self._tf_metric_by_name(metric_name)], example_weighted=False + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(result, check_result, label="result") + + @parameterized.named_parameters( + ("auc", "auc", 0.64286), + ("auc_pr", "auc_pr", 0.37467), + ("precision", "precision", 0.5833333), + ("recall", "recall", 1.0), + ("true_positives", "true_positives", 0.7), + ("false_positives", "false_positives", 0.5), + ("true_negatives", "true_negatives", 0.9), + ("false_negatives", "false_negatives", 0.0), + ("specificity_at_sensitivity", "specificity_at_sensitivity", 0.642857), + ("sensitivity_at_specificity", "sensitivity_at_specificity", 1.0), ) - histogram = computations[0] - matrix = computations[1] - metric = computations[2] - - example1 = { - 'labels': np.array([0.0]), - 'predictions': np.array([0.0]), - 'example_weights': np.array([0.1]), # ignored, example_weighted=False - } - example2 = { - 'labels': np.array([0.0]), - 'predictions': np.array([0.5]), - 'example_weights': np.array([0.2]), # ignored, example_weighted=False - } - example3 = { - 'labels': np.array([1.0]), - 'predictions': np.array([0.3]), - 'example_weights': np.array([0.3]), # ignored, example_weighted=False - } - example4 = { - 'labels': np.array([1.0]), - 'predictions': np.array([0.9]), - 'example_weights': np.array([0.4]), # ignored, example_weighted=False - } - - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'Create' >> beam.Create([example1, example2, example3, example4]) - | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs) - | 'AddSlice' >> beam.Map(lambda x: ((), x)) - | 'ComputeHistogram' >> beam.CombinePerKey(histogram.combiner) - | 'ComputeConfusionMatrix' - >> beam.Map(lambda x: (x[0], matrix.result(x[1]))) - | 'ComputeMetric' >> beam.Map(lambda x: (x[0], metric.result(x[1]))) - ) - - # pylint: enable=no-value-for-parameter - - def check_result(got): - try: - self.assertLen(got, 1) - got_slice_key, got_metrics = got[0] - self.assertEqual(got_slice_key, ()) - key = metric_types.MetricKey(name=metric_name, example_weighted=False) - self.assertDictElementsAlmostEqual( - got_metrics, {key: expected_value}, places=5 - ) - - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(result, check_result, label='result') - - @parameterized.named_parameters( - ('auc', 'auc', 0.64286), - ('auc_pr', 'auc_pr', 0.37467), - ('precision', 'precision', 0.5833333), - ('recall', 'recall', 1.0), - ('true_positives', 'true_positives', 0.7), - ('false_positives', 'false_positives', 0.5), - ('true_negatives', 'true_negatives', 0.9), - ('false_negatives', 'false_negatives', 0.0), - ('specificity_at_sensitivity', 'specificity_at_sensitivity', 0.642857), - ('sensitivity_at_specificity', 'sensitivity_at_specificity', 1.0), - ) - def testMetricsWithWeights(self, metric_name, expected_value): - # TODO (b/151636380): remove when CL/299961405 is propagated through Kokoro. - if metric_name == 'specificity_at_sensitivity': - fix_present = hasattr( - tf_keras.metrics.SpecificityAtSensitivity, - '_find_max_under_constraint', - ) - if not fix_present: - expected_value = 0.0 - - computations = tf_metric_wrapper.tf_metric_computations( - [self._tf_metric_by_name(metric_name)], example_weighted=True + def testMetricsWithWeights(self, metric_name, expected_value): + # TODO (b/151636380): remove when CL/299961405 is propagated through Kokoro. + if metric_name == "specificity_at_sensitivity": + fix_present = hasattr( + tf_keras.metrics.SpecificityAtSensitivity, + "_find_max_under_constraint", + ) + if not fix_present: + expected_value = 0.0 + + computations = tf_metric_wrapper.tf_metric_computations( + [self._tf_metric_by_name(metric_name)], example_weighted=True + ) + histogram = computations[0] + matrix = computations[1] + metric = computations[2] + + example1 = { + "labels": np.array([0.0]), + "predictions": np.array([1.0]), + "example_weights": np.array([0.5]), + } + example2 = { + "labels": np.array([1.0]), + "predictions": np.array([0.7]), + "example_weights": np.array([0.7]), + } + example3 = { + "labels": np.array([0.0]), + "predictions": np.array([0.5]), + "example_weights": np.array([0.9]), + } + + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "Create" >> beam.Create([example1, example2, example3]) + | "Process" >> beam.Map(metric_util.to_standard_metric_inputs) + | "AddSlice" >> beam.Map(lambda x: ((), x)) + | "ComputeHistogram" >> beam.CombinePerKey(histogram.combiner) + | "ComputeConfusionMatrix" + >> beam.Map(lambda x: (x[0], matrix.result(x[1]))) + | "ComputeMetric" >> beam.Map(lambda x: (x[0], metric.result(x[1]))) + ) + + # pylint: enable=no-value-for-parameter + + def check_result(got): + try: + self.assertLen(got, 1) + got_slice_key, got_metrics = got[0] + self.assertEqual(got_slice_key, ()) + key = metric_types.MetricKey( + name=metric_name, example_weighted=True + ) + self.assertDictElementsAlmostEqual( + got_metrics, {key: expected_value}, places=5 + ) + + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(result, check_result, label="result") + + @parameterized.named_parameters( + ("auc", "auc", 0.8571428), + ("auc_pr", "auc_pr", 0.77369833), + ("true_positives", "true_positives", 1.4), + ("false_positives", "false_positives", 0.6), + ("true_negatives", "true_negatives", 1.0), + ("false_negatives", "false_negatives", 0.0), ) - histogram = computations[0] - matrix = computations[1] - metric = computations[2] - - example1 = { - 'labels': np.array([0.0]), - 'predictions': np.array([1.0]), - 'example_weights': np.array([0.5]), - } - example2 = { - 'labels': np.array([1.0]), - 'predictions': np.array([0.7]), - 'example_weights': np.array([0.7]), - } - example3 = { - 'labels': np.array([0.0]), - 'predictions': np.array([0.5]), - 'example_weights': np.array([0.9]), - } - - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'Create' >> beam.Create([example1, example2, example3]) - | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs) - | 'AddSlice' >> beam.Map(lambda x: ((), x)) - | 'ComputeHistogram' >> beam.CombinePerKey(histogram.combiner) - | 'ComputeConfusionMatrix' - >> beam.Map(lambda x: (x[0], matrix.result(x[1]))) - | 'ComputeMetric' >> beam.Map(lambda x: (x[0], metric.result(x[1]))) - ) - - # pylint: enable=no-value-for-parameter - - def check_result(got): - try: - self.assertLen(got, 1) - got_slice_key, got_metrics = got[0] - self.assertEqual(got_slice_key, ()) - key = metric_types.MetricKey(name=metric_name, example_weighted=True) - self.assertDictElementsAlmostEqual( - got_metrics, {key: expected_value}, places=5 - ) - - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(result, check_result, label='result') - - @parameterized.named_parameters( - ('auc', 'auc', 0.8571428), - ('auc_pr', 'auc_pr', 0.77369833), - ('true_positives', 'true_positives', 1.4), - ('false_positives', 'false_positives', 0.6), - ('true_negatives', 'true_negatives', 1.0), - ('false_negatives', 'false_negatives', 0.0), - ) - def testMetricsWithFractionalLabels(self, metric_name, expected_value): - computations = tf_metric_wrapper.tf_metric_computations( - [self._tf_metric_by_name(metric_name)] + def testMetricsWithFractionalLabels(self, metric_name, expected_value): + computations = tf_metric_wrapper.tf_metric_computations( + [self._tf_metric_by_name(metric_name)] + ) + histogram = computations[0] + matrix = computations[1] + metric = computations[2] + + # The following examples will be expanded to: + # + # prediction | label | weight + # 0.0 | - | 1.0 + # 0.7 | - | 0.4 + # 0.7 | + | 0.6 + # 1.0 | - | 0.2 + # 1.0 | + | 0.8 + example1 = { + "labels": np.array([0.0]), + "predictions": np.array([0.0]), + "example_weights": np.array([1.0]), + } + example2 = { + "labels": np.array([0.6]), + "predictions": np.array([0.7]), + "example_weights": np.array([1.0]), + } + example3 = { + "labels": np.array([0.8]), + "predictions": np.array([1.0]), + "example_weights": np.array([1.0]), + } + + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "Create" >> beam.Create([example1, example2, example3]) + | "Process" >> beam.Map(metric_util.to_standard_metric_inputs) + | "AddSlice" >> beam.Map(lambda x: ((), x)) + | "ComputeHistogram" >> beam.CombinePerKey(histogram.combiner) + | "ComputeConfusionMatrix" + >> beam.Map(lambda x: (x[0], matrix.result(x[1]))) + | "ComputeMetric" >> beam.Map(lambda x: (x[0], metric.result(x[1]))) + ) + + # pylint: enable=no-value-for-parameter + + def check_result(got): + try: + self.assertLen(got, 1) + got_slice_key, got_metrics = got[0] + self.assertEqual(got_slice_key, ()) + key = metric_types.MetricKey(name=metric_name) + self.assertDictElementsAlmostEqual( + got_metrics, {key: expected_value}, places=5 + ) + + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(result, check_result, label="result") + + @parameterized.named_parameters( + ("precision@2", "precision", 2, 1.6 / (1.6 + 3.2)), + ("recall@2", "recall", 2, 1.6 / (1.6 + 0.8)), + ("precision@3", "precision", 3, 1.9 / (1.9 + 5.3)), + ("recall@3", "recall", 3, 1.9 / (1.9 + 0.5)), ) - histogram = computations[0] - matrix = computations[1] - metric = computations[2] - - # The following examples will be expanded to: - # - # prediction | label | weight - # 0.0 | - | 1.0 - # 0.7 | - | 0.4 - # 0.7 | + | 0.6 - # 1.0 | - | 0.2 - # 1.0 | + | 0.8 - example1 = { - 'labels': np.array([0.0]), - 'predictions': np.array([0.0]), - 'example_weights': np.array([1.0]), - } - example2 = { - 'labels': np.array([0.6]), - 'predictions': np.array([0.7]), - 'example_weights': np.array([1.0]), - } - example3 = { - 'labels': np.array([0.8]), - 'predictions': np.array([1.0]), - 'example_weights': np.array([1.0]), - } - - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'Create' >> beam.Create([example1, example2, example3]) - | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs) - | 'AddSlice' >> beam.Map(lambda x: ((), x)) - | 'ComputeHistogram' >> beam.CombinePerKey(histogram.combiner) - | 'ComputeConfusionMatrix' - >> beam.Map(lambda x: (x[0], matrix.result(x[1]))) - | 'ComputeMetric' >> beam.Map(lambda x: (x[0], metric.result(x[1]))) - ) - - # pylint: enable=no-value-for-parameter - - def check_result(got): - try: - self.assertLen(got, 1) - got_slice_key, got_metrics = got[0] - self.assertEqual(got_slice_key, ()) - key = metric_types.MetricKey(name=metric_name) - self.assertDictElementsAlmostEqual( - got_metrics, {key: expected_value}, places=5 - ) - - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(result, check_result, label='result') - - @parameterized.named_parameters( - ('precision@2', 'precision', 2, 1.6 / (1.6 + 3.2)), - ('recall@2', 'recall', 2, 1.6 / (1.6 + 0.8)), - ('precision@3', 'precision', 3, 1.9 / (1.9 + 5.3)), - ('recall@3', 'recall', 3, 1.9 / (1.9 + 0.5)), - ) - def testMultiClassMetricsUsingConfusionMatrix( - self, metric_name, top_k, expected_value - ): - computations = tf_metric_wrapper.tf_metric_computations( - [self._tf_metric_by_name(metric_name)], - sub_key=metric_types.SubKey(top_k=top_k), - example_weighted=True, + def testMultiClassMetricsUsingConfusionMatrix( + self, metric_name, top_k, expected_value + ): + computations = tf_metric_wrapper.tf_metric_computations( + [self._tf_metric_by_name(metric_name)], + sub_key=metric_types.SubKey(top_k=top_k), + example_weighted=True, + ) + histogram = computations[0] + matrix = computations[1] + metric = computations[2] + + # top_k = 2 + # TP = 0.5*0 + 0.7*1 + 0.9*1 + 0.3*0 = 1.6 + # FP = 0.5*2 + 0.7*1 + 0.9*1 + 0.3*2 = 3.2 + # FN = 0.5*1 + 0.7*0 + 0.9*0 + 0.3*1 = 0.8 + # + # top_k = 3 + # TP = 0.5*0 + 0.7*1 + 0.9*1 + 0.3*1 = 1.9 + # FP = 0.5*3 + 0.7*2 + 0.9*2 + 0.3*2 = 5.3 + # FN = 0.5*1 + 0.7*0 + 0.9*0 + 0.3*0 = 0.5 + example1 = { + "labels": np.array([2]), + "predictions": np.array([0.1, 0.2, 0.1, 0.25, 0.35]), + "example_weights": np.array([0.5]), + } + example2 = { + "labels": np.array([1]), + "predictions": np.array([0.2, 0.3, 0.05, 0.15, 0.3]), + "example_weights": np.array([0.7]), + } + example3 = { + "labels": np.array([3]), + "predictions": np.array([0.01, 0.2, 0.09, 0.5, 0.2]), + "example_weights": np.array([0.9]), + } + example4 = { + "labels": np.array([1]), + "predictions": np.array([0.3, 0.2, 0.05, 0.4, 0.05]), + # This tests that multi-dimensional weights are allowed. + "example_weights": np.array([0.3, 0.3, 0.3, 0.3, 0.3]), + } + + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "Create" >> beam.Create([example1, example2, example3, example4]) + | "Process" >> beam.Map(metric_util.to_standard_metric_inputs) + | "AddSlice" >> beam.Map(lambda x: ((), x)) + | "ComputeHistogram" >> beam.CombinePerKey(histogram.combiner) + | "ComputeConfusionMatrix" + >> beam.Map(lambda x: (x[0], matrix.result(x[1]))) + | "ComputeMetric" >> beam.Map(lambda x: (x[0], metric.result(x[1]))) + ) + + # pylint: enable=no-value-for-parameter + + def check_result(got): + try: + self.assertLen(got, 1) + got_slice_key, got_metrics = got[0] + self.assertEqual(got_slice_key, ()) + key = metric_types.MetricKey( + name=metric_name, + sub_key=metric_types.SubKey(top_k=top_k), + example_weighted=True, + ) + self.assertDictElementsAlmostEqual( + got_metrics, {key: expected_value}, places=5 + ) + + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(result, check_result, label="result") + + @parameterized.named_parameters( + ("precision@2", "precision@2", 1.6 / (1.6 + 3.2)), + ("recall@2", "recall@2", 1.6 / (1.6 + 0.8)), + ("precision@3", "precision@3", 1.9 / (1.9 + 5.3)), + ("recall@3", "recall@3", 1.9 / (1.9 + 0.5)), ) - histogram = computations[0] - matrix = computations[1] - metric = computations[2] - - # top_k = 2 - # TP = 0.5*0 + 0.7*1 + 0.9*1 + 0.3*0 = 1.6 - # FP = 0.5*2 + 0.7*1 + 0.9*1 + 0.3*2 = 3.2 - # FN = 0.5*1 + 0.7*0 + 0.9*0 + 0.3*1 = 0.8 - # - # top_k = 3 - # TP = 0.5*0 + 0.7*1 + 0.9*1 + 0.3*1 = 1.9 - # FP = 0.5*3 + 0.7*2 + 0.9*2 + 0.3*2 = 5.3 - # FN = 0.5*1 + 0.7*0 + 0.9*0 + 0.3*0 = 0.5 - example1 = { - 'labels': np.array([2]), - 'predictions': np.array([0.1, 0.2, 0.1, 0.25, 0.35]), - 'example_weights': np.array([0.5]), - } - example2 = { - 'labels': np.array([1]), - 'predictions': np.array([0.2, 0.3, 0.05, 0.15, 0.3]), - 'example_weights': np.array([0.7]), - } - example3 = { - 'labels': np.array([3]), - 'predictions': np.array([0.01, 0.2, 0.09, 0.5, 0.2]), - 'example_weights': np.array([0.9]), - } - example4 = { - 'labels': np.array([1]), - 'predictions': np.array([0.3, 0.2, 0.05, 0.4, 0.05]), - # This tests that multi-dimensional weights are allowed. - 'example_weights': np.array([0.3, 0.3, 0.3, 0.3, 0.3]), - } - - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'Create' >> beam.Create([example1, example2, example3, example4]) - | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs) - | 'AddSlice' >> beam.Map(lambda x: ((), x)) - | 'ComputeHistogram' >> beam.CombinePerKey(histogram.combiner) - | 'ComputeConfusionMatrix' - >> beam.Map(lambda x: (x[0], matrix.result(x[1]))) - | 'ComputeMetric' >> beam.Map(lambda x: (x[0], metric.result(x[1]))) - ) - - # pylint: enable=no-value-for-parameter - - def check_result(got): - try: - self.assertLen(got, 1) - got_slice_key, got_metrics = got[0] - self.assertEqual(got_slice_key, ()) - key = metric_types.MetricKey( - name=metric_name, - sub_key=metric_types.SubKey(top_k=top_k), - example_weighted=True, - ) - self.assertDictElementsAlmostEqual( - got_metrics, {key: expected_value}, places=5 - ) - - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(result, check_result, label='result') - - @parameterized.named_parameters( - ('precision@2', 'precision@2', 1.6 / (1.6 + 3.2)), - ('recall@2', 'recall@2', 1.6 / (1.6 + 0.8)), - ('precision@3', 'precision@3', 1.9 / (1.9 + 5.3)), - ('recall@3', 'recall@3', 1.9 / (1.9 + 0.5)), - ) - def testMultiClassMetricsUsingKerasConfig(self, metric_name, expected_value): - metric = tf_metric_wrapper.tf_metric_computations( - [self._tf_metric_by_name(metric_name)], example_weighted=True - )[0] - - # top_k = 2 - # TP = 0.5*0 + 0.7*1 + 0.9*1 + 0.3*0 = 1.6 - # FP = 0.5*2 + 0.7*1 + 0.9*1 + 0.3*2 = 3.2 - # FN = 0.5*1 + 0.7*0 + 0.9*0 + 0.3*1 = 0.8 - # - # top_k = 3 - # TP = 0.5*0 + 0.7*1 + 0.9*1 + 0.3*1 = 1.9 - # FP = 0.5*3 + 0.7*2 + 0.9*2 + 0.3*2 = 5.3 - # FN = 0.5*1 + 0.7*0 + 0.9*0 + 0.3*0 = 0.5 - example1 = { - 'labels': np.array([2]), - 'predictions': np.array([0.1, 0.2, 0.1, 0.25, 0.35]), - 'example_weights': np.array([0.5]), - } - example2 = { - 'labels': np.array([1]), - 'predictions': np.array([0.2, 0.3, 0.05, 0.15, 0.3]), - 'example_weights': np.array([0.7]), - } - example3 = { - 'labels': np.array([3]), - 'predictions': np.array([0.01, 0.2, 0.09, 0.5, 0.2]), - 'example_weights': np.array([0.9]), - } - example4 = { - 'labels': np.array([1]), - 'predictions': np.array([0.3, 0.2, 0.05, 0.4, 0.05]), - 'example_weights': np.array([0.3]), - } - - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'Create' >> beam.Create([example1, example2, example3, example4]) - | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs) - | 'AddSlice' >> beam.Map(lambda x: ((), x)) - | 'Combine' >> beam.CombinePerKey(metric.combiner) - ) - - # pylint: enable=no-value-for-parameter - - def check_result(got): - try: - self.assertLen(got, 1) - got_slice_key, got_metrics = got[0] - self.assertEqual(got_slice_key, ()) - top_k = int(metric_name.split('@')[1]) - key = metric_types.MetricKey( - name=metric_name, - sub_key=metric_types.SubKey(top_k=top_k), - example_weighted=True, - ) - self.assertDictElementsAlmostEqual( - got_metrics, {key: expected_value}, places=5 - ) - - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(result, check_result, label='result') + def testMultiClassMetricsUsingKerasConfig(self, metric_name, expected_value): + metric = tf_metric_wrapper.tf_metric_computations( + [self._tf_metric_by_name(metric_name)], example_weighted=True + )[0] + + # top_k = 2 + # TP = 0.5*0 + 0.7*1 + 0.9*1 + 0.3*0 = 1.6 + # FP = 0.5*2 + 0.7*1 + 0.9*1 + 0.3*2 = 3.2 + # FN = 0.5*1 + 0.7*0 + 0.9*0 + 0.3*1 = 0.8 + # + # top_k = 3 + # TP = 0.5*0 + 0.7*1 + 0.9*1 + 0.3*1 = 1.9 + # FP = 0.5*3 + 0.7*2 + 0.9*2 + 0.3*2 = 5.3 + # FN = 0.5*1 + 0.7*0 + 0.9*0 + 0.3*0 = 0.5 + example1 = { + "labels": np.array([2]), + "predictions": np.array([0.1, 0.2, 0.1, 0.25, 0.35]), + "example_weights": np.array([0.5]), + } + example2 = { + "labels": np.array([1]), + "predictions": np.array([0.2, 0.3, 0.05, 0.15, 0.3]), + "example_weights": np.array([0.7]), + } + example3 = { + "labels": np.array([3]), + "predictions": np.array([0.01, 0.2, 0.09, 0.5, 0.2]), + "example_weights": np.array([0.9]), + } + example4 = { + "labels": np.array([1]), + "predictions": np.array([0.3, 0.2, 0.05, 0.4, 0.05]), + "example_weights": np.array([0.3]), + } + + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "Create" >> beam.Create([example1, example2, example3, example4]) + | "Process" >> beam.Map(metric_util.to_standard_metric_inputs) + | "AddSlice" >> beam.Map(lambda x: ((), x)) + | "Combine" >> beam.CombinePerKey(metric.combiner) + ) + + # pylint: enable=no-value-for-parameter + + def check_result(got): + try: + self.assertLen(got, 1) + got_slice_key, got_metrics = got[0] + self.assertEqual(got_slice_key, ()) + top_k = int(metric_name.split("@")[1]) + key = metric_types.MetricKey( + name=metric_name, + sub_key=metric_types.SubKey(top_k=top_k), + example_weighted=True, + ) + self.assertDictElementsAlmostEqual( + got_metrics, {key: expected_value}, places=5 + ) + + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(result, check_result, label="result") class NonConfusionMatrixMetricsTest( test_util.TensorflowModelAnalysisTest, parameterized.TestCase ): - - def testSimpleMetric(self): - computation = tf_metric_wrapper.tf_metric_computations( - [tf_keras.metrics.MeanSquaredError(name='mse')] - )[0] - - example = { - 'labels': [0, 0, 1, 1], - 'predictions': [0, 0.5, 0.3, 0.9], - 'example_weights': [1.0], - } - - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'Create' >> beam.Create([example]) - | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs) - | 'AddSlice' >> beam.Map(lambda x: ((), x)) - | 'Combine' >> beam.CombinePerKey(computation.combiner) - ) - - # pylint: enable=no-value-for-parameter - - def check_result(got): - try: - self.assertLen(got, 1) - got_slice_key, got_metrics = got[0] - self.assertEqual(got_slice_key, ()) - mse_key = metric_types.MetricKey(name='mse') - self.assertDictElementsAlmostEqual(got_metrics, {mse_key: 0.1875}) - - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(result, check_result, label='result') - - def testSparseMetric(self): - computation = tf_metric_wrapper.tf_metric_computations([ - tf_keras.metrics.SparseCategoricalCrossentropy( - name='sparse_categorical_crossentropy' - ) - ])[0] - - # Simulate a multi-class problem with 3 labels. - example = { - 'labels': [1], - 'predictions': [0.3, 0.6, 0.1], - 'example_weights': [1.0], - } - - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'Create' >> beam.Create([example]) - | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs) - | 'AddSlice' >> beam.Map(lambda x: ((), x)) - | 'Combine' >> beam.CombinePerKey(computation.combiner) - ) - - # pylint: enable=no-value-for-parameter - - def check_result(got): - try: - self.assertLen(got, 1) - got_slice_key, got_metrics = got[0] - self.assertEqual(got_slice_key, ()) - key = metric_types.MetricKey(name='sparse_categorical_crossentropy') - # 0*log(.3) -1*log(0.6)-0*log(.1) = 0.51 - self.assertDictElementsAlmostEqual(got_metrics, {key: 0.51083}) - - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(result, check_result, label='result') - - def testRaisesErrorForInvalidNonSparseSettings(self): - with self.assertRaises(ValueError): - tf_metric_wrapper.tf_metric_computations( - [ - tf_keras.metrics.SparseCategoricalCrossentropy( - name='sparse_categorical_crossentropy' - ) - ], - aggregation_type=metric_types.AggregationType(micro_average=True), - ) - - def testMetricWithClassWeights(self): - computation = tf_metric_wrapper.tf_metric_computations( - [tf_keras.metrics.MeanSquaredError(name='mse')], - aggregation_type=metric_types.AggregationType(micro_average=True), - class_weights={0: 0.1, 1: 0.2, 2: 0.3, 3: 0.4}, - )[0] - - # Simulate a multi-class problem with 4 labels. The use of class weights - # implies micro averaging which only makes sense for multi-class metrics. - example = { - 'labels': [0, 0, 1, 0], - 'predictions': [0, 0.5, 0.3, 0.9], - 'example_weights': [1.0], - } - - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'Create' >> beam.Create([example]) - | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs) - | 'AddSlice' >> beam.Map(lambda x: ((), x)) - | 'Combine' >> beam.CombinePerKey(computation.combiner) - ) - - # pylint: enable=no-value-for-parameter - - def check_result(got): - try: - self.assertLen(got, 1) - got_slice_key, got_metrics = got[0] - self.assertEqual(got_slice_key, ()) - mse_key = metric_types.MetricKey(name='mse') - # numerator = (0.1*0**2 + 0.2*0.5**2 + 0.3*0.7**2 + 0.4*0.9**2) - # denominator = (.1 + .2 + 0.3 + 0.4) - # numerator / denominator = 0.521 - self.assertDictElementsAlmostEqual(got_metrics, {mse_key: 0.521}) - - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(result, check_result, label='result') - - def testCustomTFMetric(self): - metric = tf_metric_wrapper.tf_metric_computations( - [_CustomMetric()], example_weighted=True - )[0] - - example1 = {'labels': [0.0], 'predictions': [0.2], 'example_weights': [1.0]} - example2 = {'labels': [0.0], 'predictions': [0.8], 'example_weights': [1.0]} - example3 = {'labels': [0.0], 'predictions': [0.5], 'example_weights': [2.0]} - - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'Create' >> beam.Create([example1, example2, example3]) - | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs) - | 'AddSlice' >> beam.Map(lambda x: ((), x)) - | 'Combine' >> beam.CombinePerKey(metric.combiner) - ) - - # pylint: enable=no-value-for-parameter - - def check_result(got): - try: - self.assertLen(got, 1) - got_slice_key, got_metrics = got[0] - self.assertEqual(got_slice_key, ()) - - custom_key = metric_types.MetricKey( - name='custom', example_weighted=True - ) - self.assertDictElementsAlmostEqual( - got_metrics, - {custom_key: (0.2 + 0.8 + 2 * 0.5) / (1.0 + 1.0 + 2.0)}, - ) - - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(result, check_result, label='result') - - def testCustomConfusionMatrixTFMetric(self): - metric = tf_metric_wrapper.tf_metric_computations( - [_CustomConfusionMatrixMetric()] - )[0] - - # tp = 1 - # fp = 1 - example1 = {'labels': [0.0], 'predictions': [0.7], 'example_weights': [1.0]} - example2 = {'labels': [1.0], 'predictions': [0.8], 'example_weights': [1.0]} - - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'Create' >> beam.Create([example1, example2]) - | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs) - | 'AddSlice' >> beam.Map(lambda x: ((), x)) - | 'Combine' >> beam.CombinePerKey(metric.combiner) - ) - - # pylint: enable=no-value-for-parameter - - def check_result(got): - try: - self.assertLen(got, 1) - got_slice_key, got_metrics = got[0] - self.assertEqual(got_slice_key, ()) - - custom_key = metric_types.MetricKey(name='custom') - self.assertDictElementsAlmostEqual( - got_metrics, {custom_key: 1.0 / (1.0 + 1.0)} - ) - - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(result, check_result, label='result') - - @parameterized.named_parameters(*[ - dict( - testcase_name='within_example', - example_indices=[0], - # label_sum = (1 - 1 - 1 - 1) * 1.0 = -2.0 - # pred_sum = (0.1 + 0.2 + 0.3 + 0.0) = 0.6 - # weights_total = 1.0 * 4 = 4.0 - expected={ - metric_types.MetricKey( - name='custom_label', example_weighted=True - ): (-2.0 / 4.0), - metric_types.MetricKey( - name='custom_pred', example_weighted=True - ): (0.6 / 4.0), - }, - ), - dict( - testcase_name='across_examples', - # label_sum = (1 - 1 - 1 - 1) * 1.0 + - # (1 + 2 - 1.0 - 1) * 1.0 + - # (1 + 2 + 3 - 1) * 2.0 - # = 9.0 - # - # pred_sum = (0.1 + 0.2 + 0.3 + 0.0) * 1.0 + - # (0.1 + 0.2 + 0.0 - 1.0) * 1.0 + - # (0.1 + 0.2 + 0.3 - 1.0) * 2.0 - # = -0.9 - # - # weights_total = (1.0 * 4 + 1.0 * 4 + 2.0 * 4) = 16.0 - example_indices=[0, 1, 2], - expected={ - metric_types.MetricKey( - name='custom_label', example_weighted=True - ): (9.0 / 16.0), - metric_types.MetricKey( - name='custom_pred', example_weighted=True - ): (-0.9 / 16.0), - }, - ), - ]) - def testCustomTFMetricWithPadding(self, example_indices, expected): - computation = tf_metric_wrapper.tf_metric_computations( - [ - _CustomMetric(name='custom_label', update_y_pred=False), - _CustomMetric(name='custom_pred', update_y_pred=True), - ], - eval_config=config_pb2.EvalConfig( - model_specs=[ - config_pb2.ModelSpec( - padding_options=config_pb2.PaddingOptions( - label_int_padding=-1, - prediction_float_padding=-1.0, - ) + def testSimpleMetric(self): + computation = tf_metric_wrapper.tf_metric_computations( + [tf_keras.metrics.MeanSquaredError(name="mse")] + )[0] + + example = { + "labels": [0, 0, 1, 1], + "predictions": [0, 0.5, 0.3, 0.9], + "example_weights": [1.0], + } + + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "Create" >> beam.Create([example]) + | "Process" >> beam.Map(metric_util.to_standard_metric_inputs) + | "AddSlice" >> beam.Map(lambda x: ((), x)) + | "Combine" >> beam.CombinePerKey(computation.combiner) + ) + + # pylint: enable=no-value-for-parameter + + def check_result(got): + try: + self.assertLen(got, 1) + got_slice_key, got_metrics = got[0] + self.assertEqual(got_slice_key, ()) + mse_key = metric_types.MetricKey(name="mse") + self.assertDictElementsAlmostEqual(got_metrics, {mse_key: 0.1875}) + + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(result, check_result, label="result") + + def testSparseMetric(self): + computation = tf_metric_wrapper.tf_metric_computations( + [ + tf_keras.metrics.SparseCategoricalCrossentropy( + name="sparse_categorical_crossentropy" ) ] - ), - example_weighted=True, - )[0] - - examples = [ - { - 'labels': np.array([1], dtype=np.int64), - 'predictions': np.array([0.1, 0.2, 0.3, 0.0]), - 'example_weights': np.array([1.0]), - }, - { - 'labels': np.array([1, 2], dtype=np.int64), - 'predictions': np.array([0.1, 0.2, 0.0]), - 'example_weights': np.array([1.0]), - }, - { - 'labels': np.array([1, 2, 3], dtype=np.int64), - 'predictions': np.array([0.1, 0.2, 0.3]), - 'example_weights': np.array([2.0]), - }, - ] - - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'Create' >> beam.Create([examples[i] for i in example_indices]) - | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs) - | 'AddSlice' >> beam.Map(lambda x: ((), x)) - | 'Combine' >> beam.CombinePerKey(computation.combiner) - ) - - # pylint: enable=no-value-for-parameter - - def check_result(got): - try: - self.assertLen(got, 1) - got_slice_key, got_metrics = got[0] - self.assertEqual(got_slice_key, ()) - self.assertDictElementsAlmostEqual(got_metrics, expected) - - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(result, check_result, label='result') - - def testMultiOutputTFMetric(self): - computation = tf_metric_wrapper.tf_metric_computations({ - 'output_name': [tf_keras.metrics.MeanSquaredError(name='mse')], - })[0] - - extracts = { - 'labels': { - 'output_name': [0, 0, 1, 1], - }, - 'predictions': { - 'output_name': [0, 0.5, 0.3, 0.9], - }, - 'example_weights': {'output_name': [1.0]}, - } - - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'Create' >> beam.Create([extracts]) - | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs) - | 'AddSlice' >> beam.Map(lambda x: ((), x)) - | 'Combine' >> beam.CombinePerKey(computation.combiner) - ) - - # pylint: enable=no-value-for-parameter - - def check_result(got): - try: - self.assertLen(got, 1) - got_slice_key, got_metrics = got[0] - self.assertEqual(got_slice_key, ()) - mse_key = metric_types.MetricKey( - name='mse', output_name='output_name' - ) - self.assertDictElementsAlmostEqual( - got_metrics, - { - mse_key: 0.1875, - }, - ) - - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(result, check_result, label='result') - - def testTFMetricWithDictResult(self): - computation = tf_metric_wrapper.tf_metric_computations({ - 'output_name': [_CustomMeanSquaredError(name='mse')], - })[0] - - extracts = { - 'labels': { - 'output_name': [0, 0, 1, 1], - }, - 'predictions': { - 'output_name': [0, 0.5, 0.3, 0.9], - }, - 'example_weights': {'output_name': [1.0]}, - } - - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'Create' >> beam.Create([extracts]) - | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs) - | 'AddSlice' >> beam.Map(lambda x: ((), x)) - | 'Combine' >> beam.CombinePerKey(computation.combiner) - ) - - # pylint: enable=no-value-for-parameter - - def check_result(got): - try: - self.assertLen(got, 1) - got_slice_key, got_metrics = got[0] - self.assertEqual(got_slice_key, ()) - mse_key = metric_types.MetricKey( - name='mse/mse', output_name='output_name' - ) - one_minus_mse_key = metric_types.MetricKey( - name='mse/one_minus_mse', output_name='output_name' - ) - self.assertDictElementsAlmostEqual( - got_metrics, {mse_key: 0.1875, one_minus_mse_key: 0.8125} - ) - - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(result, check_result, label='result') - - def testTFMetricWithClassID(self): - computation = tf_metric_wrapper.tf_metric_computations( - [tf_keras.metrics.MeanSquaredError(name='mse')], - sub_key=metric_types.SubKey(class_id=1), - example_weighted=False, - )[0] - - example1 = { - 'labels': [2], - 'predictions': [0.5, 0.0, 0.5], - 'example_weights': [0.1], # ignored, example_weighted=False - } - example2 = { - 'labels': [0], - 'predictions': [0.2, 0.5, 0.3], - 'example_weights': [0.2], # ignored, example_weighted=False - } - example3 = { - 'labels': [1], - 'predictions': [0.2, 0.3, 0.5], - 'example_weights': [0.3], # ignored, example_weighted=False - } - example4 = { - 'labels': [1], - 'predictions': [0.0, 0.9, 0.1], - 'example_weights': [0.4], # ignored, example_weighted=False - } - - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'Create' >> beam.Create([example1, example2, example3, example4]) - | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs) - | 'AddSlice' >> beam.Map(lambda x: ((), x)) - | 'Combine' >> beam.CombinePerKey(computation.combiner) - ) - - # pylint: enable=no-value-for-parameter - - def check_result(got): - try: - self.assertLen(got, 1) - got_slice_key, got_metrics = got[0] - self.assertEqual(got_slice_key, ()) - mse_key = metric_types.MetricKey( - name='mse', - sub_key=metric_types.SubKey(class_id=1), - example_weighted=False, - ) - self.assertDictElementsAlmostEqual( - got_metrics, - { - mse_key: 0.1875, - }, - ) - - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(result, check_result, label='result') - - def testBatching(self): - computation = tf_metric_wrapper.tf_metric_computations( - [_CustomMetric(), tf_keras.metrics.MeanSquaredError(name='mse')], - desired_batch_size=2, - example_weighted=True, - )[0] - - example1 = {'labels': [0.0], 'predictions': [0.0], 'example_weights': [1.0]} - example2 = {'labels': [0.0], 'predictions': [0.5], 'example_weights': [1.0]} - example3 = {'labels': [1.0], 'predictions': [0.3], 'example_weights': [1.0]} - example4 = {'labels': [1.0], 'predictions': [0.9], 'example_weights': [1.0]} - example5 = {'labels': [1.0], 'predictions': [0.5], 'example_weights': [0.0]} - - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'Create' - >> beam.Create([example1, example2, example3, example4, example5]) - | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs) - | 'AddSlice' >> beam.Map(lambda x: ((), x)) - | 'Combine' >> beam.CombinePerKey(computation.combiner) - ) - - # pylint: enable=no-value-for-parameter - - def check_result(got): - try: - self.assertLen(got, 1, 'got: %s' % got) - got_slice_key, got_metrics = got[0] - self.assertEqual(got_slice_key, ()) - - custom_key = metric_types.MetricKey( - name='custom', example_weighted=True - ) - mse_key = metric_types.MetricKey(name='mse', example_weighted=True) - self.assertDictElementsAlmostEqual( - got_metrics, - { - custom_key: (0.0 + 0.5 + 0.3 + 0.9 + 0.0) / ( - 1.0 + 1.0 + 1.0 + 1.0 + 0.0 - ), - mse_key: 0.1875, - }, - ) - - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(result, check_result, label='result') - - def testMergeAccumulators(self): - computation = tf_metric_wrapper.tf_metric_computations( - [tf_keras.metrics.MeanSquaredError(name='mse')], - desired_batch_size=2, - example_weighted=True, - )[0] - - example1 = {'labels': [0.0], 'predictions': [0.0], 'example_weights': [1.0]} - example2 = {'labels': [0.0], 'predictions': [0.5], 'example_weights': [1.0]} - example3 = {'labels': [1.0], 'predictions': [0.3], 'example_weights': [1.0]} - example4 = {'labels': [1.0], 'predictions': [0.9], 'example_weights': [1.0]} - example5 = {'labels': [1.0], 'predictions': [0.5], 'example_weights': [0.0]} - - computation.combiner.setup() - combiner_inputs = [] - for e in (example1, example2, example3, example4, example5): - combiner_inputs.append(metric_util.to_standard_metric_inputs(e)) - acc1 = computation.combiner.create_accumulator() - acc1 = computation.combiner.add_input(acc1, combiner_inputs[0]) - acc1 = computation.combiner.add_input(acc1, combiner_inputs[1]) - acc1 = computation.combiner.add_input(acc1, combiner_inputs[2]) - acc2 = computation.combiner.create_accumulator() - acc2 = computation.combiner.add_input(acc2, combiner_inputs[3]) - acc2 = computation.combiner.add_input(acc2, combiner_inputs[4]) - acc = computation.combiner.merge_accumulators([acc1, acc2]) - - got_metrics = computation.combiner.extract_output(acc) - mse_key = metric_types.MetricKey(name='mse', example_weighted=True) - self.assertDictElementsAlmostEqual(got_metrics, {mse_key: 0.1875}) + )[0] + + # Simulate a multi-class problem with 3 labels. + example = { + "labels": [1], + "predictions": [0.3, 0.6, 0.1], + "example_weights": [1.0], + } + + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "Create" >> beam.Create([example]) + | "Process" >> beam.Map(metric_util.to_standard_metric_inputs) + | "AddSlice" >> beam.Map(lambda x: ((), x)) + | "Combine" >> beam.CombinePerKey(computation.combiner) + ) + + # pylint: enable=no-value-for-parameter + + def check_result(got): + try: + self.assertLen(got, 1) + got_slice_key, got_metrics = got[0] + self.assertEqual(got_slice_key, ()) + key = metric_types.MetricKey(name="sparse_categorical_crossentropy") + # 0*log(.3) -1*log(0.6)-0*log(.1) = 0.51 + self.assertDictElementsAlmostEqual(got_metrics, {key: 0.51083}) + + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(result, check_result, label="result") + + def testRaisesErrorForInvalidNonSparseSettings(self): + with self.assertRaises(ValueError): + tf_metric_wrapper.tf_metric_computations( + [ + tf_keras.metrics.SparseCategoricalCrossentropy( + name="sparse_categorical_crossentropy" + ) + ], + aggregation_type=metric_types.AggregationType(micro_average=True), + ) + + def testMetricWithClassWeights(self): + computation = tf_metric_wrapper.tf_metric_computations( + [tf_keras.metrics.MeanSquaredError(name="mse")], + aggregation_type=metric_types.AggregationType(micro_average=True), + class_weights={0: 0.1, 1: 0.2, 2: 0.3, 3: 0.4}, + )[0] + + # Simulate a multi-class problem with 4 labels. The use of class weights + # implies micro averaging which only makes sense for multi-class metrics. + example = { + "labels": [0, 0, 1, 0], + "predictions": [0, 0.5, 0.3, 0.9], + "example_weights": [1.0], + } + + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "Create" >> beam.Create([example]) + | "Process" >> beam.Map(metric_util.to_standard_metric_inputs) + | "AddSlice" >> beam.Map(lambda x: ((), x)) + | "Combine" >> beam.CombinePerKey(computation.combiner) + ) + + # pylint: enable=no-value-for-parameter + + def check_result(got): + try: + self.assertLen(got, 1) + got_slice_key, got_metrics = got[0] + self.assertEqual(got_slice_key, ()) + mse_key = metric_types.MetricKey(name="mse") + # numerator = (0.1*0**2 + 0.2*0.5**2 + 0.3*0.7**2 + 0.4*0.9**2) + # denominator = (.1 + .2 + 0.3 + 0.4) + # numerator / denominator = 0.521 + self.assertDictElementsAlmostEqual(got_metrics, {mse_key: 0.521}) + + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(result, check_result, label="result") + + def testCustomTFMetric(self): + metric = tf_metric_wrapper.tf_metric_computations( + [_CustomMetric()], example_weighted=True + )[0] + + example1 = {"labels": [0.0], "predictions": [0.2], "example_weights": [1.0]} + example2 = {"labels": [0.0], "predictions": [0.8], "example_weights": [1.0]} + example3 = {"labels": [0.0], "predictions": [0.5], "example_weights": [2.0]} + + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "Create" >> beam.Create([example1, example2, example3]) + | "Process" >> beam.Map(metric_util.to_standard_metric_inputs) + | "AddSlice" >> beam.Map(lambda x: ((), x)) + | "Combine" >> beam.CombinePerKey(metric.combiner) + ) + + # pylint: enable=no-value-for-parameter + + def check_result(got): + try: + self.assertLen(got, 1) + got_slice_key, got_metrics = got[0] + self.assertEqual(got_slice_key, ()) + + custom_key = metric_types.MetricKey( + name="custom", example_weighted=True + ) + self.assertDictElementsAlmostEqual( + got_metrics, + {custom_key: (0.2 + 0.8 + 2 * 0.5) / (1.0 + 1.0 + 2.0)}, + ) + + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(result, check_result, label="result") + + def testCustomConfusionMatrixTFMetric(self): + metric = tf_metric_wrapper.tf_metric_computations( + [_CustomConfusionMatrixMetric()] + )[0] + + # tp = 1 + # fp = 1 + example1 = {"labels": [0.0], "predictions": [0.7], "example_weights": [1.0]} + example2 = {"labels": [1.0], "predictions": [0.8], "example_weights": [1.0]} + + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "Create" >> beam.Create([example1, example2]) + | "Process" >> beam.Map(metric_util.to_standard_metric_inputs) + | "AddSlice" >> beam.Map(lambda x: ((), x)) + | "Combine" >> beam.CombinePerKey(metric.combiner) + ) + + # pylint: enable=no-value-for-parameter + + def check_result(got): + try: + self.assertLen(got, 1) + got_slice_key, got_metrics = got[0] + self.assertEqual(got_slice_key, ()) + + custom_key = metric_types.MetricKey(name="custom") + self.assertDictElementsAlmostEqual( + got_metrics, {custom_key: 1.0 / (1.0 + 1.0)} + ) + + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(result, check_result, label="result") + + @parameterized.named_parameters( + *[ + dict( + testcase_name="within_example", + example_indices=[0], + # label_sum = (1 - 1 - 1 - 1) * 1.0 = -2.0 + # pred_sum = (0.1 + 0.2 + 0.3 + 0.0) = 0.6 + # weights_total = 1.0 * 4 = 4.0 + expected={ + metric_types.MetricKey( + name="custom_label", example_weighted=True + ): (-2.0 / 4.0), + metric_types.MetricKey(name="custom_pred", example_weighted=True): ( + 0.6 / 4.0 + ), + }, + ), + dict( + testcase_name="across_examples", + # label_sum = (1 - 1 - 1 - 1) * 1.0 + + # (1 + 2 - 1.0 - 1) * 1.0 + + # (1 + 2 + 3 - 1) * 2.0 + # = 9.0 + # + # pred_sum = (0.1 + 0.2 + 0.3 + 0.0) * 1.0 + + # (0.1 + 0.2 + 0.0 - 1.0) * 1.0 + + # (0.1 + 0.2 + 0.3 - 1.0) * 2.0 + # = -0.9 + # + # weights_total = (1.0 * 4 + 1.0 * 4 + 2.0 * 4) = 16.0 + example_indices=[0, 1, 2], + expected={ + metric_types.MetricKey( + name="custom_label", example_weighted=True + ): (9.0 / 16.0), + metric_types.MetricKey(name="custom_pred", example_weighted=True): ( + -0.9 / 16.0 + ), + }, + ), + ] + ) + def testCustomTFMetricWithPadding(self, example_indices, expected): + computation = tf_metric_wrapper.tf_metric_computations( + [ + _CustomMetric(name="custom_label", update_y_pred=False), + _CustomMetric(name="custom_pred", update_y_pred=True), + ], + eval_config=config_pb2.EvalConfig( + model_specs=[ + config_pb2.ModelSpec( + padding_options=config_pb2.PaddingOptions( + label_int_padding=-1, + prediction_float_padding=-1.0, + ) + ) + ] + ), + example_weighted=True, + )[0] + + examples = [ + { + "labels": np.array([1], dtype=np.int64), + "predictions": np.array([0.1, 0.2, 0.3, 0.0]), + "example_weights": np.array([1.0]), + }, + { + "labels": np.array([1, 2], dtype=np.int64), + "predictions": np.array([0.1, 0.2, 0.0]), + "example_weights": np.array([1.0]), + }, + { + "labels": np.array([1, 2, 3], dtype=np.int64), + "predictions": np.array([0.1, 0.2, 0.3]), + "example_weights": np.array([2.0]), + }, + ] + + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "Create" >> beam.Create([examples[i] for i in example_indices]) + | "Process" >> beam.Map(metric_util.to_standard_metric_inputs) + | "AddSlice" >> beam.Map(lambda x: ((), x)) + | "Combine" >> beam.CombinePerKey(computation.combiner) + ) + + # pylint: enable=no-value-for-parameter + + def check_result(got): + try: + self.assertLen(got, 1) + got_slice_key, got_metrics = got[0] + self.assertEqual(got_slice_key, ()) + self.assertDictElementsAlmostEqual(got_metrics, expected) + + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(result, check_result, label="result") + + def testMultiOutputTFMetric(self): + computation = tf_metric_wrapper.tf_metric_computations( + { + "output_name": [tf_keras.metrics.MeanSquaredError(name="mse")], + } + )[0] + + extracts = { + "labels": { + "output_name": [0, 0, 1, 1], + }, + "predictions": { + "output_name": [0, 0.5, 0.3, 0.9], + }, + "example_weights": {"output_name": [1.0]}, + } + + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "Create" >> beam.Create([extracts]) + | "Process" >> beam.Map(metric_util.to_standard_metric_inputs) + | "AddSlice" >> beam.Map(lambda x: ((), x)) + | "Combine" >> beam.CombinePerKey(computation.combiner) + ) + + # pylint: enable=no-value-for-parameter + + def check_result(got): + try: + self.assertLen(got, 1) + got_slice_key, got_metrics = got[0] + self.assertEqual(got_slice_key, ()) + mse_key = metric_types.MetricKey( + name="mse", output_name="output_name" + ) + self.assertDictElementsAlmostEqual( + got_metrics, + { + mse_key: 0.1875, + }, + ) + + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(result, check_result, label="result") + + def testTFMetricWithDictResult(self): + computation = tf_metric_wrapper.tf_metric_computations( + { + "output_name": [_CustomMeanSquaredError(name="mse")], + } + )[0] + + extracts = { + "labels": { + "output_name": [0, 0, 1, 1], + }, + "predictions": { + "output_name": [0, 0.5, 0.3, 0.9], + }, + "example_weights": {"output_name": [1.0]}, + } + + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "Create" >> beam.Create([extracts]) + | "Process" >> beam.Map(metric_util.to_standard_metric_inputs) + | "AddSlice" >> beam.Map(lambda x: ((), x)) + | "Combine" >> beam.CombinePerKey(computation.combiner) + ) + + # pylint: enable=no-value-for-parameter + + def check_result(got): + try: + self.assertLen(got, 1) + got_slice_key, got_metrics = got[0] + self.assertEqual(got_slice_key, ()) + mse_key = metric_types.MetricKey( + name="mse/mse", output_name="output_name" + ) + one_minus_mse_key = metric_types.MetricKey( + name="mse/one_minus_mse", output_name="output_name" + ) + self.assertDictElementsAlmostEqual( + got_metrics, {mse_key: 0.1875, one_minus_mse_key: 0.8125} + ) + + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(result, check_result, label="result") + + def testTFMetricWithClassID(self): + computation = tf_metric_wrapper.tf_metric_computations( + [tf_keras.metrics.MeanSquaredError(name="mse")], + sub_key=metric_types.SubKey(class_id=1), + example_weighted=False, + )[0] + + example1 = { + "labels": [2], + "predictions": [0.5, 0.0, 0.5], + "example_weights": [0.1], # ignored, example_weighted=False + } + example2 = { + "labels": [0], + "predictions": [0.2, 0.5, 0.3], + "example_weights": [0.2], # ignored, example_weighted=False + } + example3 = { + "labels": [1], + "predictions": [0.2, 0.3, 0.5], + "example_weights": [0.3], # ignored, example_weighted=False + } + example4 = { + "labels": [1], + "predictions": [0.0, 0.9, 0.1], + "example_weights": [0.4], # ignored, example_weighted=False + } + + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "Create" >> beam.Create([example1, example2, example3, example4]) + | "Process" >> beam.Map(metric_util.to_standard_metric_inputs) + | "AddSlice" >> beam.Map(lambda x: ((), x)) + | "Combine" >> beam.CombinePerKey(computation.combiner) + ) + + # pylint: enable=no-value-for-parameter + + def check_result(got): + try: + self.assertLen(got, 1) + got_slice_key, got_metrics = got[0] + self.assertEqual(got_slice_key, ()) + mse_key = metric_types.MetricKey( + name="mse", + sub_key=metric_types.SubKey(class_id=1), + example_weighted=False, + ) + self.assertDictElementsAlmostEqual( + got_metrics, + { + mse_key: 0.1875, + }, + ) + + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(result, check_result, label="result") + + def testBatching(self): + computation = tf_metric_wrapper.tf_metric_computations( + [_CustomMetric(), tf_keras.metrics.MeanSquaredError(name="mse")], + desired_batch_size=2, + example_weighted=True, + )[0] + + example1 = {"labels": [0.0], "predictions": [0.0], "example_weights": [1.0]} + example2 = {"labels": [0.0], "predictions": [0.5], "example_weights": [1.0]} + example3 = {"labels": [1.0], "predictions": [0.3], "example_weights": [1.0]} + example4 = {"labels": [1.0], "predictions": [0.9], "example_weights": [1.0]} + example5 = {"labels": [1.0], "predictions": [0.5], "example_weights": [0.0]} + + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "Create" + >> beam.Create([example1, example2, example3, example4, example5]) + | "Process" >> beam.Map(metric_util.to_standard_metric_inputs) + | "AddSlice" >> beam.Map(lambda x: ((), x)) + | "Combine" >> beam.CombinePerKey(computation.combiner) + ) + + # pylint: enable=no-value-for-parameter + + def check_result(got): + try: + self.assertLen(got, 1, "got: %s" % got) + got_slice_key, got_metrics = got[0] + self.assertEqual(got_slice_key, ()) + + custom_key = metric_types.MetricKey( + name="custom", example_weighted=True + ) + mse_key = metric_types.MetricKey(name="mse", example_weighted=True) + self.assertDictElementsAlmostEqual( + got_metrics, + { + custom_key: (0.0 + 0.5 + 0.3 + 0.9 + 0.0) + / (1.0 + 1.0 + 1.0 + 1.0 + 0.0), + mse_key: 0.1875, + }, + ) + + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(result, check_result, label="result") + + def testMergeAccumulators(self): + computation = tf_metric_wrapper.tf_metric_computations( + [tf_keras.metrics.MeanSquaredError(name="mse")], + desired_batch_size=2, + example_weighted=True, + )[0] + + example1 = {"labels": [0.0], "predictions": [0.0], "example_weights": [1.0]} + example2 = {"labels": [0.0], "predictions": [0.5], "example_weights": [1.0]} + example3 = {"labels": [1.0], "predictions": [0.3], "example_weights": [1.0]} + example4 = {"labels": [1.0], "predictions": [0.9], "example_weights": [1.0]} + example5 = {"labels": [1.0], "predictions": [0.5], "example_weights": [0.0]} + + computation.combiner.setup() + combiner_inputs = [] + for e in (example1, example2, example3, example4, example5): + combiner_inputs.append(metric_util.to_standard_metric_inputs(e)) + acc1 = computation.combiner.create_accumulator() + acc1 = computation.combiner.add_input(acc1, combiner_inputs[0]) + acc1 = computation.combiner.add_input(acc1, combiner_inputs[1]) + acc1 = computation.combiner.add_input(acc1, combiner_inputs[2]) + acc2 = computation.combiner.create_accumulator() + acc2 = computation.combiner.add_input(acc2, combiner_inputs[3]) + acc2 = computation.combiner.add_input(acc2, combiner_inputs[4]) + acc = computation.combiner.merge_accumulators([acc1, acc2]) + + got_metrics = computation.combiner.extract_output(acc) + mse_key = metric_types.MetricKey(name="mse", example_weighted=True) + self.assertDictElementsAlmostEqual(got_metrics, {mse_key: 0.1875}) class MixedMetricsTest(test_util.TensorflowModelAnalysisTest): + def testWithMixedMetrics(self): + computations = tf_metric_wrapper.tf_metric_computations( + [ + tf_keras.metrics.AUC(name="auc"), + tf_keras.losses.BinaryCrossentropy(name="binary_crossentropy"), + tf_keras.metrics.MeanSquaredError(name="mse"), + ] + ) + + confusion_histogram = computations[0] + confusion_matrix = computations[1].result + confusion_metrics = computations[2].result + non_confusion_metrics = computations[3] + + example1 = { + "labels": np.array([0.0]), + "predictions": np.array([0.0]), + "example_weights": np.array([1.0]), + } + example2 = { + "labels": np.array([0.0]), + "predictions": np.array([0.5]), + "example_weights": np.array([1.0]), + } + example3 = { + "labels": np.array([1.0]), + "predictions": np.array([0.3]), + "example_weights": np.array([1.0]), + } + example4 = { + "labels": np.array([1.0]), + "predictions": np.array([0.9]), + "example_weights": np.array([1.0]), + } + + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + sliced_examples = ( + pipeline + | "Create" >> beam.Create([example1, example2, example3, example4]) + | "Process" >> beam.Map(metric_util.to_standard_metric_inputs) + | "AddSlice" >> beam.Map(lambda x: ((), x)) + ) + + confusion_result = ( + sliced_examples + | "ComputeHistogram" >> beam.CombinePerKey(confusion_histogram.combiner) + | "ComputeConfusionMatrix" + >> beam.Map(lambda x: (x[0], confusion_matrix(x[1]))) + | "ComputeMetric" >> beam.Map(lambda x: (x[0], confusion_metrics(x[1]))) + ) + + non_confusion_result = sliced_examples | "Combine" >> beam.CombinePerKey( + non_confusion_metrics.combiner + ) + + # pylint: enable=no-value-for-parameter + + def check_confusion_result(got): + try: + self.assertLen(got, 1) + got_slice_key, got_metrics = got[0] + self.assertEqual(got_slice_key, ()) + auc_key = metric_types.MetricKey(name="auc") + self.assertDictElementsAlmostEqual( + got_metrics, {auc_key: 0.75}, places=5 + ) + + except AssertionError as err: + raise util.BeamAssertException(err) + + def check_non_confusion_result(got): + try: + self.assertLen(got, 1) + got_slice_key, got_metrics = got[0] + self.assertEqual(got_slice_key, ()) + mse_key = metric_types.MetricKey(name="mse") + binary_crossentropy_key = metric_types.MetricKey( + name="binary_crossentropy" + ) + self.assertDictElementsAlmostEqual( + got_metrics, + {mse_key: 0.1875, binary_crossentropy_key: 0.50061995}, + places=5, + ) + + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that( + confusion_result, check_confusion_result, label="confusion" + ) + util.assert_that( + non_confusion_result, + check_non_confusion_result, + label="non_confusion", + ) + - def testWithMixedMetrics(self): - computations = tf_metric_wrapper.tf_metric_computations([ - tf_keras.metrics.AUC(name='auc'), - tf_keras.losses.BinaryCrossentropy(name='binary_crossentropy'), - tf_keras.metrics.MeanSquaredError(name='mse'), - ]) - - confusion_histogram = computations[0] - confusion_matrix = computations[1].result - confusion_metrics = computations[2].result - non_confusion_metrics = computations[3] - - example1 = { - 'labels': np.array([0.0]), - 'predictions': np.array([0.0]), - 'example_weights': np.array([1.0]), - } - example2 = { - 'labels': np.array([0.0]), - 'predictions': np.array([0.5]), - 'example_weights': np.array([1.0]), - } - example3 = { - 'labels': np.array([1.0]), - 'predictions': np.array([0.3]), - 'example_weights': np.array([1.0]), - } - example4 = { - 'labels': np.array([1.0]), - 'predictions': np.array([0.9]), - 'example_weights': np.array([1.0]), - } - - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - sliced_examples = ( - pipeline - | 'Create' >> beam.Create([example1, example2, example3, example4]) - | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs) - | 'AddSlice' >> beam.Map(lambda x: ((), x)) - ) - - confusion_result = ( - sliced_examples - | 'ComputeHistogram' - >> beam.CombinePerKey(confusion_histogram.combiner) - | 'ComputeConfusionMatrix' - >> beam.Map(lambda x: (x[0], confusion_matrix(x[1]))) - | 'ComputeMetric' - >> beam.Map(lambda x: (x[0], confusion_metrics(x[1]))) - ) - - non_confusion_result = sliced_examples | 'Combine' >> beam.CombinePerKey( - non_confusion_metrics.combiner - ) - - # pylint: enable=no-value-for-parameter - - def check_confusion_result(got): - try: - self.assertLen(got, 1) - got_slice_key, got_metrics = got[0] - self.assertEqual(got_slice_key, ()) - auc_key = metric_types.MetricKey(name='auc') - self.assertDictElementsAlmostEqual( - got_metrics, {auc_key: 0.75}, places=5 - ) - - except AssertionError as err: - raise util.BeamAssertException(err) - - def check_non_confusion_result(got): - try: - self.assertLen(got, 1) - got_slice_key, got_metrics = got[0] - self.assertEqual(got_slice_key, ()) - mse_key = metric_types.MetricKey(name='mse') - binary_crossentropy_key = metric_types.MetricKey( - name='binary_crossentropy' - ) - self.assertDictElementsAlmostEqual( - got_metrics, - {mse_key: 0.1875, binary_crossentropy_key: 0.50061995}, - places=5, - ) - - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that( - confusion_result, check_confusion_result, label='confusion' - ) - util.assert_that( - non_confusion_result, - check_non_confusion_result, - label='non_confusion', - ) - - -if __name__ == '__main__': - tf.compat.v1.enable_v2_behavior() - tf.test.main() +if __name__ == "__main__": + tf.compat.v1.enable_v2_behavior() + tf.test.main() diff --git a/tensorflow_model_analysis/metrics/tjur_discrimination.py b/tensorflow_model_analysis/metrics/tjur_discrimination.py index c62239855a..ac0e532684 100644 --- a/tensorflow_model_analysis/metrics/tjur_discrimination.py +++ b/tensorflow_model_analysis/metrics/tjur_discrimination.py @@ -20,38 +20,38 @@ from typing import Any, Dict, Iterable, Optional import apache_beam as beam -from tensorflow_model_analysis.metrics import metric_types -from tensorflow_model_analysis.metrics import metric_util + +from tensorflow_model_analysis.metrics import metric_types, metric_util from tensorflow_model_analysis.proto import config_pb2 -COEFFICIENT_OF_DISCRIMINATION_NAME = 'coefficient_of_discrimination' -RELATIVE_COEFFICIENT_OF_DISCRIMINATION_NAME = ( - 'relative_coefficient_of_discrimination' -) -_TJUR_DISCRIMINATION_NAME = '_tjur_discrimination' +COEFFICIENT_OF_DISCRIMINATION_NAME = "coefficient_of_discrimination" +RELATIVE_COEFFICIENT_OF_DISCRIMINATION_NAME = "relative_coefficient_of_discrimination" +_TJUR_DISCRIMINATION_NAME = "_tjur_discrimination" class CoefficientOfDiscrimination(metric_types.Metric): - """Coefficient of discrimination metric. + """Coefficient of discrimination metric. - The coefficient of discrimination measures the differences between the average - prediction for the positive examples and the average prediction for the - negative examples. + The coefficient of discrimination measures the differences between the average + prediction for the positive examples and the average prediction for the + negative examples. - The formula is: AVG(pred | label = 1) - AVG(pred | label = 0) - More details can be found in the following paper: - https://www.tandfonline.com/doi/abs/10.1198/tast.2009.08210 - """ + The formula is: AVG(pred | label = 1) - AVG(pred | label = 0) + More details can be found in the following paper: + https://www.tandfonline.com/doi/abs/10.1198/tast.2009.08210 + """ - def __init__(self, name: str = COEFFICIENT_OF_DISCRIMINATION_NAME): - """Initializes coefficient of discrimination metric. + def __init__(self, name: str = COEFFICIENT_OF_DISCRIMINATION_NAME): + """Initializes coefficient of discrimination metric. - Args: - name: Metric name. - """ - super().__init__( - metric_util.merge_per_key_computations(_coefficient_of_discrimination), - name=name) + Args: + ---- + name: Metric name. + """ + super().__init__( + metric_util.merge_per_key_computations(_coefficient_of_discrimination), + name=name, + ) metric_types.register_metric(CoefficientOfDiscrimination) @@ -60,75 +60,86 @@ def __init__(self, name: str = COEFFICIENT_OF_DISCRIMINATION_NAME): def _coefficient_of_discrimination( name: str = COEFFICIENT_OF_DISCRIMINATION_NAME, eval_config: Optional[config_pb2.EvalConfig] = None, - model_name: str = '', - output_name: str = '', + model_name: str = "", + output_name: str = "", sub_key: Optional[metric_types.SubKey] = None, aggregation_type: Optional[metric_types.AggregationType] = None, class_weights: Optional[Dict[int, float]] = None, - example_weighted: bool = False) -> metric_types.MetricComputations: - """Returns metric computations for coefficient of discrimination.""" - key = metric_types.MetricKey( - name=name, - model_name=model_name, - output_name=output_name, - sub_key=sub_key, - example_weighted=example_weighted) - - # Compute shared Tjur discrimination metrics. - computations = _tjur_discrimination( - eval_config=eval_config, - model_name=model_name, - output_name=output_name, - aggregation_type=aggregation_type, - class_weights=class_weights, - example_weighted=example_weighted) - # Shared metrics are based on a single computation and key. - tjur_discrimination_key = computations[0].keys[0] - - def result( - metrics: Dict[metric_types.MetricKey, Any] - ) -> Dict[metric_types.MetricKey, float]: - """Returns coefficient of discrimination.""" - metric = metrics[tjur_discrimination_key] - if (metric.total_negative_weighted_labels == 0 or - metric.total_positive_weighted_labels == 0): - value = float('nan') - else: - avg_pos_label = ( - metric.total_positive_weighted_predictions / - metric.total_positive_weighted_labels) - avg_neg_label = ( - metric.total_negative_weighted_predictions / - metric.total_negative_weighted_labels) - value = avg_pos_label - avg_neg_label - return {key: value} - - derived_computation = metric_types.DerivedMetricComputation( - keys=[key], result=result) - computations.append(derived_computation) - return computations + example_weighted: bool = False, +) -> metric_types.MetricComputations: + """Returns metric computations for coefficient of discrimination.""" + key = metric_types.MetricKey( + name=name, + model_name=model_name, + output_name=output_name, + sub_key=sub_key, + example_weighted=example_weighted, + ) + + # Compute shared Tjur discrimination metrics. + computations = _tjur_discrimination( + eval_config=eval_config, + model_name=model_name, + output_name=output_name, + aggregation_type=aggregation_type, + class_weights=class_weights, + example_weighted=example_weighted, + ) + # Shared metrics are based on a single computation and key. + tjur_discrimination_key = computations[0].keys[0] + + def result( + metrics: Dict[metric_types.MetricKey, Any], + ) -> Dict[metric_types.MetricKey, float]: + """Returns coefficient of discrimination.""" + metric = metrics[tjur_discrimination_key] + if ( + metric.total_negative_weighted_labels == 0 + or metric.total_positive_weighted_labels == 0 + ): + value = float("nan") + else: + avg_pos_label = ( + metric.total_positive_weighted_predictions + / metric.total_positive_weighted_labels + ) + avg_neg_label = ( + metric.total_negative_weighted_predictions + / metric.total_negative_weighted_labels + ) + value = avg_pos_label - avg_neg_label + return {key: value} + + derived_computation = metric_types.DerivedMetricComputation( + keys=[key], result=result + ) + computations.append(derived_computation) + return computations class RelativeCoefficientOfDiscrimination(metric_types.Metric): - """Relative coefficient of discrimination metric. + """Relative coefficient of discrimination metric. - The relative coefficient of discrimination measures the ratio between the - average prediction for the positive examples and the average prediction for - the negative examples. This has a very simple intuitive explanation, measuring - how much higher is the prediction going to be for a positive example than for - a negative example. - """ + The relative coefficient of discrimination measures the ratio between the + average prediction for the positive examples and the average prediction for + the negative examples. This has a very simple intuitive explanation, measuring + how much higher is the prediction going to be for a positive example than for + a negative example. + """ - def __init__(self, name: str = RELATIVE_COEFFICIENT_OF_DISCRIMINATION_NAME): - """Initializes relative coefficient of discrimination metric. + def __init__(self, name: str = RELATIVE_COEFFICIENT_OF_DISCRIMINATION_NAME): + """Initializes relative coefficient of discrimination metric. - Args: - name: Metric name. - """ - super().__init__( - metric_util.merge_per_key_computations( - _relative_coefficient_of_discrimination), - name=name) + Args: + ---- + name: Metric name. + """ + super().__init__( + metric_util.merge_per_key_computations( + _relative_coefficient_of_discrimination + ), + name=name, + ) metric_types.register_metric(RelativeCoefficientOfDiscrimination) @@ -137,159 +148,182 @@ def __init__(self, name: str = RELATIVE_COEFFICIENT_OF_DISCRIMINATION_NAME): def _relative_coefficient_of_discrimination( name: str = RELATIVE_COEFFICIENT_OF_DISCRIMINATION_NAME, eval_config: Optional[config_pb2.EvalConfig] = None, - model_name: str = '', - output_name: str = '', + model_name: str = "", + output_name: str = "", aggregation_type: Optional[metric_types.AggregationType] = None, class_weights: Optional[Dict[float, int]] = None, - example_weighted: bool = False) -> metric_types.MetricComputations: - """Returns metric computations for coefficient of discrimination.""" - key = metric_types.MetricKey( - name=name, - model_name=model_name, - output_name=output_name, - example_weighted=example_weighted) - - # Compute shared Tjur discrimination metrics. - computations = _tjur_discrimination( - eval_config=eval_config, - model_name=model_name, - output_name=output_name, - aggregation_type=aggregation_type, - class_weights=class_weights, - example_weighted=example_weighted) - # Shared metrics are based on a single computation and key. - tjur_discrimination_key = computations[0].keys[0] - - def result( - metrics: Dict[metric_types.MetricKey, Any] - ) -> Dict[metric_types.MetricKey, float]: - """Returns coefficient of discrimination.""" - metric = metrics[tjur_discrimination_key] - if (metric.total_negative_weighted_labels == 0 or - metric.total_positive_weighted_labels == 0 or - metric.total_negative_weighted_predictions == 0): - value = float('nan') - else: - avg_pos_label = ( - metric.total_positive_weighted_predictions / - metric.total_positive_weighted_labels) - avg_neg_label = ( - metric.total_negative_weighted_predictions / - metric.total_negative_weighted_labels) - value = avg_pos_label / avg_neg_label - return {key: value} - - derived_computation = metric_types.DerivedMetricComputation( - keys=[key], result=result) - computations.append(derived_computation) - return computations + example_weighted: bool = False, +) -> metric_types.MetricComputations: + """Returns metric computations for coefficient of discrimination.""" + key = metric_types.MetricKey( + name=name, + model_name=model_name, + output_name=output_name, + example_weighted=example_weighted, + ) + + # Compute shared Tjur discrimination metrics. + computations = _tjur_discrimination( + eval_config=eval_config, + model_name=model_name, + output_name=output_name, + aggregation_type=aggregation_type, + class_weights=class_weights, + example_weighted=example_weighted, + ) + # Shared metrics are based on a single computation and key. + tjur_discrimination_key = computations[0].keys[0] + + def result( + metrics: Dict[metric_types.MetricKey, Any], + ) -> Dict[metric_types.MetricKey, float]: + """Returns coefficient of discrimination.""" + metric = metrics[tjur_discrimination_key] + if ( + metric.total_negative_weighted_labels == 0 + or metric.total_positive_weighted_labels == 0 + or metric.total_negative_weighted_predictions == 0 + ): + value = float("nan") + else: + avg_pos_label = ( + metric.total_positive_weighted_predictions + / metric.total_positive_weighted_labels + ) + avg_neg_label = ( + metric.total_negative_weighted_predictions + / metric.total_negative_weighted_labels + ) + value = avg_pos_label / avg_neg_label + return {key: value} + + derived_computation = metric_types.DerivedMetricComputation( + keys=[key], result=result + ) + computations.append(derived_computation) + return computations def _tjur_discrimination( name: str = _TJUR_DISCRIMINATION_NAME, eval_config: Optional[config_pb2.EvalConfig] = None, - model_name: str = '', - output_name: str = '', + model_name: str = "", + output_name: str = "", aggregation_type: Optional[metric_types.AggregationType] = None, class_weights: Optional[Dict[int, float]] = None, - example_weighted: bool = False) -> metric_types.MetricComputations: - """Returns metric computations for Tjur discrimination.""" - key = metric_types.MetricKey( - name=name, - model_name=model_name, - output_name=output_name, - example_weighted=example_weighted) - return [ - metric_types.MetricComputation( - keys=[key], - preprocessors=None, - combiner=_TjurDiscriminationCombiner( - key, - eval_config, - aggregation_type, - class_weights, - example_weighted, - ), - ) - ] + example_weighted: bool = False, +) -> metric_types.MetricComputations: + """Returns metric computations for Tjur discrimination.""" + key = metric_types.MetricKey( + name=name, + model_name=model_name, + output_name=output_name, + example_weighted=example_weighted, + ) + return [ + metric_types.MetricComputation( + keys=[key], + preprocessors=None, + combiner=_TjurDiscriminationCombiner( + key, + eval_config, + aggregation_type, + class_weights, + example_weighted, + ), + ) + ] class _TjurDiscriminationAccumulator: - """Tjur discrimination accumulator.""" + """Tjur discrimination accumulator.""" - __slots__ = [ - 'total_negative_weighted_predictions', 'total_negative_weighted_labels', - 'total_positive_weighted_predictions', 'total_positive_weighted_labels' - ] + __slots__ = [ + "total_negative_weighted_predictions", + "total_negative_weighted_labels", + "total_positive_weighted_predictions", + "total_positive_weighted_labels", + ] - def __init__(self): - self.total_negative_weighted_predictions = 0.0 - self.total_negative_weighted_labels = 0.0 - self.total_positive_weighted_predictions = 0.0 - self.total_positive_weighted_labels = 0.0 + def __init__(self): + self.total_negative_weighted_predictions = 0.0 + self.total_negative_weighted_labels = 0.0 + self.total_positive_weighted_predictions = 0.0 + self.total_positive_weighted_labels = 0.0 class _TjurDiscriminationCombiner(beam.CombineFn): - """Computes min label position metric.""" - - def __init__(self, key: metric_types.MetricKey, - eval_config: Optional[config_pb2.EvalConfig], - aggregation_type: Optional[metric_types.AggregationType], - class_weights: Optional[Dict[int, - float]], example_weighted: bool): - self._key = key - self._eval_config = eval_config - self._aggregation_type = aggregation_type - self._class_weights = class_weights - self._example_weighted = example_weighted - - def create_accumulator(self) -> _TjurDiscriminationAccumulator: - return _TjurDiscriminationAccumulator() - - def add_input( - self, - accumulator: _TjurDiscriminationAccumulator, - element: metric_types.StandardMetricInputs, - ) -> _TjurDiscriminationAccumulator: - for label, prediction, example_weight in ( - metric_util.to_label_prediction_example_weight( + """Computes min label position metric.""" + + def __init__( + self, + key: metric_types.MetricKey, + eval_config: Optional[config_pb2.EvalConfig], + aggregation_type: Optional[metric_types.AggregationType], + class_weights: Optional[Dict[int, float]], + example_weighted: bool, + ): + self._key = key + self._eval_config = eval_config + self._aggregation_type = aggregation_type + self._class_weights = class_weights + self._example_weighted = example_weighted + + def create_accumulator(self) -> _TjurDiscriminationAccumulator: + return _TjurDiscriminationAccumulator() + + def add_input( + self, + accumulator: _TjurDiscriminationAccumulator, + element: metric_types.StandardMetricInputs, + ) -> _TjurDiscriminationAccumulator: + for ( + label, + prediction, + example_weight, + ) in metric_util.to_label_prediction_example_weight( element, eval_config=self._eval_config, model_name=self._key.model_name, output_name=self._key.output_name, aggregation_type=self._aggregation_type, class_weights=self._class_weights, - example_weighted=self._example_weighted)): - label = float(label) - prediction = float(prediction) - example_weight = float(example_weight) - accumulator.total_negative_weighted_labels += ((1.0 - label) * - example_weight) - accumulator.total_positive_weighted_labels += label * example_weight - accumulator.total_negative_weighted_predictions += ((1.0 - label) * - prediction * - example_weight) - accumulator.total_positive_weighted_predictions += ( - label * prediction * example_weight) - return accumulator - - def merge_accumulators( - self, accumulators: Iterable[_TjurDiscriminationAccumulator] - ) -> _TjurDiscriminationAccumulator: - accumulators = iter(accumulators) - result = next(accumulators) - for accumulator in accumulators: - result.total_negative_weighted_predictions += ( - accumulator.total_negative_weighted_predictions) - result.total_negative_weighted_labels += ( - accumulator.total_negative_weighted_labels) - result.total_positive_weighted_predictions += ( - accumulator.total_positive_weighted_predictions) - result.total_positive_weighted_labels += ( - accumulator.total_positive_weighted_labels) - return result - - def extract_output( - self, accumulator: _TjurDiscriminationAccumulator - ) -> Dict[metric_types.MetricKey, _TjurDiscriminationAccumulator]: - return {self._key: accumulator} + example_weighted=self._example_weighted, + ): + label = float(label) + prediction = float(prediction) + example_weight = float(example_weight) + accumulator.total_negative_weighted_labels += (1.0 - label) * example_weight + accumulator.total_positive_weighted_labels += label * example_weight + accumulator.total_negative_weighted_predictions += ( + (1.0 - label) * prediction * example_weight + ) + accumulator.total_positive_weighted_predictions += ( + label * prediction * example_weight + ) + return accumulator + + def merge_accumulators( + self, accumulators: Iterable[_TjurDiscriminationAccumulator] + ) -> _TjurDiscriminationAccumulator: + accumulators = iter(accumulators) + result = next(accumulators) + for accumulator in accumulators: + result.total_negative_weighted_predictions += ( + accumulator.total_negative_weighted_predictions + ) + result.total_negative_weighted_labels += ( + accumulator.total_negative_weighted_labels + ) + result.total_positive_weighted_predictions += ( + accumulator.total_positive_weighted_predictions + ) + result.total_positive_weighted_labels += ( + accumulator.total_positive_weighted_labels + ) + return result + + def extract_output( + self, accumulator: _TjurDiscriminationAccumulator + ) -> Dict[metric_types.MetricKey, _TjurDiscriminationAccumulator]: + return {self._key: accumulator} diff --git a/tensorflow_model_analysis/metrics/tjur_discrimination_test.py b/tensorflow_model_analysis/metrics/tjur_discrimination_test.py index 6577d3eacb..3d0940f8e6 100644 --- a/tensorflow_model_analysis/metrics/tjur_discrimination_test.py +++ b/tensorflow_model_analysis/metrics/tjur_discrimination_test.py @@ -14,188 +14,205 @@ """Tests for Tjur discrimination metrics.""" import math -from absl.testing import parameterized + import apache_beam as beam -from apache_beam.testing import util import numpy as np import tensorflow as tf -from tensorflow_model_analysis.metrics import metric_util -from tensorflow_model_analysis.metrics import tjur_discrimination +from absl.testing import parameterized +from apache_beam.testing import util + +from tensorflow_model_analysis.metrics import metric_util, tjur_discrimination from tensorflow_model_analysis.utils import test_util class TjurDisriminationTest( test_util.TensorflowModelAnalysisTest, parameterized.TestCase ): - - @parameterized.named_parameters( - ('coefficient_of_discrimination', - tjur_discrimination.CoefficientOfDiscrimination(), - (1.2 / 2.0) - (0.8 / 1.0)), - ('relative_coefficient_of_discrimination', - tjur_discrimination.RelativeCoefficientOfDiscrimination(), - (1.2 / 2.0) / (0.8 / 1.0))) - def testTjuDicriminationMetricsWithoutWeights(self, metric, expected_value): - computations = metric.computations() - shared_metrics = computations[0] - metric = computations[1] - - # Positive labels: 0.0 + 1.0 + 1.0 = 2.0 - # Negative labels: 1.0 + 0.0 + 0.0 = 1.0 - # Positive predictions: 0.0 * 0.8 + 1.0 * 0.3 + 1.0 * 0.9 = 1.2 - # Negative predictions: 1.0 * 0.8 + 0.0 * 0.3 + 0.0 * 0.9 = 0.8 - example1 = { - 'labels': np.array([0.0]), - 'predictions': np.array([0.8]), - 'example_weights': np.array([1.0]), - } - example2 = { - 'labels': np.array([1.0]), - 'predictions': np.array([0.3]), - 'example_weights': np.array([1.0]), - } - example3 = { - 'labels': np.array([1.0]), - 'predictions': np.array([0.9]), - 'example_weights': np.array([1.0]), - } - - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'Create' >> beam.Create([example1, example2, example3]) - | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs) - | 'AddSlice' >> beam.Map(lambda x: ((), x)) - | - 'ComputeWeightedTotals' >> beam.CombinePerKey(shared_metrics.combiner) - | 'ComputeMetric' >> beam.Map(lambda x: (x[0], metric.result(x[1])))) - - # pylint: enable=no-value-for-parameter - - def check_result(got): - try: - self.assertLen(got, 1) - got_slice_key, got_metrics = got[0] - self.assertEqual(got_slice_key, ()) - key = metric.keys[0] - self.assertDictElementsAlmostEqual( - got_metrics, {key: expected_value}, places=5) - - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(result, check_result, label='result') - - @parameterized.named_parameters( - ('coefficient_of_discrimination', - tjur_discrimination.CoefficientOfDiscrimination(), - (3.3 / 5.0) - (1.6 / 5.0)), - ('relative_coefficient_of_discrimination', - tjur_discrimination.RelativeCoefficientOfDiscrimination(), - (3.3 / 5.0) / (1.6 / 5.0))) - def testTjuDicriminationMetricsWithWeights(self, metric, expected_value): - computations = metric.computations(example_weighted=True) - shared_metrics = computations[0] - metric = computations[1] - - # Positive labels: 1.0 * 0.0 + 2.0 * 1.0 + 3.0 * 1.0 + 4.0 * 0.0 = 5.0 - # Negative labels: 1.0 * 1.0 + 2.0 * 0.0 + 3.0 * 0.0 + 4.0 * 1.0 = 5.0 - # Positive predictions: 1.0 * 0.0 * 0.8 + 2.0 * 1.0 * 0.3 + 3.0 * 1.0 * 0.9 - # + 4.0 * 0.0 * 0.2 = 3.3 - # Negative predictions: 1.0 * 1.0 * 0.8 + 2.0 * 0.0 * 0.7 + 3.0 * 0.0 * 0.1 - # + 4.0 * 1.0 * 0.2 = 1.6 - example1 = { - 'labels': np.array([0.0]), - 'predictions': np.array([0.8]), - 'example_weights': np.array([1.0]), - } - example2 = { - 'labels': np.array([1.0]), - 'predictions': np.array([0.3]), - 'example_weights': np.array([2.0]), - } - example3 = { - 'labels': np.array([1.0]), - 'predictions': np.array([0.9]), - 'example_weights': np.array([3.0]), - } - example4 = { - 'labels': np.array([0.0]), - 'predictions': np.array([0.2]), - 'example_weights': np.array([4.0]), - } - - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'Create' >> beam.Create([example1, example2, example3, example4]) - | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs) - | 'AddSlice' >> beam.Map(lambda x: ((), x)) - | - 'ComputeWeightedTotals' >> beam.CombinePerKey(shared_metrics.combiner) - | 'ComputeMetric' >> beam.Map(lambda x: (x[0], metric.result(x[1])))) - - # pylint: enable=no-value-for-parameter - - def check_result(got): - try: - self.assertLen(got, 1) - got_slice_key, got_metrics = got[0] - self.assertEqual(got_slice_key, ()) - key = metric.keys[0] - self.assertDictElementsAlmostEqual( - got_metrics, {key: expected_value}, places=5) - - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(result, check_result, label='result') - - @parameterized.named_parameters( - ('coefficient_of_discrimination', - tjur_discrimination.CoefficientOfDiscrimination()), - ('relative_coefficient_of_discrimination', - tjur_discrimination.RelativeCoefficientOfDiscrimination())) - def testTjurDiscriminationMetricsWithNan(self, metric): - computations = metric.computations() - shared_metrics = computations[0] - metric = computations[1] - - example = { - 'labels': np.array([0.0]), - 'predictions': np.array([1.0]), - 'example_weights': np.array([1.0]), - } - - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'Create' >> beam.Create([example]) - | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs) - | 'AddSlice' >> beam.Map(lambda x: ((), x)) - | - 'ComputeWeightedTotals' >> beam.CombinePerKey(shared_metrics.combiner) - | 'ComputeMetric' >> beam.Map(lambda x: (x[0], metric.result(x[1])))) - - # pylint: enable=no-value-for-parameter - - def check_result(got): - try: - self.assertLen(got, 1) - got_slice_key, got_metrics = got[0] - self.assertEqual(got_slice_key, ()) - key = metric.keys[0] - self.assertIn(key, got_metrics) - self.assertTrue(math.isnan(got_metrics[key])) - - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(result, check_result, label='result') - - -if __name__ == '__main__': - tf.test.main() + @parameterized.named_parameters( + ( + "coefficient_of_discrimination", + tjur_discrimination.CoefficientOfDiscrimination(), + (1.2 / 2.0) - (0.8 / 1.0), + ), + ( + "relative_coefficient_of_discrimination", + tjur_discrimination.RelativeCoefficientOfDiscrimination(), + (1.2 / 2.0) / (0.8 / 1.0), + ), + ) + def testTjuDicriminationMetricsWithoutWeights(self, metric, expected_value): + computations = metric.computations() + shared_metrics = computations[0] + metric = computations[1] + + # Positive labels: 0.0 + 1.0 + 1.0 = 2.0 + # Negative labels: 1.0 + 0.0 + 0.0 = 1.0 + # Positive predictions: 0.0 * 0.8 + 1.0 * 0.3 + 1.0 * 0.9 = 1.2 + # Negative predictions: 1.0 * 0.8 + 0.0 * 0.3 + 0.0 * 0.9 = 0.8 + example1 = { + "labels": np.array([0.0]), + "predictions": np.array([0.8]), + "example_weights": np.array([1.0]), + } + example2 = { + "labels": np.array([1.0]), + "predictions": np.array([0.3]), + "example_weights": np.array([1.0]), + } + example3 = { + "labels": np.array([1.0]), + "predictions": np.array([0.9]), + "example_weights": np.array([1.0]), + } + + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "Create" >> beam.Create([example1, example2, example3]) + | "Process" >> beam.Map(metric_util.to_standard_metric_inputs) + | "AddSlice" >> beam.Map(lambda x: ((), x)) + | "ComputeWeightedTotals" >> beam.CombinePerKey(shared_metrics.combiner) + | "ComputeMetric" >> beam.Map(lambda x: (x[0], metric.result(x[1]))) + ) + + # pylint: enable=no-value-for-parameter + + def check_result(got): + try: + self.assertLen(got, 1) + got_slice_key, got_metrics = got[0] + self.assertEqual(got_slice_key, ()) + key = metric.keys[0] + self.assertDictElementsAlmostEqual( + got_metrics, {key: expected_value}, places=5 + ) + + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(result, check_result, label="result") + + @parameterized.named_parameters( + ( + "coefficient_of_discrimination", + tjur_discrimination.CoefficientOfDiscrimination(), + (3.3 / 5.0) - (1.6 / 5.0), + ), + ( + "relative_coefficient_of_discrimination", + tjur_discrimination.RelativeCoefficientOfDiscrimination(), + (3.3 / 5.0) / (1.6 / 5.0), + ), + ) + def testTjuDicriminationMetricsWithWeights(self, metric, expected_value): + computations = metric.computations(example_weighted=True) + shared_metrics = computations[0] + metric = computations[1] + + # Positive labels: 1.0 * 0.0 + 2.0 * 1.0 + 3.0 * 1.0 + 4.0 * 0.0 = 5.0 + # Negative labels: 1.0 * 1.0 + 2.0 * 0.0 + 3.0 * 0.0 + 4.0 * 1.0 = 5.0 + # Positive predictions: 1.0 * 0.0 * 0.8 + 2.0 * 1.0 * 0.3 + 3.0 * 1.0 * 0.9 + # + 4.0 * 0.0 * 0.2 = 3.3 + # Negative predictions: 1.0 * 1.0 * 0.8 + 2.0 * 0.0 * 0.7 + 3.0 * 0.0 * 0.1 + # + 4.0 * 1.0 * 0.2 = 1.6 + example1 = { + "labels": np.array([0.0]), + "predictions": np.array([0.8]), + "example_weights": np.array([1.0]), + } + example2 = { + "labels": np.array([1.0]), + "predictions": np.array([0.3]), + "example_weights": np.array([2.0]), + } + example3 = { + "labels": np.array([1.0]), + "predictions": np.array([0.9]), + "example_weights": np.array([3.0]), + } + example4 = { + "labels": np.array([0.0]), + "predictions": np.array([0.2]), + "example_weights": np.array([4.0]), + } + + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "Create" >> beam.Create([example1, example2, example3, example4]) + | "Process" >> beam.Map(metric_util.to_standard_metric_inputs) + | "AddSlice" >> beam.Map(lambda x: ((), x)) + | "ComputeWeightedTotals" >> beam.CombinePerKey(shared_metrics.combiner) + | "ComputeMetric" >> beam.Map(lambda x: (x[0], metric.result(x[1]))) + ) + + # pylint: enable=no-value-for-parameter + + def check_result(got): + try: + self.assertLen(got, 1) + got_slice_key, got_metrics = got[0] + self.assertEqual(got_slice_key, ()) + key = metric.keys[0] + self.assertDictElementsAlmostEqual( + got_metrics, {key: expected_value}, places=5 + ) + + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(result, check_result, label="result") + + @parameterized.named_parameters( + ( + "coefficient_of_discrimination", + tjur_discrimination.CoefficientOfDiscrimination(), + ), + ( + "relative_coefficient_of_discrimination", + tjur_discrimination.RelativeCoefficientOfDiscrimination(), + ), + ) + def testTjurDiscriminationMetricsWithNan(self, metric): + computations = metric.computations() + shared_metrics = computations[0] + metric = computations[1] + + example = { + "labels": np.array([0.0]), + "predictions": np.array([1.0]), + "example_weights": np.array([1.0]), + } + + with beam.Pipeline() as pipeline: + # pylint: disable=no-value-for-parameter + result = ( + pipeline + | "Create" >> beam.Create([example]) + | "Process" >> beam.Map(metric_util.to_standard_metric_inputs) + | "AddSlice" >> beam.Map(lambda x: ((), x)) + | "ComputeWeightedTotals" >> beam.CombinePerKey(shared_metrics.combiner) + | "ComputeMetric" >> beam.Map(lambda x: (x[0], metric.result(x[1]))) + ) + + # pylint: enable=no-value-for-parameter + + def check_result(got): + try: + self.assertLen(got, 1) + got_slice_key, got_metrics = got[0] + self.assertEqual(got_slice_key, ()) + key = metric.keys[0] + self.assertIn(key, got_metrics) + self.assertTrue(math.isnan(got_metrics[key])) + + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(result, check_result, label="result") + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow_model_analysis/metrics/weighted_example_count.py b/tensorflow_model_analysis/metrics/weighted_example_count.py index 3826684c18..7164301807 100644 --- a/tensorflow_model_analysis/metrics/weighted_example_count.py +++ b/tensorflow_model_analysis/metrics/weighted_example_count.py @@ -13,24 +13,23 @@ # limitations under the License. """Weighted example count metric.""" -from tensorflow_model_analysis.metrics import example_count -from tensorflow_model_analysis.metrics import metric_types +from tensorflow_model_analysis.metrics import example_count, metric_types -WEIGHTED_EXAMPLE_COUNT_NAME = 'weighted_example_count' +WEIGHTED_EXAMPLE_COUNT_NAME = "weighted_example_count" # TODO(b/143180976): Remove. class WeightedExampleCount(example_count.ExampleCount): - """Weighted example count (deprecated - use ExampleCount).""" + """Weighted example count (deprecated - use ExampleCount).""" - def __init__(self, name: str = WEIGHTED_EXAMPLE_COUNT_NAME): - """Initializes weighted example count. + def __init__(self, name: str = WEIGHTED_EXAMPLE_COUNT_NAME): + """Initializes weighted example count. - Args: - name: Metric name. - """ - - super().__init__(name=name) + Args: + ---- + name: Metric name. + """ + super().__init__(name=name) metric_types.register_metric(WeightedExampleCount) diff --git a/tensorflow_model_analysis/notebook/colab/renderer.py b/tensorflow_model_analysis/notebook/colab/renderer.py index e5d0fe995f..f888746b46 100644 --- a/tensorflow_model_analysis/notebook/colab/renderer.py +++ b/tensorflow_model_analysis/notebook/colab/renderer.py @@ -14,14 +14,15 @@ """TFMA API for OSS Colab renderer.""" from typing import Any, Callable, Dict, List, Optional, Union + from tensorflow_model_analysis.notebook.colab import util # See also `loadVulcanizedTemplate` in JS, which adjusts the script path # further at runtime, depending on the environment. def get_trusted_html_for_vulcanized_js(): - """Returns a trusted string of HTML that will load vulcanized_tfma.js.""" - return """ + """Returns a trusted string of HTML that will load vulcanized_tfma.js.""" + return """ """ @@ -29,51 +30,58 @@ def get_trusted_html_for_vulcanized_js(): def render_slicing_metrics( data: List[Dict[str, Union[Dict[str, Any], str]]], config: Dict[str, str], - event_handlers: Optional[Callable[[Dict[str, Union[str, float]]], - None]] = None + event_handlers: Optional[Callable[[Dict[str, Union[str, float]]], None]] = None, ) -> None: - """Renders the slicing metrics view in Colab. + """Renders the slicing metrics view in Colab. - Args: - data: A list of dictionary containing metrics for correpsonding slices. - config: A dictionary of the configuration. - event_handlers: event handlers - """ - util.render_tfma_component( - 'tfma-nb-slicing-metrics', - data, - config, - event_handlers=event_handlers, - trusted_html_for_vulcanized_tfma_js=get_trusted_html_for_vulcanized_js()) + Args: + ---- + data: A list of dictionary containing metrics for correpsonding slices. + config: A dictionary of the configuration. + event_handlers: event handlers + """ + util.render_tfma_component( + "tfma-nb-slicing-metrics", + data, + config, + event_handlers=event_handlers, + trusted_html_for_vulcanized_tfma_js=get_trusted_html_for_vulcanized_js(), + ) -def render_time_series(data: List[Dict[str, Union[Dict[Union[float, str], Any], - str]]], - config: Dict[str, bool]) -> None: - """Renders the time series view in Colab. +def render_time_series( + data: List[Dict[str, Union[Dict[Union[float, str], Any], str]]], + config: Dict[str, bool], +) -> None: + """Renders the time series view in Colab. - Args: - data: A list of dictionary containing metrics for different evaluation runs. - config: A dictionary of the configuration. - """ - util.render_tfma_component( - 'tfma-nb-time-series', - data, - config, - trusted_html_for_vulcanized_tfma_js=get_trusted_html_for_vulcanized_js()) + Args: + ---- + data: A list of dictionary containing metrics for different evaluation runs. + config: A dictionary of the configuration. + """ + util.render_tfma_component( + "tfma-nb-time-series", + data, + config, + trusted_html_for_vulcanized_tfma_js=get_trusted_html_for_vulcanized_js(), + ) def render_plot( data: Dict[str, List[Union[str, float, List[float]]]], - config: Dict[str, Union[Dict[str, Dict[str, str]], str]]) -> None: - """Renders the plot view in Colab. + config: Dict[str, Union[Dict[str, Dict[str, str]], str]], +) -> None: + """Renders the plot view in Colab. - Args: - data: A dictionary containing plot data. - config: A dictionary of the configuration. - """ - util.render_tfma_component( - 'tfma-nb-plot', - data, - config, - trusted_html_for_vulcanized_tfma_js=get_trusted_html_for_vulcanized_js()) + Args: + ---- + data: A dictionary containing plot data. + config: A dictionary of the configuration. + """ + util.render_tfma_component( + "tfma-nb-plot", + data, + config, + trusted_html_for_vulcanized_tfma_js=get_trusted_html_for_vulcanized_js(), + ) diff --git a/tensorflow_model_analysis/notebook/colab/util.py b/tensorflow_model_analysis/notebook/colab/util.py index 345fa6e95a..133b36b470 100644 --- a/tensorflow_model_analysis/notebook/colab/util.py +++ b/tensorflow_model_analysis/notebook/colab/util.py @@ -16,6 +16,7 @@ import base64 import json from typing import Any, Callable, Dict, List, Optional, Union + from google.colab import output from IPython import display @@ -24,50 +25,55 @@ # Safelist the web component names that can be rendered. _TRUSTED_TFMA_COMPONENT_NAMES = frozenset( - ['tfma-nb-slicing-metrics', 'tfma-nb-time-series', 'tfma-nb-plot']) + ["tfma-nb-slicing-metrics", "tfma-nb-time-series", "tfma-nb-plot"] +) def _create_handler_wrapper(event_handlers): - """Wraps the event handler and registers it as a callback for the js side. + """Wraps the event handler and registers it as a callback for the js side. - Wraps the event handler and registers it as a callback for js. Keep count and - use it as aprt of the callback name to ensure uniqueness. + Wraps the event handler and registers it as a callback for js. Keep count and + use it as aprt of the callback name to ensure uniqueness. - Args: - event_handlers: The hadnler for the js events. + Args: + ---- + event_handlers: The hadnler for the js events. - Returns: - The name of the js callback, safe to render as HTML or JS. - """ - trusted_name = 'tfma_eventCallback' + str(int(_create_handler_wrapper.count)) - _create_handler_wrapper.count += 1 + Returns: + ------- + The name of the js callback, safe to render as HTML or JS. + """ + trusted_name = "tfma_eventCallback" + str(int(_create_handler_wrapper.count)) + _create_handler_wrapper.count += 1 - def wrapped_function(name='', detail=None): - if event_handlers and name in event_handlers: - event_handlers[name](detail) + def wrapped_function(name="", detail=None): + if event_handlers and name in event_handlers: + event_handlers[name](detail) - output.register_callback(trusted_name, wrapped_function) - return trusted_name + output.register_callback(trusted_name, wrapped_function) + return trusted_name _create_handler_wrapper.count = 0 def make_trusted_event_handler_js(event_handlers): - """Generates event handler code in js if provided. + """Generates event handler code in js if provided. - If python event_handlers are provided, generate corresponding js callback and - trigger it when applicable. See tfma-nb-event-mixin.js + If python event_handlers are provided, generate corresponding js callback and + trigger it when applicable. See tfma-nb-event-mixin.js - Args: - event_handlers: The hadnler for the hs events. + Args: + ---- + event_handlers: The hadnler for the hs events. - Returns: - Trusted js code that will call the event handler in python. - """ - if event_handlers: - trusted_callback_name = _create_handler_wrapper(event_handlers) - return """ + Returns: + ------- + Trusted js code that will call the event handler in python. + """ + if event_handlers: + trusted_callback_name = _create_handler_wrapper(event_handlers) + return f""" element.addEventListener('tfma-event', (event) => {{ google.colab.kernel.invokeFunction( '{trusted_callback_name}', [], {{ @@ -75,57 +81,63 @@ def make_trusted_event_handler_js(event_handlers): detail: event.detail.detail }}); }}); - """.format(trusted_callback_name=trusted_callback_name) - else: - return '/** No event handlers needed. */' + """ + else: + return "/** No event handlers needed. */" def to_base64_encoded_json(obj) -> str: - """Encode a Python object as a base64-endoded JSON string. + """Encode a Python object as a base64-endoded JSON string. - When embedding JSON inline inside HTML, serialize it to a JSON string in - Python and base64 encode that string to escape it so that it's safe to render - inside HTML. Then on the JS side, base64 decode it and parse it as JSON. + When embedding JSON inline inside HTML, serialize it to a JSON string in + Python and base64 encode that string to escape it so that it's safe to render + inside HTML. Then on the JS side, base64 decode it and parse it as JSON. - Args: - obj: any Python object serializable to JSON + Args: + ---- + obj: any Python object serializable to JSON - Returns: - base64-encoded string of JSON - """ - json_string = json.dumps(obj) - return base64.b64encode(json_string.encode('utf-8')).decode('utf-8') + Returns: + ------- + base64-encoded string of JSON + """ + json_string = json.dumps(obj) + return base64.b64encode(json_string.encode("utf-8")).decode("utf-8") def generate_html_for_tfma_component( component_name: str, - data: Union[List[Dict[str, Union[Dict[str, Any], str]]], - Dict[str, List[Union[str, float, List[float]]]]], + data: Union[ + List[Dict[str, Union[Dict[str, Any], str]]], + Dict[str, List[Union[str, float, List[float]]]], + ], config: Dict[str, Union[Dict[str, Dict[str, str]], str, bool]], trusted_html_for_vulcanized_tfma_js: str, event_handlers: Optional[PythonEventHandlersMap] = None, ) -> str: - """Generates HTML for TFMA component. - - Args: - component_name: The name of the TFMA web component to render. - data: A dictionary containing data for visualization. - config: A dictionary containing the configuration. - trusted_html_for_vulcanized_tfma_js: Optional string of trusted HTML that is - rendered unescaped. This can be a script tag referencing a trusted - external JS file or a script tag with trusted JS inline. - event_handlers: Handlers for events on the js side. - - Returns: - HTML content of the rendered TFMA component. - """ - - if component_name not in _TRUSTED_TFMA_COMPONENT_NAMES: - raise ValueError('component_name must be one of: ' + - ','.join(_TRUSTED_TFMA_COMPONENT_NAMES)) - - ui_payload = {'config': config, 'data': data} - template = """ + """Generates HTML for TFMA component. + + Args: + ---- + component_name: The name of the TFMA web component to render. + data: A dictionary containing data for visualization. + config: A dictionary containing the configuration. + trusted_html_for_vulcanized_tfma_js: Optional string of trusted HTML that is + rendered unescaped. This can be a script tag referencing a trusted + external JS file or a script tag with trusted JS inline. + event_handlers: Handlers for events on the js side. + + Returns: + ------- + HTML content of the rendered TFMA component. + """ + if component_name not in _TRUSTED_TFMA_COMPONENT_NAMES: + raise ValueError( + "component_name must be one of: " + ",".join(_TRUSTED_TFMA_COMPONENT_NAMES) + ) + + ui_payload = {"config": config, "data": data} + template = """ {trusted_html_for_vulcanized_tfma_js} <{trusted_tfma_component_name} id="component"> """ - html = template.format( - trusted_tfma_component_name=component_name, - trusted_html_for_vulcanized_tfma_js=trusted_html_for_vulcanized_tfma_js, - trusted_event_handler_js=make_trusted_event_handler_js(event_handlers), - base64_encoded_json_payload=to_base64_encoded_json(ui_payload)) - return html + html = template.format( + trusted_tfma_component_name=component_name, + trusted_html_for_vulcanized_tfma_js=trusted_html_for_vulcanized_tfma_js, + trusted_event_handler_js=make_trusted_event_handler_js(event_handlers), + base64_encoded_json_payload=to_base64_encoded_json(ui_payload), + ) + return html def render_tfma_component( component_name: str, - data: Union[List[Dict[str, Union[Dict[str, Any], str]]], - Dict[str, List[Union[str, float, List[float]]]]], + data: Union[ + List[Dict[str, Union[Dict[str, Any], str]]], + Dict[str, List[Union[str, float, List[float]]]], + ], config: Dict[str, Union[Dict[str, Dict[str, str]], str, bool]], trusted_html_for_vulcanized_tfma_js: str, event_handlers: Optional[PythonEventHandlersMap] = None, ) -> None: - """Renders the specified TFMA component in Colab. - - Colab requires custom visualization to be rendered in a sandbox so we cannot - use Jupyter widget. - - Args: - component_name: The name of the TFMA web component to render. - data: A dictionary containing data for visualization. - config: A dictionary containing the configuration. - trusted_html_for_vulcanized_tfma_js: Optional string of trusted HTML that is - rendered unescaped. This can be a script tag referencing a trusted - external JS file or a script tag with trusted JS inline. - event_handlers: Handlers for events on the js side. - """ - - html = generate_html_for_tfma_component(component_name, data, config, - trusted_html_for_vulcanized_tfma_js, - event_handlers) - display.display(display.HTML(html)) + """Renders the specified TFMA component in Colab. + + Colab requires custom visualization to be rendered in a sandbox so we cannot + use Jupyter widget. + + Args: + ---- + component_name: The name of the TFMA web component to render. + data: A dictionary containing data for visualization. + config: A dictionary containing the configuration. + trusted_html_for_vulcanized_tfma_js: Optional string of trusted HTML that is + rendered unescaped. This can be a script tag referencing a trusted + external JS file or a script tag with trusted JS inline. + event_handlers: Handlers for events on the js side. + """ + html = generate_html_for_tfma_component( + component_name, + data, + config, + trusted_html_for_vulcanized_tfma_js, + event_handlers, + ) + display.display(display.HTML(html)) diff --git a/tensorflow_model_analysis/notebook/colab/widget.py b/tensorflow_model_analysis/notebook/colab/widget.py index 5080a3d1ca..916eadf87c 100644 --- a/tensorflow_model_analysis/notebook/colab/widget.py +++ b/tensorflow_model_analysis/notebook/colab/widget.py @@ -17,15 +17,18 @@ # The following empty classes are used for # pytype in tensorflow_model_analysis/view/widget_view.py. class SlicingMetricsViewer: - """Empty viewer class for slicing metrics.""" - pass + """Empty viewer class for slicing metrics.""" + + pass class TimeSeriesViewer: - """Empty viewer class for time series.""" - pass + """Empty viewer class for time series.""" + + pass class PlotViewer: - """Empty viewer class for plot.""" - pass + """Empty viewer class for plot.""" + + pass diff --git a/tensorflow_model_analysis/notebook/jupyter/renderer.py b/tensorflow_model_analysis/notebook/jupyter/renderer.py index b6839d8f15..803901aec4 100644 --- a/tensorflow_model_analysis/notebook/jupyter/renderer.py +++ b/tensorflow_model_analysis/notebook/jupyter/renderer.py @@ -17,60 +17,66 @@ def render_slicing_metrics(data, config, event_handlers=None): - """Renders the slicing metrics view in Jupyter. + """Renders the slicing metrics view in Jupyter. - Args: - data: A list of dictionary containing metrics for correpsonding slices. - config: A dictionary of the configuration. - event_handlers: A dictionary of where keys are event types and values are - event handlers. + Args: + ---- + data: A list of dictionary containing metrics for correpsonding slices. + config: A dictionary of the configuration. + event_handlers: A dictionary of where keys are event types and values are + event handlers. - Returns: - A SlicingMetricsViewer. - """ - if tfma_widget.SlicingMetricsViewer is None: - raise ValueError("tfma_widget.SlicingMetricsViewer is None.") - view = tfma_widget.SlicingMetricsViewer() - view.data = data - view.config = config - view.event_handlers = event_handlers + Returns: + ------- + A SlicingMetricsViewer. + """ + if tfma_widget.SlicingMetricsViewer is None: + raise ValueError("tfma_widget.SlicingMetricsViewer is None.") + view = tfma_widget.SlicingMetricsViewer() + view.data = data + view.config = config + view.event_handlers = event_handlers - return view + return view def render_time_series(data, config): - """Renders the time series view in Jupyter. + """Renders the time series view in Jupyter. - Args: - data: A list of dictionary containing metrics for different evaluation runs. - config: A dictionary of the configuration. + Args: + ---- + data: A list of dictionary containing metrics for different evaluation runs. + config: A dictionary of the configuration. - Returns: - A TimeSeriesViewer. - """ - if tfma_widget.TimeSeriesViewer is None: - raise ValueError("tfma_widget.TimeSeriesViewer is None.") - view = tfma_widget.TimeSeriesViewer() - view.data = data - view.config = config + Returns: + ------- + A TimeSeriesViewer. + """ + if tfma_widget.TimeSeriesViewer is None: + raise ValueError("tfma_widget.TimeSeriesViewer is None.") + view = tfma_widget.TimeSeriesViewer() + view.data = data + view.config = config - return view + return view def render_plot(data, config): - """Renders the plot view in Jupyter. - - Args: - data: A dictionary containing plot data. - config: A dictionary of the configuration. - - Returns: - A PlotViewer. - """ - if tfma_widget.PlotViewer is None: - raise ValueError("tfma_widget.PlotViewer is None.") - view = tfma_widget.PlotViewer() - view.data = data - view.config = config - - return view + """Renders the plot view in Jupyter. + + Args: + ---- + data: A dictionary containing plot data. + config: A dictionary of the configuration. + + Returns: + ------- + A PlotViewer. + """ + if tfma_widget.PlotViewer is None: + raise ValueError("tfma_widget.PlotViewer is None.") + view = tfma_widget.PlotViewer() + view.data = data + view.config = config + + return view diff --git a/tensorflow_model_analysis/notebook/jupyter/tfma_widget/__init__.py b/tensorflow_model_analysis/notebook/jupyter/tfma_widget/__init__.py index c853286182..74f752d9d0 100644 --- a/tensorflow_model_analysis/notebook/jupyter/tfma_widget/__init__.py +++ b/tensorflow_model_analysis/notebook/jupyter/tfma_widget/__init__.py @@ -12,6 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. """Initializes TFMA's Jupyter notebook widget.""" -from tensorflow_model_analysis.notebook.jupyter.tfma_widget.widget import PlotViewer -from tensorflow_model_analysis.notebook.jupyter.tfma_widget.widget import SlicingMetricsViewer -from tensorflow_model_analysis.notebook.jupyter.tfma_widget.widget import TimeSeriesViewer + +from tensorflow_model_analysis.notebook.jupyter.tfma_widget.widget import ( + PlotViewer, + SlicingMetricsViewer, + TimeSeriesViewer, +) diff --git a/tensorflow_model_analysis/notebook/jupyter/tfma_widget/widget.py b/tensorflow_model_analysis/notebook/jupyter/tfma_widget/widget.py index 816c65ea4a..e0ce01e117 100644 --- a/tensorflow_model_analysis/notebook/jupyter/tfma_widget/widget.py +++ b/tensorflow_model_analysis/notebook/jupyter/tfma_widget/widget.py @@ -12,61 +12,66 @@ # See the License for the specific language governing permissions and # limitations under the License. """Defines TFMA's Jupyter notebook widgets.""" + import ipywidgets as widgets -from tensorflow_model_analysis.version import VERSION import traitlets +from tensorflow_model_analysis.version import VERSION + @widgets.register class SlicingMetricsViewer(widgets.DOMWidget): - """The slicing metrics visualization widget.""" - _view_name = traitlets.Unicode('SlicingMetricsView').tag(sync=True) - _model_name = traitlets.Unicode('SlicingMetricsModel').tag(sync=True) - _view_module = traitlets.Unicode('tensorflow_model_analysis').tag(sync=True) - _model_module = traitlets.Unicode('tensorflow_model_analysis').tag(sync=True) - _view_module_version = traitlets.Unicode(VERSION).tag(sync=True) - _model_module_version = traitlets.Unicode(VERSION).tag(sync=True) - data = traitlets.List([]).tag(sync=True) - config = traitlets.Dict(dict()).tag(sync=True) + """The slicing metrics visualization widget.""" - # Used for handling on the js side. - event_handlers = {} - js_events = traitlets.List([]).tag(sync=True) + _view_name = traitlets.Unicode("SlicingMetricsView").tag(sync=True) + _model_name = traitlets.Unicode("SlicingMetricsModel").tag(sync=True) + _view_module = traitlets.Unicode("tensorflow_model_analysis").tag(sync=True) + _model_module = traitlets.Unicode("tensorflow_model_analysis").tag(sync=True) + _view_module_version = traitlets.Unicode(VERSION).tag(sync=True) + _model_module_version = traitlets.Unicode(VERSION).tag(sync=True) + data = traitlets.List([]).tag(sync=True) + config = traitlets.Dict(dict()).tag(sync=True) - @traitlets.observe('js_events') - def _handle_js_events(self, change): - if self.js_events: - if self.event_handlers: - for event in self.js_events: - event_name = event['name'] - if event_name in self.event_handlers: - self.event_handlers[event_name](event['detail']) + # Used for handling on the js side. + event_handlers = {} + js_events = traitlets.List([]).tag(sync=True) - # clears the event queue. - self.js_events = [] + @traitlets.observe("js_events") + def _handle_js_events(self, change): + if self.js_events: + if self.event_handlers: + for event in self.js_events: + event_name = event["name"] + if event_name in self.event_handlers: + self.event_handlers[event_name](event["detail"]) + + # clears the event queue. + self.js_events = [] @widgets.register class TimeSeriesViewer(widgets.DOMWidget): - """The time series visualization widget.""" - _view_name = traitlets.Unicode('TimeSeriesView').tag(sync=True) - _model_name = traitlets.Unicode('TimeSeriesModel').tag(sync=True) - _view_module = traitlets.Unicode('tensorflow_model_analysis').tag(sync=True) - _model_module = traitlets.Unicode('tensorflow_model_analysis').tag(sync=True) - _view_module_version = traitlets.Unicode(VERSION).tag(sync=True) - _model_module_version = traitlets.Unicode(VERSION).tag(sync=True) - data = traitlets.List([]).tag(sync=True) - config = traitlets.Dict(dict()).tag(sync=True) + """The time series visualization widget.""" + + _view_name = traitlets.Unicode("TimeSeriesView").tag(sync=True) + _model_name = traitlets.Unicode("TimeSeriesModel").tag(sync=True) + _view_module = traitlets.Unicode("tensorflow_model_analysis").tag(sync=True) + _model_module = traitlets.Unicode("tensorflow_model_analysis").tag(sync=True) + _view_module_version = traitlets.Unicode(VERSION).tag(sync=True) + _model_module_version = traitlets.Unicode(VERSION).tag(sync=True) + data = traitlets.List([]).tag(sync=True) + config = traitlets.Dict(dict()).tag(sync=True) @widgets.register class PlotViewer(widgets.DOMWidget): - """The time series visualization widget.""" - _view_name = traitlets.Unicode('PlotView').tag(sync=True) - _model_name = traitlets.Unicode('PlotModel').tag(sync=True) - _view_module = traitlets.Unicode('tensorflow_model_analysis').tag(sync=True) - _model_module = traitlets.Unicode('tensorflow_model_analysis').tag(sync=True) - _view_module_version = traitlets.Unicode(VERSION).tag(sync=True) - _model_module_version = traitlets.Unicode(VERSION).tag(sync=True) - data = traitlets.Dict([]).tag(sync=True) - config = traitlets.Dict(dict()).tag(sync=True) + """The time series visualization widget.""" + + _view_name = traitlets.Unicode("PlotView").tag(sync=True) + _model_name = traitlets.Unicode("PlotModel").tag(sync=True) + _view_module = traitlets.Unicode("tensorflow_model_analysis").tag(sync=True) + _model_module = traitlets.Unicode("tensorflow_model_analysis").tag(sync=True) + _view_module_version = traitlets.Unicode(VERSION).tag(sync=True) + _model_module_version = traitlets.Unicode(VERSION).tag(sync=True) + data = traitlets.Dict([]).tag(sync=True) + config = traitlets.Dict(dict()).tag(sync=True) diff --git a/tensorflow_model_analysis/notebook/visualization.py b/tensorflow_model_analysis/notebook/visualization.py index 99f686a560..091dcacf5f 100644 --- a/tensorflow_model_analysis/notebook/visualization.py +++ b/tensorflow_model_analysis/notebook/visualization.py @@ -12,16 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. """Visualization API.""" + import sys def _is_colab(): - return "google.colab" in sys.modules + return "google.colab" in sys.modules if _is_colab(): - from tensorflow_model_analysis.notebook.colab.renderer import * # pylint: disable=wildcard-import,g-import-not-at-top - from tensorflow_model_analysis.notebook.colab.widget import * # pylint: disable=wildcard-import,g-import-not-at-top + from tensorflow_model_analysis.notebook.colab.renderer import * # pylint: disable=wildcard-import,g-import-not-at-top + from tensorflow_model_analysis.notebook.colab.widget import * # pylint: disable=wildcard-import,g-import-not-at-top else: - from tensorflow_model_analysis.notebook.jupyter.renderer import * # pylint: disable=wildcard-import,g-import-not-at-top - from tensorflow_model_analysis.notebook.jupyter.tfma_widget import * # pylint: disable=wildcard-import,g-import-not-at-top + from tensorflow_model_analysis.notebook.jupyter.renderer import * # pylint: disable=wildcard-import,g-import-not-at-top + from tensorflow_model_analysis.notebook.jupyter.tfma_widget import * # pylint: disable=wildcard-import,g-import-not-at-top diff --git a/tensorflow_model_analysis/post_export_metrics/__init__.py b/tensorflow_model_analysis/post_export_metrics/__init__.py index 887163ad00..d9a40624fe 100644 --- a/tensorflow_model_analysis/post_export_metrics/__init__.py +++ b/tensorflow_model_analysis/post_export_metrics/__init__.py @@ -17,24 +17,24 @@ under active development. """ -from tensorflow_model_analysis.post_export_metrics import metric_keys from tensorflow_model_analysis.post_export_metrics import * +from tensorflow_model_analysis.post_export_metrics import metric_keys __all__ = [ - "auc", - "auc_plots", - "calibration", - "calibration_plot_and_prediction_histogram", - "confusion_matrix_at_thresholds", - "DEFAULT_KEY_PREFERENCE", - "example_count", - "example_weight", - "fairness_auc", - "fairness_indicators", - "mean_absolute_error", - "mean_squared_error", - "precision_at_k", - "recall_at_k", - "root_mean_squared_error", - "squared_pearson_correlation", + "auc", + "auc_plots", + "calibration", + "calibration_plot_and_prediction_histogram", + "confusion_matrix_at_thresholds", + "DEFAULT_KEY_PREFERENCE", + "example_count", + "example_weight", + "fairness_auc", + "fairness_indicators", + "mean_absolute_error", + "mean_squared_error", + "precision_at_k", + "recall_at_k", + "root_mean_squared_error", + "squared_pearson_correlation", ] diff --git a/tensorflow_model_analysis/post_export_metrics/metric_keys.py b/tensorflow_model_analysis/post_export_metrics/metric_keys.py index 3b89cea776..074d479ec8 100644 --- a/tensorflow_model_analysis/post_export_metrics/metric_keys.py +++ b/tensorflow_model_analysis/post_export_metrics/metric_keys.py @@ -23,90 +23,99 @@ from typing import Optional # Prefix for post export metrics keys in metric_ops. -DEFAULT_PREFIX = 'post_export_metrics' +DEFAULT_PREFIX = "post_export_metrics" def base_key(suffix: str, prefix: Optional[str] = DEFAULT_PREFIX) -> str: - """Creates a base key from a prefix and a suffix.""" - return '%s/%s' % (prefix, suffix) + """Creates a base key from a prefix and a suffix.""" + return "%s/%s" % (prefix, suffix) def tagged_key(key: str, tag: str) -> str: - """Returns a base key tagged with a user defined tag. + """Returns a base key tagged with a user defined tag. - The tag is inserted after the base key's initial prefix. + The tag is inserted after the base key's initial prefix. - Example: tagged_key('a/c', 'b') -> 'a/b/c' - Example: tagged_key('a', 'b') -> 'a/b' # Use case for plots keys. + Example: tagged_key('a/c', 'b') -> 'a/b/c' + Example: tagged_key('a', 'b') -> 'a/b' # Use case for plots keys. - Args: - key: Base key. - tag: Tag to add to base key. - """ - parts = key.split('/') - if len(parts) > 1: - return '%s/%s/%s' % (parts[0], tag, '/'.join(parts[1:])) - return '%s/%s' % (key, tag) + Args: + ---- + key: Base key. + tag: Tag to add to base key. + """ + parts = key.split("/") + if len(parts) > 1: + return "%s/%s/%s" % (parts[0], tag, "/".join(parts[1:])) + return "%s/%s" % (key, tag) def upper_bound_key(key: str) -> str: - """Creates an upper_bound key from a child key.""" - return key + '/upper_bound' + """Creates an upper_bound key from a child key.""" + return key + "/upper_bound" def lower_bound_key(key: str) -> str: - """Create a lower_bound key from a child key.""" - return key + '/lower_bound' + """Create a lower_bound key from a child key.""" + return key + "/lower_bound" # Not actually for any metric, just used for communicating errors. -ERROR_METRIC = '__ERROR__' +ERROR_METRIC = "__ERROR__" -EXAMPLE_WEIGHT = base_key('example_weight') -EXAMPLE_COUNT = base_key('example_count') -SQUARED_PEARSON_CORRELATION = base_key('squared_pearson_correlation') -CALIBRATION = base_key('calibration') -_CALIBRATION_PLOT_MATRICES_SUFFIX = 'calibration_plot/matrices' +EXAMPLE_WEIGHT = base_key("example_weight") +EXAMPLE_COUNT = base_key("example_count") +SQUARED_PEARSON_CORRELATION = base_key("squared_pearson_correlation") +CALIBRATION = base_key("calibration") +_CALIBRATION_PLOT_MATRICES_SUFFIX = "calibration_plot/matrices" CALIBRATION_PLOT_MATRICES = base_key(_CALIBRATION_PLOT_MATRICES_SUFFIX) -_CALIBRATION_PLOT_BOUNDARIES_SUFFIX = 'calibration_plot/boundaries' +_CALIBRATION_PLOT_BOUNDARIES_SUFFIX = "calibration_plot/boundaries" CALIBRATION_PLOT_BOUNDARIES = base_key(_CALIBRATION_PLOT_BOUNDARIES_SUFFIX) CONFUSION_MATRIX_AT_THRESHOLDS_MATRICES = base_key( - 'confusion_matrix_at_thresholds/matrices') + "confusion_matrix_at_thresholds/matrices" +) CONFUSION_MATRIX_AT_THRESHOLDS_THRESHOLDS = base_key( - 'confusion_matrix_at_thresholds/thresholds') + "confusion_matrix_at_thresholds/thresholds" +) CONFUSION_MATRIX_AT_THRESHOLDS = base_key( - 'confusion_matrix_at_thresholds') # Output-only + "confusion_matrix_at_thresholds" +) # Output-only FAIRNESS_CONFUSION_MATRIX_MATRICES = base_key( - 'fairness/confusion_matrix_at_thresholds/matrices') + "fairness/confusion_matrix_at_thresholds/matrices" +) FAIRNESS_CONFUSION_MATRIX_THESHOLDS = base_key( - 'fairness/confusion_matrix_at_thresholds/thresholds') + "fairness/confusion_matrix_at_thresholds/thresholds" +) FAIRNESS_CONFUSION_MATRIX = base_key( - 'fairness/confusion_matrix_at_thresholds') # Output-only -FAIRNESS_AUC = base_key('fairness/auc') -_AUC_PLOTS_MATRICES_SUFFIX = 'auc_plots/matrices' + "fairness/confusion_matrix_at_thresholds" +) # Output-only +FAIRNESS_AUC = base_key("fairness/auc") +_AUC_PLOTS_MATRICES_SUFFIX = "auc_plots/matrices" AUC_PLOTS_MATRICES = base_key(_AUC_PLOTS_MATRICES_SUFFIX) -_AUC_PLOTS_THRESHOLDS_SUFFIX = 'auc_plots/thresholds' +_AUC_PLOTS_THRESHOLDS_SUFFIX = "auc_plots/thresholds" AUC_PLOTS_THRESHOLDS = base_key(_AUC_PLOTS_THRESHOLDS_SUFFIX) -AUC = base_key('auc') -AUPRC = base_key('auprc') -PRECISION_AT_K = base_key('precision_at_k') -RECALL_AT_K = base_key('recall_at_k') -MEAN_ABSOLUTE_ERROR = base_key('mean_absolute_error') -MEAN_SQUARED_ERROR = base_key('mean_squared_error') -ROOT_MEAN_SQUARED_ERROR = base_key('root_mean_squared_error') +AUC = base_key("auc") +AUPRC = base_key("auprc") +PRECISION_AT_K = base_key("precision_at_k") +RECALL_AT_K = base_key("recall_at_k") +MEAN_ABSOLUTE_ERROR = base_key("mean_absolute_error") +MEAN_SQUARED_ERROR = base_key("mean_squared_error") +ROOT_MEAN_SQUARED_ERROR = base_key("root_mean_squared_error") # Suffixes of keys where the corresponding values are results for plots _PLOT_SUFFIXES = [ - _CALIBRATION_PLOT_MATRICES_SUFFIX, _CALIBRATION_PLOT_BOUNDARIES_SUFFIX, - _AUC_PLOTS_MATRICES_SUFFIX, _AUC_PLOTS_THRESHOLDS_SUFFIX + _CALIBRATION_PLOT_MATRICES_SUFFIX, + _CALIBRATION_PLOT_BOUNDARIES_SUFFIX, + _AUC_PLOTS_MATRICES_SUFFIX, + _AUC_PLOTS_THRESHOLDS_SUFFIX, ] def is_plot_key(key: str) -> bool: - """Returns true if key is a plot key.""" - # We need to check for suffixes here because metrics may have prefixes based - # on multiple labels and/or heads. - for suffix in _PLOT_SUFFIXES: - if key.endswith(suffix): - return True - return False + """Returns true if key is a plot key.""" + # We need to check for suffixes here because metrics may have prefixes based + # on multiple labels and/or heads. + for suffix in _PLOT_SUFFIXES: + if key.endswith(suffix): + return True + return False diff --git a/tensorflow_model_analysis/sdk.py b/tensorflow_model_analysis/sdk.py index 14f15fd3c9..a5258cef86 100644 --- a/tensorflow_model_analysis/sdk.py +++ b/tensorflow_model_analysis/sdk.py @@ -26,91 +26,74 @@ # pylint: disable=unused-import # Allow constants to be imported at the top-level since they live in root dir. -from tensorflow_model_analysis.constants import ANALYSIS_KEY -from tensorflow_model_analysis.constants import ARROW_INPUT_COLUMN -from tensorflow_model_analysis.constants import ARROW_RECORD_BATCH_KEY -from tensorflow_model_analysis.constants import ATTRIBUTIONS_KEY -from tensorflow_model_analysis.constants import BASELINE_KEY -from tensorflow_model_analysis.constants import BASELINE_SCORE_KEY -from tensorflow_model_analysis.constants import CANDIDATE_KEY -from tensorflow_model_analysis.constants import DATA_CENTRIC_MODE -from tensorflow_model_analysis.constants import EXAMPLE_SCORE_KEY -from tensorflow_model_analysis.constants import EXAMPLE_WEIGHTS_KEY # TODO(b/120222218): Remove after passing of native FPL supported. -from tensorflow_model_analysis.constants import FEATURES_KEY -from tensorflow_model_analysis.constants import FEATURES_PREDICTIONS_LABELS_KEY -from tensorflow_model_analysis.constants import INPUT_KEY -from tensorflow_model_analysis.constants import LABELS_KEY -from tensorflow_model_analysis.constants import METRICS_KEY -from tensorflow_model_analysis.constants import MODEL_CENTRIC_MODE -from tensorflow_model_analysis.constants import PLOTS_KEY -from tensorflow_model_analysis.constants import PREDICTIONS_KEY -from tensorflow_model_analysis.constants import SLICE_KEY_TYPES_KEY -from tensorflow_model_analysis.constants import TF_ESTIMATOR -from tensorflow_model_analysis.constants import TF_GENERIC -from tensorflow_model_analysis.constants import TF_JS -from tensorflow_model_analysis.constants import TF_KERAS -from tensorflow_model_analysis.constants import TF_LITE -from tensorflow_model_analysis.constants import TFMA_EVAL -from tensorflow_model_analysis.constants import VALIDATIONS_KEY +from tensorflow_model_analysis.constants import ( + ANALYSIS_KEY, + ARROW_INPUT_COLUMN, + ARROW_RECORD_BATCH_KEY, + ATTRIBUTIONS_KEY, + BASELINE_KEY, + BASELINE_SCORE_KEY, + CANDIDATE_KEY, + DATA_CENTRIC_MODE, + EXAMPLE_SCORE_KEY, + EXAMPLE_WEIGHTS_KEY, + FEATURES_KEY, + FEATURES_PREDICTIONS_LABELS_KEY, + INPUT_KEY, + LABELS_KEY, + METRICS_KEY, + MODEL_CENTRIC_MODE, + PLOTS_KEY, + PREDICTIONS_KEY, + SLICE_KEY_TYPES_KEY, + TF_ESTIMATOR, + TF_GENERIC, + TF_JS, + TF_KERAS, + TF_LITE, + TFMA_EVAL, + VALIDATIONS_KEY, +) # Allow proto types to be imported at the top-level since proto's live in # the tensorflow_model_analysis namespace. # pylint: disable=g-importing-member -from tensorflow_model_analysis.proto.config_pb2 import AggregationOptions -from tensorflow_model_analysis.proto.config_pb2 import BinarizationOptions -from tensorflow_model_analysis.proto.config_pb2 import ConfidenceIntervalOptions -from tensorflow_model_analysis.proto.config_pb2 import CrossSliceMetricThreshold -from tensorflow_model_analysis.proto.config_pb2 import CrossSliceMetricThresholds -from tensorflow_model_analysis.proto.config_pb2 import CrossSlicingSpec -from tensorflow_model_analysis.proto.config_pb2 import EvalConfig -from tensorflow_model_analysis.proto.config_pb2 import ExampleWeightOptions -from tensorflow_model_analysis.proto.config_pb2 import GenericChangeThreshold -from tensorflow_model_analysis.proto.config_pb2 import GenericValueThreshold -from tensorflow_model_analysis.proto.config_pb2 import MetricConfig -from tensorflow_model_analysis.proto.config_pb2 import MetricDirection -from tensorflow_model_analysis.proto.config_pb2 import MetricsSpec -from tensorflow_model_analysis.proto.config_pb2 import MetricThreshold -from tensorflow_model_analysis.proto.config_pb2 import ModelSpec -from tensorflow_model_analysis.proto.config_pb2 import Options -from tensorflow_model_analysis.proto.config_pb2 import PaddingOptions -from tensorflow_model_analysis.proto.config_pb2 import PerSliceMetricThreshold -from tensorflow_model_analysis.proto.config_pb2 import PerSliceMetricThresholds -from tensorflow_model_analysis.proto.config_pb2 import RepeatedInt32Value -from tensorflow_model_analysis.proto.config_pb2 import RepeatedStringValue -from tensorflow_model_analysis.proto.config_pb2 import SlicingSpec -# pylint: enable=g-importing-member +from tensorflow_model_analysis.proto.config_pb2 import ( + MetricDirection, +) +# pylint: enable=g-importing-member # Import VERSION as VERSION_STRING for backwards compatibility. from tensorflow_model_analysis.version import VERSION as VERSION_STRING __all__ = [ - "ANALYSIS_KEY", - "ARROW_INPUT_COLUMN", - "ARROW_RECORD_BATCH_KEY", - "ATTRIBUTIONS_KEY", - "BASELINE_KEY", - "BASELINE_SCORE_KEY", - "CANDIDATE_KEY", - "DATA_CENTRIC_MODE", - "EXAMPLE_SCORE_KEY", - "EXAMPLE_WEIGHTS_KEY", - "FEATURES_KEY", - "FEATURES_PREDICTIONS_LABELS_KEY", - "INPUT_KEY", - "LABELS_KEY", - "METRICS_KEY", - "MODEL_CENTRIC_MODE", - "MetricDirection", - "PLOTS_KEY", - "PREDICTIONS_KEY", - "SLICE_KEY_TYPES_KEY", - "TFMA_EVAL", - "TF_ESTIMATOR", - "TF_GENERIC", - "TF_JS", - "TF_KERAS", - "TF_LITE", - "VALIDATIONS_KEY", - "VERSION_STRING", + "ANALYSIS_KEY", + "ARROW_INPUT_COLUMN", + "ARROW_RECORD_BATCH_KEY", + "ATTRIBUTIONS_KEY", + "BASELINE_KEY", + "BASELINE_SCORE_KEY", + "CANDIDATE_KEY", + "DATA_CENTRIC_MODE", + "EXAMPLE_SCORE_KEY", + "EXAMPLE_WEIGHTS_KEY", + "FEATURES_KEY", + "FEATURES_PREDICTIONS_LABELS_KEY", + "INPUT_KEY", + "LABELS_KEY", + "METRICS_KEY", + "MODEL_CENTRIC_MODE", + "MetricDirection", + "PLOTS_KEY", + "PREDICTIONS_KEY", + "SLICE_KEY_TYPES_KEY", + "TFMA_EVAL", + "TF_ESTIMATOR", + "TF_GENERIC", + "TF_JS", + "TF_KERAS", + "TF_LITE", + "VALIDATIONS_KEY", + "VERSION_STRING", ] diff --git a/tensorflow_model_analysis/slicer/__init__.py b/tensorflow_model_analysis/slicer/__init__.py index 6102401bd6..c8f5b1e8e6 100644 --- a/tensorflow_model_analysis/slicer/__init__.py +++ b/tensorflow_model_analysis/slicer/__init__.py @@ -13,12 +13,14 @@ # limitations under the License. """Init module for TensorFlow Model Analysis slicer.""" -from tensorflow_model_analysis.slicer.slicer_lib import CrossSliceKeyType -from tensorflow_model_analysis.slicer.slicer_lib import deserialize_slice_key -from tensorflow_model_analysis.slicer.slicer_lib import FanoutSlices -from tensorflow_model_analysis.slicer.slicer_lib import OVERALL_SLICE_KEY -from tensorflow_model_analysis.slicer.slicer_lib import serialize_slice_key -from tensorflow_model_analysis.slicer.slicer_lib import SingleSliceSpec -from tensorflow_model_analysis.slicer.slicer_lib import SliceKeyOrCrossSliceKeyType -from tensorflow_model_analysis.slicer.slicer_lib import SliceKeyType -from tensorflow_model_analysis.slicer.slicer_lib import stringify_slice_key +from tensorflow_model_analysis.slicer.slicer_lib import ( + OVERALL_SLICE_KEY, + CrossSliceKeyType, + FanoutSlices, + SingleSliceSpec, + SliceKeyOrCrossSliceKeyType, + SliceKeyType, + deserialize_slice_key, + serialize_slice_key, + stringify_slice_key, +) diff --git a/tensorflow_model_analysis/slicer/slice_accessor.py b/tensorflow_model_analysis/slicer/slice_accessor.py index 0847dd2da7..bf0c3e211a 100644 --- a/tensorflow_model_analysis/slicer/slice_accessor.py +++ b/tensorflow_model_analysis/slicer/slice_accessor.py @@ -21,76 +21,93 @@ import numpy as np import pyarrow as pa import tensorflow as tf + from tensorflow_model_analysis.api import types class SliceAccessor: - """Wrapper around features dict for accessing keys and values for slicing.""" + """Wrapper around features dict for accessing keys and values for slicing.""" - def __init__(self, - features_dicts: Iterable[Union[types.DictOfTensorValue, - types.DictOfFetchedTensorValues]], - default_features_dict: Optional[ - Union[types.DictOfTensorValue, - types.DictOfFetchedTensorValues]] = None): - self._features_dicts = features_dicts - self._default_features_dict = default_features_dict + def __init__( + self, + features_dicts: Iterable[ + Union[types.DictOfTensorValue, types.DictOfFetchedTensorValues] + ], + default_features_dict: Optional[ + Union[types.DictOfTensorValue, types.DictOfFetchedTensorValues] + ] = None, + ): + self._features_dicts = features_dicts + self._default_features_dict = default_features_dict - def has_key(self, key: str): - for d in self._features_dicts: - if key in d and d[key] is not None: - return True - if (self._default_features_dict and key in self._default_features_dict and - self._default_features_dict[key] is not None): - return True - return False + def has_key(self, key: str): + for d in self._features_dicts: + if key in d and d[key] is not None: + return True + if ( + self._default_features_dict + and key in self._default_features_dict + and self._default_features_dict[key] is not None + ): + return True + return False - def get(self, key: str) -> List[Union[int, bytes, float]]: - """Get the values of the feature with the given key. + def get(self, key: str) -> List[Union[int, bytes, float]]: + """Get the values of the feature with the given key. - Args: - key: the key of the feature to get the values of + Args: + ---- + key: the key of the feature to get the values of - Returns: - The values of the feature. + Returns: + ------- + The values of the feature. - Raises: - KeyError: If the feature was not present in the input example. - ValueError: A dense feature was not a 1D array. - ValueError: The feature had an unknown type. - """ + Raises: + ------ + KeyError: If the feature was not present in the input example. + ValueError: A dense feature was not a 1D array. + ValueError: The feature had an unknown type. + """ - def normalize_value(value): - if value is None: - return None - if isinstance(value, dict) and 'node' in value: - # Backwards compatibility for features that were stored as FPL types - # instead of native dicts. - value = value['node'] - if isinstance(value, (types.SparseTensorValue, types.RaggedTensorValue, - tf.compat.v1.SparseTensorValue, - tf.compat.v1.ragged.RaggedTensorValue)): - value = value.values - if not isinstance(value, (np.ndarray, pa.Array, list)): - raise ValueError( - 'feature had unsupported type: key: %s, value: %s, type: %s' % - (key, value, type(value))) - # Only np.array and multi-dimentional pa.array support flatten. - if hasattr(value, 'flatten'): - value = value.flatten() - return value + def normalize_value(value): + if value is None: + return None + if isinstance(value, dict) and "node" in value: + # Backwards compatibility for features that were stored as FPL types + # instead of native dicts. + value = value["node"] + if isinstance( + value, + ( + types.SparseTensorValue, + types.RaggedTensorValue, + tf.compat.v1.SparseTensorValue, + tf.compat.v1.ragged.RaggedTensorValue, + ), + ): + value = value.values + if not isinstance(value, (np.ndarray, pa.Array, list)): + raise ValueError( + "feature had unsupported type: key: %s, value: %s, type: %s" + % (key, value, type(value)) + ) + # Only np.array and multi-dimentional pa.array support flatten. + if hasattr(value, "flatten"): + value = value.flatten() + return value - values = None - for d in self._features_dicts: - value = normalize_value(d.get(key)) - if value is None: - continue - if values is None: - values = value - else: - values = np.concatenate((values, value)) - if values is None and self._default_features_dict: - values = normalize_value(self._default_features_dict.get(key)) - if values is None: - raise KeyError('key %s not found' % key) - return np.unique(values).tolist() + values = None + for d in self._features_dicts: + value = normalize_value(d.get(key)) + if value is None: + continue + if values is None: + values = value + else: + values = np.concatenate((values, value)) + if values is None and self._default_features_dict: + values = normalize_value(self._default_features_dict.get(key)) + if values is None: + raise KeyError("key %s not found" % key) + return np.unique(values).tolist() diff --git a/tensorflow_model_analysis/slicer/slice_accessor_test.py b/tensorflow_model_analysis/slicer/slice_accessor_test.py index 90f8d949e5..5e12a1462f 100644 --- a/tensorflow_model_analysis/slicer/slice_accessor_test.py +++ b/tensorflow_model_analysis/slicer/slice_accessor_test.py @@ -13,85 +13,95 @@ # limitations under the License. """Slice accessor test.""" -from absl.testing import parameterized import numpy as np import pyarrow as pa import tensorflow as tf +from absl.testing import parameterized + from tensorflow_model_analysis.api import types from tensorflow_model_analysis.slicer import slice_accessor -_ENCODING_NODE_SUFFIX = 'node' +_ENCODING_NODE_SUFFIX = "node" class SliceAccessorTest(tf.test.TestCase, parameterized.TestCase): + def testRaisesKeyError(self): + accessor = slice_accessor.SliceAccessor({}) + with self.assertRaises(KeyError): + accessor.get("no_such_key") - def testRaisesKeyError(self): - accessor = slice_accessor.SliceAccessor({}) - with self.assertRaises(KeyError): - accessor.get('no_such_key') - - @parameterized.named_parameters( - ('sparse_tensor_value', - types.SparseTensorValue( - indices=np.array([[0, 0], [1, 1]]), - values=np.array(['apple', 'banana']), - dense_shape=np.array([2, 2])), ['apple', 'banana']), - ('ragged_tensor_value', - types.RaggedTensorValue( - values=np.array([1, 2, 3]), nested_row_splits=[np.array([0, 0, 1])]), - [1, 2, 3]), ('dense', np.array([1.0, 2.0]), [1.0, 2.0]), - ('dense_single', np.array([7.0]), [7.0]), - ('dense_multidim', np.array([[1.0, 2.0], [3.0, 4.0]]), - [1.0, 2.0, 3.0, 4.0]), ('squeeze_needed', np.array([[2.0]]), [2.0]), - ('list', [1, 2, 3], [1, 2, 3]), - ('pyarrow', pa.array([1, 2, 3]), [1, 2, 3]), - ('pyarrow_ragged', pa.array([[1, 2], [3]]), [1, 2, 3])) - def testAccessFeaturesDict(self, feature_value, slice_value): - accessor = slice_accessor.SliceAccessor([{'feature': feature_value}]) - self.assertEqual(slice_value, accessor.get('feature')) - # Test with multiple dicts and duplicate values - accessor = slice_accessor.SliceAccessor([{ - 'feature': feature_value - }, { - 'feature': feature_value - }]) - self.assertEqual(slice_value, accessor.get('feature')) - # Test with default features dict - accessor = slice_accessor.SliceAccessor( - [{ - 'unmatched_feature': feature_value - }], - default_features_dict={'feature': feature_value}) - self.assertEqual(slice_value, accessor.get('feature')) + @parameterized.named_parameters( + ( + "sparse_tensor_value", + types.SparseTensorValue( + indices=np.array([[0, 0], [1, 1]]), + values=np.array(["apple", "banana"]), + dense_shape=np.array([2, 2]), + ), + ["apple", "banana"], + ), + ( + "ragged_tensor_value", + types.RaggedTensorValue( + values=np.array([1, 2, 3]), nested_row_splits=[np.array([0, 0, 1])] + ), + [1, 2, 3], + ), + ("dense", np.array([1.0, 2.0]), [1.0, 2.0]), + ("dense_single", np.array([7.0]), [7.0]), + ("dense_multidim", np.array([[1.0, 2.0], [3.0, 4.0]]), [1.0, 2.0, 3.0, 4.0]), + ("squeeze_needed", np.array([[2.0]]), [2.0]), + ("list", [1, 2, 3], [1, 2, 3]), + ("pyarrow", pa.array([1, 2, 3]), [1, 2, 3]), + ("pyarrow_ragged", pa.array([[1, 2], [3]]), [1, 2, 3]), + ) + def testAccessFeaturesDict(self, feature_value, slice_value): + accessor = slice_accessor.SliceAccessor([{"feature": feature_value}]) + self.assertEqual(slice_value, accessor.get("feature")) + # Test with multiple dicts and duplicate values + accessor = slice_accessor.SliceAccessor( + [{"feature": feature_value}, {"feature": feature_value}] + ) + self.assertEqual(slice_value, accessor.get("feature")) + # Test with default features dict + accessor = slice_accessor.SliceAccessor( + [{"unmatched_feature": feature_value}], + default_features_dict={"feature": feature_value}, + ) + self.assertEqual(slice_value, accessor.get("feature")) - def testLegacyAccessFeaturesDict(self): - with tf.compat.v1.Session() as sess: - sparse = tf.SparseTensor( - indices=[[0, 0], [1, 1]], - values=['apple', 'banana'], - dense_shape=[2, 2]) - dense = tf.constant([1.0, 2.0]) - dense_single = tf.constant([7.0]) - dense_multidim = tf.constant([[1.0, 2.0], [3.0, 4.0]]) - squeeze_needed = tf.constant([[2.0]]) - (sparse_value, dense_value, dense_single_value, dense_multidim_value, - squeeze_needed_value - ) = sess.run( - fetches=[sparse, dense, dense_single, dense_multidim, squeeze_needed]) - features_dict = { - 'sparse': {_ENCODING_NODE_SUFFIX: sparse_value}, - 'dense': {_ENCODING_NODE_SUFFIX: dense_value}, - 'dense_single': {_ENCODING_NODE_SUFFIX: dense_single_value}, - 'squeeze_needed': {_ENCODING_NODE_SUFFIX: squeeze_needed_value}, - 'dense_multidim': {_ENCODING_NODE_SUFFIX: dense_multidim_value}, - } - accessor = slice_accessor.SliceAccessor([features_dict]) - self.assertEqual([b'apple', b'banana'], accessor.get('sparse')) - self.assertEqual([1.0, 2.0], accessor.get('dense')) - self.assertEqual([7.0], accessor.get('dense_single')) - self.assertEqual([1.0, 2.0, 3.0, 4.0], accessor.get('dense_multidim')) - self.assertEqual([2.0], accessor.get('squeeze_needed')) + def testLegacyAccessFeaturesDict(self): + with tf.compat.v1.Session() as sess: + sparse = tf.SparseTensor( + indices=[[0, 0], [1, 1]], values=["apple", "banana"], dense_shape=[2, 2] + ) + dense = tf.constant([1.0, 2.0]) + dense_single = tf.constant([7.0]) + dense_multidim = tf.constant([[1.0, 2.0], [3.0, 4.0]]) + squeeze_needed = tf.constant([[2.0]]) + ( + sparse_value, + dense_value, + dense_single_value, + dense_multidim_value, + squeeze_needed_value, + ) = sess.run( + fetches=[sparse, dense, dense_single, dense_multidim, squeeze_needed] + ) + features_dict = { + "sparse": {_ENCODING_NODE_SUFFIX: sparse_value}, + "dense": {_ENCODING_NODE_SUFFIX: dense_value}, + "dense_single": {_ENCODING_NODE_SUFFIX: dense_single_value}, + "squeeze_needed": {_ENCODING_NODE_SUFFIX: squeeze_needed_value}, + "dense_multidim": {_ENCODING_NODE_SUFFIX: dense_multidim_value}, + } + accessor = slice_accessor.SliceAccessor([features_dict]) + self.assertEqual([b"apple", b"banana"], accessor.get("sparse")) + self.assertEqual([1.0, 2.0], accessor.get("dense")) + self.assertEqual([7.0], accessor.get("dense_single")) + self.assertEqual([1.0, 2.0, 3.0, 4.0], accessor.get("dense_multidim")) + self.assertEqual([2.0], accessor.get("squeeze_needed")) -if __name__ == '__main__': - tf.test.main() +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow_model_analysis/slicer/slicer_lib.py b/tensorflow_model_analysis/slicer/slicer_lib.py index 186d7b48af..1ad4488595 100644 --- a/tensorflow_model_analysis/slicer/slicer_lib.py +++ b/tensorflow_model_analysis/slicer/slicer_lib.py @@ -18,16 +18,26 @@ """ import itertools - -from typing import Any, Callable, Dict, Generator, Iterable, List, NamedTuple, Optional, Tuple, Union +from typing import ( + Any, + Callable, + Dict, + Generator, + Iterable, + List, + NamedTuple, + Optional, + Tuple, + Union, +) import apache_beam as beam import numpy as np import tensorflow as tf + from tensorflow_model_analysis import constants from tensorflow_model_analysis.api import types -from tensorflow_model_analysis.proto import config_pb2 -from tensorflow_model_analysis.proto import metrics_for_slice_pb2 +from tensorflow_model_analysis.proto import config_pb2, metrics_for_slice_pb2 from tensorflow_model_analysis.slicer import slice_accessor # FeatureValueType represents a value that a feature could take. @@ -47,577 +57,644 @@ SliceKeyOrCrossSliceKeyType = Union[SliceKeyType, CrossSliceKeyType] # pylint: disable=invalid-name -OVERALL_SLICE_NAME = 'Overall' +OVERALL_SLICE_NAME = "Overall" # The slice key for the slice that includes all of the data. OVERALL_SLICE_KEY = () class SingleSliceSpec: - """Specification for a single slice. - - This is intended to be an immutable class that specifies a single slice. - Use this in conjunction with get_slices_for_features_dicts to generate slices - for dictionaries of features. - - Examples: - - columns = ['age'], features = [] - This means to slice by the 'age' column. - - columns = ['age'], features = [('gender', 'female')] - This means to slice by the 'age' column if the 'gender' is 'female'. - - For more examples, refer to the tests in slicer_test.py. - """ - - def __eq__(self, other: 'SingleSliceSpec'): - # Need access to other's protected fields for comparison. - if isinstance(other, self.__class__): - # pylint: disable=protected-access - return (self._columns == other._columns and - self._features == other._features) - # pylint: enable=protected-access - else: - return False + """Specification for a single slice. + + This is intended to be an immutable class that specifies a single slice. + Use this in conjunction with get_slices_for_features_dicts to generate slices + for dictionaries of features. + + Examples + -------- + - columns = ['age'], features = [] + This means to slice by the 'age' column. + - columns = ['age'], features = [('gender', 'female')] + This means to slice by the 'age' column if the 'gender' is 'female'. + - For more examples, refer to the tests in slicer_test.py. + """ - def __ne__(self, other: 'SingleSliceSpec'): - return not self.__eq__(other) + def __eq__(self, other: "SingleSliceSpec"): + # Need access to other's protected fields for comparison. + if isinstance(other, self.__class__): + # pylint: disable=protected-access + return self._columns == other._columns and self._features == other._features + # pylint: enable=protected-access + else: + return False + + def __ne__(self, other: "SingleSliceSpec"): + return not self.__eq__(other) + + def __hash__(self): + return hash((self._columns, self._features)) + + def __init__( + self, + columns: Iterable[str] = (), + features: Iterable[Tuple[str, FeatureValueType]] = (), + spec: Optional[config_pb2.SlicingSpec] = None, + ): + """Initialises a SingleSliceSpec. + + Args: + ---- + columns: an iterable of column names to slice on. + features: an iterable of features to slice on. Each feature is a (key, + value) tuple. Note that strings representing ints and floats will be + automatically converted to ints and floats respectively and will be + compared against both the string versions and int or float versions of + the associated features. + spec: Initializes slicing spec from proto. If not None, overrides any + values passed in columns or features. + + Raises: + ------ + ValueError: There was overlap between the columns specified in columns + and those in features. + ValueError: columns or features was a string: they should probably be a + singleton list containing that string. + """ + if isinstance(columns, str): + raise ValueError( + "columns is a string: it should probably be a singleton " + "list containing that string" + ) + if isinstance(features, str): + raise ValueError( + "features is a string: it should probably be a " + "singleton list containing that string" + ) + + if spec is not None: + columns = spec.feature_keys + features = [(k, v) for k, v in spec.feature_values.items()] + + features = [(k, _to_type(v)) for (k, v) in features] + + self._columns = frozenset(columns) + self._features = frozenset(features) + + # We build this up as an instance variable, instead of building it each + # time we call generate_slices, for efficiency reasons. + # + # This is a flat list of SingletonSliceKeyTypes, + # i.e. List[SingletonSliceKeyType]. + self._value_matches = [] + + for key, value in self._features: + if key in self._columns: + raise ValueError( + "Columns specified in columns and in features should " + "not overlap, but %s was specified in both." % key + ) + self._value_matches.append((key, value)) + self._value_matches = sorted(self._value_matches) + + def __repr__(self): + return "SingleSliceSpec(columns=%s, features=%s)" % ( + self._columns, + self._features, + ) + + def to_proto(self) -> config_pb2.SlicingSpec: + feature_values = {k: str(v) for (k, v) in self._features} + return config_pb2.SlicingSpec( + feature_keys=self._columns, feature_values=feature_values + ) + + def is_overall(self): + """Returns True if this specification represents the overall slice.""" + return not self._columns and not self._features + + def is_slice_applicable(self, slice_key: SliceKeyType): + """Determines if this slice spec is applicable to a slice of data. + + Args: + ---- + slice_key: The slice as a SliceKeyType + + Returns: + ------- + True if the slice_spec is applicable to the given slice, False otherwise. + """ + columns = list(self._columns) + features = list(self._features) + for singleton_slice_key in slice_key: + # Convert to internal representation of slice (i.e. str -> float, etc). + if len(singleton_slice_key) == 2: + singleton_slice_key = ( + singleton_slice_key[0], + _to_type(singleton_slice_key[1]), + ) + if singleton_slice_key in features: + features.remove(singleton_slice_key) + elif singleton_slice_key[0] in columns: + columns.remove(singleton_slice_key[0]) + else: + return False + return not features and not columns + + def generate_slices( + self, accessor: slice_accessor.SliceAccessor + ) -> Generator[SliceKeyType, None, None]: + """Generates all slices that match this specification from the data. + + Should only be called within this file. + + Examples: + -------- + - columns = [], features = [] (the overall slice case) + slice accessor has features age=[5], gender=['f'], interest=['knitting'] + returns [()] + - columns = ['age'], features = [('gender', 'f')] + slice accessor has features age=[5], gender=['f'], interest=['knitting'] + returns [[('age', 5), ('gender, 'f')]] + - columns = ['interest'], features = [('gender', 'f')] + slice accessor has features age=[5], gender=['f'], + interest=['knitting', 'games'] + returns [[('gender', 'f'), ('interest, 'knitting')], + [('gender', 'f'), ('interest, 'games')]] + + Args: + ---- + accessor: slice accessor. + + Yields: + ------ + A SliceKeyType for each slice that matches this specification. Nothing + will be yielded if there no slices matched this specification. The entries + in the yielded SliceKeyTypes are guaranteed to be sorted by key names (and + then values, if necessary), ascending. + """ + # Check all the value matches (where there's a specific value specified). + for key, value in self._features: + if not accessor.has_key(key): + return + + accessor_values = accessor.get(key) + if value not in accessor_values: + if isinstance(value, str): + if value.encode() not in accessor_values: # For Python3. + return + # Check that string version of int/float not in values. + elif str(value) not in accessor_values: + return + + # Get all the column matches (where we're matching only the column). + # + # For each column, we generate a List[SingletonSliceKeyType] containing + # all pairs (column, value) for all values of the column. So this will be + # a List[List[SingletonSliceKeyType]]. + # + # For example, column_matches might be: + # [[('gender', 'f'), ('gender', 'm')], [('age', 4), ('age', 5)]] + column_matches = [] + for column in self._columns: + # If a column to slice on doesn't appear in the example, then there will + # be no applicable slices, so return. + if not accessor.has_key(column): + return + + column_match = [] + for value in accessor.get(column): + if isinstance(value, bytes): + try: + column_match.append((column, tf.compat.as_text(value))) + except UnicodeDecodeError as e: + raise ValueError( + f"Found non-UTF8 feature value {value} in " + f'column "{column}"' + ) from e + else: + column_match.append((column, value)) + column_matches.append(column_match) + + # We can now take the Cartesian product of the column_matches, and append + # the value matches to each element of that, to generate the final list of + # slices. Note that for the overall slice case the column_matches is [] and + # the Cartesian product of [] is (). + for column_part in itertools.product(*column_matches): + yield tuple(sorted(self._value_matches + list(column_part))) - def __hash__(self): - return hash((self._columns, self._features)) - def __init__(self, - columns: Iterable[str] = (), - features: Iterable[Tuple[str, FeatureValueType]] = (), - spec: Optional[config_pb2.SlicingSpec] = None): - """Initialises a SingleSliceSpec. +class CrossSliceSpec( + NamedTuple( + "CrossSliceSpec", + [ + ("base_slicing_spec", SingleSliceSpec), + ("slicing_specs", Tuple[SingleSliceSpec]), + ], + ) +): + """Specification for a cross slice. + + This is intended to be an immutable class that specifies a cross slice. + Use this in conjunction with get_slices_for_features_dicts to generate slices + for dictionaries of features. + + Attributes + ---------- + base_slicing_spec: The baseline slicing spec. + slicing_specs: A tuple of slicing specs of the proto counterparts. + """ + + def __new__(cls, spec: config_pb2.CrossSlicingSpec): + """Create a new CrrossSliceSpec object from its Proto counnterpart.""" + # This is organized as a Tuple(baseline_spec, Tuple(slice_specs)) + return super(CrossSliceSpec, cls).__new__( + cls, + SingleSliceSpec(spec=spec.baseline_spec), + tuple( + SingleSliceSpec(spec=slicing_spec) + for slicing_spec in spec.slicing_specs + ), + ) + + +def deserialize_slice_spec( + slice_spec: Union[config_pb2.SlicingSpec, config_pb2.CrossSlicingSpec], +) -> Union[SingleSliceSpec, CrossSliceSpec]: + """Creates the appropriate hashable slicing spec object. Args: - columns: an iterable of column names to slice on. - features: an iterable of features to slice on. Each feature is a (key, - value) tuple. Note that strings representing ints and floats will be - automatically converted to ints and floats respectively and will be - compared against both the string versions and int or float versions of - the associated features. - spec: Initializes slicing spec from proto. If not None, overrides any - values passed in columns or features. + ---- + slice_spec: Proto counnterpart of slicing spec. + + Returns: + ------- + The python object of single slice spec or cross slice spec. Raises: - ValueError: There was overlap between the columns specified in columns - and those in features. - ValueError: columns or features was a string: they should probably be a - singleton list containing that string. + ------ + NotImplementedError: if the type of slice_spec is not supported. """ - if isinstance(columns, str): - raise ValueError('columns is a string: it should probably be a singleton ' - 'list containing that string') - if isinstance(features, str): - raise ValueError('features is a string: it should probably be a ' - 'singleton list containing that string') - - if spec is not None: - columns = spec.feature_keys - features = [(k, v) for k, v in spec.feature_values.items()] - - features = [(k, _to_type(v)) for (k, v) in features] - - self._columns = frozenset(columns) - self._features = frozenset(features) - - # We build this up as an instance variable, instead of building it each - # time we call generate_slices, for efficiency reasons. - # - # This is a flat list of SingletonSliceKeyTypes, - # i.e. List[SingletonSliceKeyType]. - self._value_matches = [] - - for (key, value) in self._features: - if key in self._columns: - raise ValueError('Columns specified in columns and in features should ' - 'not overlap, but %s was specified in both.' % key) - self._value_matches.append((key, value)) - self._value_matches = sorted(self._value_matches) - - def __repr__(self): - return 'SingleSliceSpec(columns=%s, features=%s)' % (self._columns, - self._features) - - def to_proto(self) -> config_pb2.SlicingSpec: - feature_values = {k: str(v) for (k, v) in self._features} - return config_pb2.SlicingSpec( - feature_keys=self._columns, feature_values=feature_values) - - def is_overall(self): - """Returns True if this specification represents the overall slice.""" - return not self._columns and not self._features - - def is_slice_applicable(self, slice_key: SliceKeyType): - """Determines if this slice spec is applicable to a slice of data. + if isinstance(slice_spec, config_pb2.SlicingSpec): + return SingleSliceSpec(spec=slice_spec) + elif isinstance(slice_spec, config_pb2.CrossSlicingSpec): + return CrossSliceSpec(spec=slice_spec) + else: + raise NotImplementedError( + f"Not implemented for slice_spec type: {type(slice_spec)}" + ) + + +def serialize_slice_key(slice_key: SliceKeyType) -> metrics_for_slice_pb2.SliceKey: + """Converts SliceKeyType to SliceKey proto. Args: - slice_key: The slice as a SliceKeyType + ---- + slice_key: The slice key in the format of SliceKeyType. Returns: - True if the slice_spec is applicable to the given slice, False otherwise. + ------- + The slice key in the format of SliceKey proto. + + Raises: + ------ + TypeError: If the evaluate type is unrecognized. """ - columns = list(self._columns) - features = list(self._features) - for singleton_slice_key in slice_key: - # Convert to internal representation of slice (i.e. str -> float, etc). - if len(singleton_slice_key) == 2: - singleton_slice_key = (singleton_slice_key[0], - _to_type(singleton_slice_key[1])) - if singleton_slice_key in features: - features.remove(singleton_slice_key) - elif singleton_slice_key[0] in columns: - columns.remove(singleton_slice_key[0]) - else: - return False - return not features and not columns - - def generate_slices( - self, accessor: slice_accessor.SliceAccessor - ) -> Generator[SliceKeyType, None, None]: - """Generates all slices that match this specification from the data. - - Should only be called within this file. - - Examples: - - columns = [], features = [] (the overall slice case) - slice accessor has features age=[5], gender=['f'], interest=['knitting'] - returns [()] - - columns = ['age'], features = [('gender', 'f')] - slice accessor has features age=[5], gender=['f'], interest=['knitting'] - returns [[('age', 5), ('gender, 'f')]] - - columns = ['interest'], features = [('gender', 'f')] - slice accessor has features age=[5], gender=['f'], - interest=['knitting', 'games'] - returns [[('gender', 'f'), ('interest, 'knitting')], - [('gender', 'f'), ('interest, 'games')]] + result = metrics_for_slice_pb2.SliceKey() + + for col, val in slice_key: + single_slice_key = result.single_slice_keys.add() + single_slice_key.column = col + if isinstance(val, (bytes, str)): + single_slice_key.bytes_value = tf.compat.as_bytes(val) + elif isinstance(val, int): + single_slice_key.int64_value = val + elif isinstance(val, float): + single_slice_key.float_value = val + else: + raise TypeError("unrecognized type of type %s, value %s" % (type(val), val)) - Args: - accessor: slice accessor. + return result - Yields: - A SliceKeyType for each slice that matches this specification. Nothing - will be yielded if there no slices matched this specification. The entries - in the yielded SliceKeyTypes are guaranteed to be sorted by key names (and - then values, if necessary), ascending. - """ - # Check all the value matches (where there's a specific value specified). - for (key, value) in self._features: - if not accessor.has_key(key): - return - - accessor_values = accessor.get(key) - if value not in accessor_values: - if isinstance(value, str): - if value.encode() not in accessor_values: # For Python3. - return - # Check that string version of int/float not in values. - elif str(value) not in accessor_values: - return - - # Get all the column matches (where we're matching only the column). - # - # For each column, we generate a List[SingletonSliceKeyType] containing - # all pairs (column, value) for all values of the column. So this will be - # a List[List[SingletonSliceKeyType]]. - # - # For example, column_matches might be: - # [[('gender', 'f'), ('gender', 'm')], [('age', 4), ('age', 5)]] - column_matches = [] - for column in self._columns: - # If a column to slice on doesn't appear in the example, then there will - # be no applicable slices, so return. - if not accessor.has_key(column): - return - - column_match = [] - for value in accessor.get(column): - if isinstance(value, bytes): - try: - column_match.append((column, tf.compat.as_text(value))) - except UnicodeDecodeError as e: - raise ValueError('Found non-UTF8 feature value {} in ' - 'column "{}"'.format(value, column)) from e + +def _to_type(v: FeatureValueType) -> FeatureValueType: + """Converts string versions of ints and floats to respective values.""" + if isinstance(v, float) or isinstance(v, int): + return v + try: + v = str(v) + if "." in v: + return float(v) else: - column_match.append((column, value)) - column_matches.append(column_match) + return int(v) + except ValueError: + return v - # We can now take the Cartesian product of the column_matches, and append - # the value matches to each element of that, to generate the final list of - # slices. Note that for the overall slice case the column_matches is [] and - # the Cartesian product of [] is (). - for column_part in itertools.product(*column_matches): - yield tuple(sorted(self._value_matches + list(column_part))) +def serialize_cross_slice_key( + cross_slice_key: CrossSliceKeyType, +) -> metrics_for_slice_pb2.CrossSliceKey: + """Converts CrossSliceKeyType to CrossSliceKey proto.""" + result = metrics_for_slice_pb2.CrossSliceKey() + baseline_slice_key, comparison_slice_key = cross_slice_key + result.baseline_slice_key.CopyFrom(serialize_slice_key(baseline_slice_key)) + result.comparison_slice_key.CopyFrom(serialize_slice_key(comparison_slice_key)) + return result -class CrossSliceSpec( - NamedTuple('CrossSliceSpec', [('base_slicing_spec', SingleSliceSpec), - ('slicing_specs', Tuple[SingleSliceSpec])])): - """Specification for a cross slice. - This is intended to be an immutable class that specifies a cross slice. - Use this in conjunction with get_slices_for_features_dicts to generate slices - for dictionaries of features. +def deserialize_slice_key(slice_key: metrics_for_slice_pb2.SliceKey) -> SliceKeyType: + """Converts SliceKey proto to SliceKeyType. - Attributes: - base_slicing_spec: The baseline slicing spec. - slicing_specs: A tuple of slicing specs of the proto counterparts. - """ + Args: + ---- + slice_key: The slice key in the format of proto SliceKey. - def __new__(cls, spec: config_pb2.CrossSlicingSpec): - """Create a new CrrossSliceSpec object from its Proto counnterpart.""" - # This is organized as a Tuple(baseline_spec, Tuple(slice_specs)) - return super(CrossSliceSpec, cls).__new__( - cls, SingleSliceSpec(spec=spec.baseline_spec), - tuple( - SingleSliceSpec(spec=slicing_spec) - for slicing_spec in spec.slicing_specs)) + Returns: + ------- + The slice key in the format of SliceKeyType. + Raises: + ------ + TypeError: If the evaluate type is unreconized. + """ + result = [] + for elem in slice_key.single_slice_keys: + if elem.HasField("bytes_value"): + value = tf.compat.as_text(elem.bytes_value) + elif elem.HasField("int64_value"): + value = elem.int64_value + elif elem.HasField("float_value"): + value = elem.float_value + else: + raise TypeError( + "unrecognized type of type %s, value %s" % (type(elem), elem) + ) + result.append((elem.column, value)) + return tuple(result) -def deserialize_slice_spec( - slice_spec: Union[config_pb2.SlicingSpec, config_pb2.CrossSlicingSpec] -) -> Union[SingleSliceSpec, CrossSliceSpec]: - """Creates the appropriate hashable slicing spec object. - - Args: - slice_spec: Proto counnterpart of slicing spec. - - Returns: - The python object of single slice spec or cross slice spec. - - Raises: - NotImplementedError: if the type of slice_spec is not supported. - """ - if isinstance(slice_spec, config_pb2.SlicingSpec): - return SingleSliceSpec(spec=slice_spec) - elif isinstance(slice_spec, config_pb2.CrossSlicingSpec): - return CrossSliceSpec(spec=slice_spec) - else: - raise NotImplementedError( - f'Not implemented for slice_spec type: {type(slice_spec)}') - - -def serialize_slice_key( - slice_key: SliceKeyType) -> metrics_for_slice_pb2.SliceKey: - """Converts SliceKeyType to SliceKey proto. - - Args: - slice_key: The slice key in the format of SliceKeyType. - - Returns: - The slice key in the format of SliceKey proto. - - Raises: - TypeError: If the evaluate type is unrecognized. - """ - result = metrics_for_slice_pb2.SliceKey() - - for (col, val) in slice_key: - single_slice_key = result.single_slice_keys.add() - single_slice_key.column = col - if isinstance(val, (bytes, str)): - single_slice_key.bytes_value = tf.compat.as_bytes(val) - elif isinstance(val, int): - single_slice_key.int64_value = val - elif isinstance(val, float): - single_slice_key.float_value = val - else: - raise TypeError('unrecognized type of type %s, value %s' % - (type(val), val)) - return result +def deserialize_cross_slice_key( + cross_slice_key: metrics_for_slice_pb2.CrossSliceKey, +) -> CrossSliceKeyType: + """Converts CrossSliceKey proto to CrossSliceKeyType. + Args: + ---- + cross_slice_key: The cross slice key in the format of proto CrossSliceKey. -def _to_type(v: FeatureValueType) -> FeatureValueType: - """Converts string versions of ints and floats to respective values.""" - if isinstance(v, float) or isinstance(v, int): - return v - try: - v = str(v) - if '.' in v: - return float(v) - else: - return int(v) - except ValueError: - return v + Returns: + ------- + The cross slice key in the format of CrossSliceKeyType. + + Raises: + ------ + TypeError: If the evaluate type is unrecognized. + """ + baseline_key = deserialize_slice_key(cross_slice_key.baseline_slice_key) + comparison_key = deserialize_slice_key(cross_slice_key.comparison_slice_key) + return (baseline_key, comparison_key) -def serialize_cross_slice_key( - cross_slice_key: CrossSliceKeyType) -> metrics_for_slice_pb2.CrossSliceKey: - """Converts CrossSliceKeyType to CrossSliceKey proto.""" - result = metrics_for_slice_pb2.CrossSliceKey() - baseline_slice_key, comparison_slice_key = cross_slice_key - result.baseline_slice_key.CopyFrom(serialize_slice_key(baseline_slice_key)) - result.comparison_slice_key.CopyFrom( - serialize_slice_key(comparison_slice_key)) - return result - - -def deserialize_slice_key( - slice_key: metrics_for_slice_pb2.SliceKey) -> SliceKeyType: - """Converts SliceKey proto to SliceKeyType. - - Args: - slice_key: The slice key in the format of proto SliceKey. - - Returns: - The slice key in the format of SliceKeyType. - - Raises: - TypeError: If the evaluate type is unreconized. - """ - result = [] - for elem in slice_key.single_slice_keys: - if elem.HasField('bytes_value'): - value = tf.compat.as_text(elem.bytes_value) - elif elem.HasField('int64_value'): - value = elem.int64_value - elif elem.HasField('float_value'): - value = elem.float_value - else: - raise TypeError('unrecognized type of type %s, value %s' % - (type(elem), elem)) - result.append((elem.column, value)) - return tuple(result) +def get_slices_for_features_dicts( + features_dicts: Iterable[ + Union[types.DictOfTensorValue, types.DictOfFetchedTensorValues] + ], + default_features_dict: Union[ + types.DictOfTensorValue, types.DictOfFetchedTensorValues + ], + slice_spec: List[SingleSliceSpec], +) -> Iterable[SliceKeyType]: + """Generates the slice keys appropriate for the given features dictionaries. + Args: + ---- + features_dicts: Features dictionaries. For example a list of transformed + features dictionaries. + default_features_dict: Additional dict to search if a match is not found in + features dictionaries. For example the raw features. + slice_spec: slice specification. -def deserialize_cross_slice_key( - cross_slice_key: metrics_for_slice_pb2.CrossSliceKey) -> CrossSliceKeyType: - """Converts CrossSliceKey proto to CrossSliceKeyType. + Yields: + ------ + Slice keys appropriate for the given features dictionaries. + """ + accessor = slice_accessor.SliceAccessor(features_dicts, default_features_dict) + for single_slice_spec in slice_spec: + for slice_key in single_slice_spec.generate_slices(accessor): + yield slice_key - Args: - cross_slice_key: The cross slice key in the format of proto CrossSliceKey. - Returns: - The cross slice key in the format of CrossSliceKeyType. +def stringify_slice_key(slice_key: SliceKeyType) -> str: + """Stringifies a slice key. - Raises: - TypeError: If the evaluate type is unrecognized. - """ - baseline_key = deserialize_slice_key(cross_slice_key.baseline_slice_key) - comparison_key = deserialize_slice_key(cross_slice_key.comparison_slice_key) - return (baseline_key, comparison_key) + The string representation of a SingletonSliceKeyType is "feature:value". When + multiple columns / features are specified, the string representation of a + SliceKeyType is "c1_X_c2_X_...:v1_X_v2_X_..." where c1, c2, ... are the column + names and v1, v2, ... are the corresponding values For example, + ('gender, 'f'), ('age', 5) befores age_X_gender:f_X_5. If no columns / feature + specified, return "Overall". + Note that we do not perform special escaping for slice values that contain + '_X_'. This stringified representation is meant to be human-readbale rather + than a reversible encoding. -def get_slices_for_features_dicts( - features_dicts: Iterable[Union[types.DictOfTensorValue, - types.DictOfFetchedTensorValues]], - default_features_dict: Union[types.DictOfTensorValue, - types.DictOfFetchedTensorValues], - slice_spec: List[SingleSliceSpec]) -> Iterable[SliceKeyType]: - """Generates the slice keys appropriate for the given features dictionaries. - - Args: - features_dicts: Features dictionaries. For example a list of transformed - features dictionaries. - default_features_dict: Additional dict to search if a match is not found in - features dictionaries. For example the raw features. - slice_spec: slice specification. - - Yields: - Slice keys appropriate for the given features dictionaries. - """ - accessor = slice_accessor.SliceAccessor(features_dicts, default_features_dict) - for single_slice_spec in slice_spec: - for slice_key in single_slice_spec.generate_slices(accessor): - yield slice_key + The columns will be in the same order as in SliceKeyType. If they are + generated using SingleSliceSpec.generate_slices, they will be in sorted order, + ascending. + Technically float values are not supported, but we don't check for them here. -def stringify_slice_key(slice_key: SliceKeyType) -> str: - """Stringifies a slice key. - - The string representation of a SingletonSliceKeyType is "feature:value". When - multiple columns / features are specified, the string representation of a - SliceKeyType is "c1_X_c2_X_...:v1_X_v2_X_..." where c1, c2, ... are the column - names and v1, v2, ... are the corresponding values For example, - ('gender, 'f'), ('age', 5) befores age_X_gender:f_X_5. If no columns / feature - specified, return "Overall". - - Note that we do not perform special escaping for slice values that contain - '_X_'. This stringified representation is meant to be human-readbale rather - than a reversible encoding. - - The columns will be in the same order as in SliceKeyType. If they are - generated using SingleSliceSpec.generate_slices, they will be in sorted order, - ascending. - - Technically float values are not supported, but we don't check for them here. - - Args: - slice_key: Slice key to stringify. The constituent SingletonSliceKeyTypes - should be sorted in ascending order. - - Returns: - String representation of the slice key. - """ - key_count = len(slice_key) - if not key_count: - return OVERALL_SLICE_NAME - - keys = [] - values = [] - separator = '_X_' - - for (feature, value) in slice_key: - # Since this is meant to be a human-readable string, we assume that the - # feature and feature value are valid UTF-8 strings (might not be true in - # cases where people store serialised protos in the features for instance). - keys.append(tf.compat.as_text(feature)) - # We need to call as_str_any to convert non-string (e.g. integer) values to - # string first before converting to text. - values.append(tf.compat.as_text(tf.compat.as_str_any(value))) - - # To use u'{}' instead of '{}' here to avoid encoding a unicode character with - # ascii codec. - return (separator.join([u'{}'.format(key) for key in keys]) + ':' + - separator.join([u'{}'.format(value) for value in values])) + Args: + ---- + slice_key: Slice key to stringify. The constituent SingletonSliceKeyTypes + should be sorted in ascending order. + + Returns: + ------- + String representation of the slice key. + """ + key_count = len(slice_key) + if not key_count: + return OVERALL_SLICE_NAME + + keys = [] + values = [] + separator = "_X_" + + for feature, value in slice_key: + # Since this is meant to be a human-readable string, we assume that the + # feature and feature value are valid UTF-8 strings (might not be true in + # cases where people store serialised protos in the features for instance). + keys.append(tf.compat.as_text(feature)) + # We need to call as_str_any to convert non-string (e.g. integer) values to + # string first before converting to text. + values.append(tf.compat.as_text(tf.compat.as_str_any(value))) + + # To use u'{}' instead of '{}' here to avoid encoding a unicode character with + # ascii codec. + return ( + separator.join([f"{key}" for key in keys]) + + ":" + + separator.join([f"{value}" for value in values]) + ) def slice_keys_to_numpy_array(slice_keys: List[SliceKeyType]) -> np.ndarray: - """Converts a list of slice keys into a numpy array. + """Converts a list of slice keys into a numpy array. - This must be done in a special way to avoid numpy treating the slice key - tuples as additional dimensions in the numpy array. + This must be done in a special way to avoid numpy treating the slice key + tuples as additional dimensions in the numpy array. - Args: - slice_keys: A list of SliceKeyTypes + Args: + ---- + slice_keys: A list of SliceKeyTypes - Returns: - A numpy array with dtype=object where individual values are tuples. - """ - result = np.empty(len(slice_keys), dtype=object) - for i, slice_key in enumerate(slice_keys): - result[i] = slice_key - return result + Returns: + ------- + A numpy array with dtype=object where individual values are tuples. + """ + result = np.empty(len(slice_keys), dtype=object) + for i, slice_key in enumerate(slice_keys): + result[i] = slice_key + return result def is_cross_slice_applicable( - cross_slice_key: CrossSliceKeyType, - cross_slicing_spec: config_pb2.CrossSlicingSpec) -> bool: - """Checks if CrossSlicingSpec is applicable to the CrossSliceKeyType.""" - baseline_slice_key, comparison_slice_key = cross_slice_key - - if not SingleSliceSpec(spec=cross_slicing_spec.baseline_spec - ).is_slice_applicable(baseline_slice_key): + cross_slice_key: CrossSliceKeyType, cross_slicing_spec: config_pb2.CrossSlicingSpec +) -> bool: + """Checks if CrossSlicingSpec is applicable to the CrossSliceKeyType.""" + baseline_slice_key, comparison_slice_key = cross_slice_key + + if not SingleSliceSpec(spec=cross_slicing_spec.baseline_spec).is_slice_applicable( + baseline_slice_key + ): + return False + for comparison_slicing_spec in cross_slicing_spec.slicing_specs: + if SingleSliceSpec(spec=comparison_slicing_spec).is_slice_applicable( + comparison_slice_key + ): + return True return False - for comparison_slicing_spec in cross_slicing_spec.slicing_specs: - if SingleSliceSpec( - spec=comparison_slicing_spec).is_slice_applicable(comparison_slice_key): - return True - return False - -def get_slice_key_type( - slice_key: Union[SliceKeyType, CrossSliceKeyType]) -> Any: - """Determines if the slice_key in SliceKeyType or CrossSliceKeyType format. - Args: - slice_key: The slice key which can be in SliceKeyType format or - CrossSliceType format. +def get_slice_key_type(slice_key: Union[SliceKeyType, CrossSliceKeyType]) -> Any: + """Determines if the slice_key in SliceKeyType or CrossSliceKeyType format. - Returns: - SliceKeyType object or CrossSliceKeyType object. + Args: + ---- + slice_key: The slice key which can be in SliceKeyType format or + CrossSliceType format. - Raises: - TypeError: If slice key is not recognized. - """ + Returns: + ------- + SliceKeyType object or CrossSliceKeyType object. - def is_singleton_slice_key_type( - singleton_slice_key: SingletonSliceKeyType) -> bool: - try: - col, val = singleton_slice_key - except ValueError: - return False - if (isinstance(col, (bytes, str)) and - (isinstance(val, (bytes, str)) or isinstance(val, int) or - isinstance(val, float))): - return True - else: - return False + Raises: + ------ + TypeError: If slice key is not recognized. + """ - def is_slice_key_type(slice_key: SliceKeyType) -> bool: - if not slice_key: - return True + def is_singleton_slice_key_type(singleton_slice_key: SingletonSliceKeyType) -> bool: + try: + col, val = singleton_slice_key + except ValueError: + return False + if isinstance(col, (bytes, str)) and ( + isinstance(val, (bytes, str)) + or isinstance(val, int) + or isinstance(val, float) + ): + return True + else: + return False - for single_slice_key in slice_key: - if not is_singleton_slice_key_type(single_slice_key): - return False - return True + def is_slice_key_type(slice_key: SliceKeyType) -> bool: + if not slice_key: + return True - if is_slice_key_type(slice_key): - return SliceKeyType + for single_slice_key in slice_key: + if not is_singleton_slice_key_type(single_slice_key): + return False + return True - try: - baseline_slice, comparison_slice = slice_key # pytype: disable=bad-unpacking - except ValueError as e: - raise TypeError(f'Unrecognized slice type for slice_key: {slice_key}. ' - 'Neither SliceKeyType nor CrossSliceKeyType.') from e + if is_slice_key_type(slice_key): + return SliceKeyType - if (is_slice_key_type(baseline_slice) and - is_slice_key_type(comparison_slice)): - return CrossSliceKeyType - else: - raise TypeError('Unrecognized slice type. Neither SliceKeyType nor' - ' CrossSliceKeyType.') + try: + baseline_slice, comparison_slice = slice_key # pytype: disable=bad-unpacking + except ValueError as e: + raise TypeError( + f"Unrecognized slice type for slice_key: {slice_key}. " + "Neither SliceKeyType nor CrossSliceKeyType." + ) from e + + if is_slice_key_type(baseline_slice) and is_slice_key_type(comparison_slice): + return CrossSliceKeyType + else: + raise TypeError( + "Unrecognized slice type. Neither SliceKeyType nor" " CrossSliceKeyType." + ) -def is_cross_slice_key( - slice_key: Union[SliceKeyType, CrossSliceKeyType]) -> bool: - """Returns whether slice_key is cross_slice or not.""" - return get_slice_key_type(slice_key) == CrossSliceKeyType +def is_cross_slice_key(slice_key: Union[SliceKeyType, CrossSliceKeyType]) -> bool: + """Returns whether slice_key is cross_slice or not.""" + return get_slice_key_type(slice_key) == CrossSliceKeyType def slice_key_matches_slice_specs( - slice_key: SliceKeyType, slice_specs: Iterable[SingleSliceSpec]) -> bool: - """Checks whether a slice key matches any slice spec. + slice_key: SliceKeyType, slice_specs: Iterable[SingleSliceSpec] +) -> bool: + """Checks whether a slice key matches any slice spec. - In this setting, a slice key matches a slice spec if it could have been - generated by that spec. + In this setting, a slice key matches a slice spec if it could have been + generated by that spec. - Args: - slice_key: The slice key to check for applicability against slice specs. - slice_specs: Slice specs against which to check applicability of a slice - key. + Args: + ---- + slice_key: The slice key to check for applicability against slice specs. + slice_specs: Slice specs against which to check applicability of a slice + key. - Returns: - True if the slice_key matches any slice specs, False otherwise. - """ - for slice_spec in slice_specs: - if slice_spec.is_slice_applicable(slice_key): - return True - return False + Returns: + ------- + True if the slice_key matches any slice specs, False otherwise. + """ + for slice_spec in slice_specs: + if slice_spec.is_slice_applicable(slice_key): + return True + return False @beam.typehints.with_input_types(types.Extracts) @beam.typehints.with_output_types(Tuple[SliceKeyType, types.Extracts]) class _FanoutSlicesDoFn(beam.DoFn): - """A DoFn that performs per-slice key fanout prior to computing aggregates.""" - - def __init__(self, key_filter_fn: Callable[[str], bool]): - self._num_slices_generated_per_instance = beam.metrics.Metrics.distribution( - constants.METRICS_NAMESPACE, 'num_slices_generated_per_instance') - self._post_slice_num_instances = beam.metrics.Metrics.counter( - constants.METRICS_NAMESPACE, 'post_slice_num_instances') - self._key_filter_fn = key_filter_fn - - def process( - self, - element: types.Extracts) -> List[Tuple[SliceKeyType, types.Extracts]]: - key_filter_fn = self._key_filter_fn # Local cache. - filtered = {k: v for k, v in element.items() if key_filter_fn(k)} - slice_keys = element.get(constants.SLICE_KEY_TYPES_KEY) - # The query based evaluator will group slices from multiple examples, so we - # deduplicate to avoid overcounting. Depending on whether the rows within a - # batch have a variable or fixed length, either a VarLenTensorValue or a 2D - # np.ndarray will be created. - if isinstance(slice_keys, types.VarLenTensorValue): - slice_keys = slice_keys.values - elif isinstance(slice_keys, np.ndarray) and len(slice_keys.shape) == 2: - slice_keys = slice_keys.flatten() - result = [(slice_key, filtered) for slice_key in set(slice_keys)] - self._num_slices_generated_per_instance.update(len(result)) - self._post_slice_num_instances.inc(len(result)) - return result + """A DoFn that performs per-slice key fanout prior to computing aggregates.""" + + def __init__(self, key_filter_fn: Callable[[str], bool]): + self._num_slices_generated_per_instance = beam.metrics.Metrics.distribution( + constants.METRICS_NAMESPACE, "num_slices_generated_per_instance" + ) + self._post_slice_num_instances = beam.metrics.Metrics.counter( + constants.METRICS_NAMESPACE, "post_slice_num_instances" + ) + self._key_filter_fn = key_filter_fn + + def process( + self, element: types.Extracts + ) -> List[Tuple[SliceKeyType, types.Extracts]]: + key_filter_fn = self._key_filter_fn # Local cache. + filtered = {k: v for k, v in element.items() if key_filter_fn(k)} + slice_keys = element.get(constants.SLICE_KEY_TYPES_KEY) + # The query based evaluator will group slices from multiple examples, so we + # deduplicate to avoid overcounting. Depending on whether the rows within a + # batch have a variable or fixed length, either a VarLenTensorValue or a 2D + # np.ndarray will be created. + if isinstance(slice_keys, types.VarLenTensorValue): + slice_keys = slice_keys.values + elif isinstance(slice_keys, np.ndarray) and len(slice_keys.shape) == 2: + slice_keys = slice_keys.flatten() + result = [(slice_key, filtered) for slice_key in set(slice_keys)] + self._num_slices_generated_per_instance.update(len(result)) + self._post_slice_num_instances.inc(len(result)) + return result # TODO(cyfoo): Possibly introduce the same telemetry in Lantern to help with @@ -626,43 +703,46 @@ def process( @beam.typehints.with_input_types(Tuple[SliceKeyType, types.Extracts]) @beam.typehints.with_output_types(int) def _TrackDistinctSliceKeys( # pylint: disable=invalid-name - slice_keys_and_values: beam.pvalue.PCollection) -> beam.pvalue.PCollection: - """Gathers slice key telemetry post slicing.""" - - def increment_counter(element): # pylint: disable=invalid-name - num_distinct_slice_keys = beam.metrics.Metrics.counter( - constants.METRICS_NAMESPACE, 'num_distinct_slice_keys') - num_distinct_slice_keys.inc(element) - return element - - return (slice_keys_and_values - | 'ExtractSliceKeys' >> beam.Keys() - | 'RemoveDuplicates' >> beam.Distinct() - | 'Size' >> beam.combiners.Count.Globally() - | 'IncrementCounter' >> beam.Map(increment_counter)) + slice_keys_and_values: beam.pvalue.PCollection, +) -> beam.pvalue.PCollection: + """Gathers slice key telemetry post slicing.""" + + def increment_counter(element): # pylint: disable=invalid-name + num_distinct_slice_keys = beam.metrics.Metrics.counter( + constants.METRICS_NAMESPACE, "num_distinct_slice_keys" + ) + num_distinct_slice_keys.inc(element) + return element + + return ( + slice_keys_and_values + | "ExtractSliceKeys" >> beam.Keys() + | "RemoveDuplicates" >> beam.Distinct() + | "Size" >> beam.combiners.Count.Globally() + | "IncrementCounter" >> beam.Map(increment_counter) + ) @beam.ptransform_fn @beam.typehints.with_input_types(types.Extracts) @beam.typehints.with_output_types(tuple[SliceKeyType, types.Extracts]) def FanoutSlices( # pylint: disable=invalid-name - pcoll: beam.pvalue.PCollection, - include_slice_keys_in_output: Optional[bool] = False + pcoll: beam.pvalue.PCollection, include_slice_keys_in_output: Optional[bool] = False ) -> beam.pvalue.PCollection: # pylint: disable=invalid-name - """Fan out extracts based on slice keys (slice keys removed by default).""" - if include_slice_keys_in_output: - key_filter_fn = lambda k: True - else: - pruned_keys = (constants.SLICE_KEY_TYPES_KEY, constants.SLICE_KEYS_KEY) - key_filter_fn = lambda k: k not in pruned_keys + """Fan out extracts based on slice keys (slice keys removed by default).""" + if include_slice_keys_in_output: + key_filter_fn = lambda k: True + else: + pruned_keys = (constants.SLICE_KEY_TYPES_KEY, constants.SLICE_KEYS_KEY) + key_filter_fn = lambda k: k not in pruned_keys - result = pcoll | 'DoSlicing' >> beam.ParDo(_FanoutSlicesDoFn(key_filter_fn)) + result = pcoll | "DoSlicing" >> beam.ParDo(_FanoutSlicesDoFn(key_filter_fn)) - # pylint: disable=no-value-for-parameter - _ = result | 'TrackDistinctSliceKeys' >> _TrackDistinctSliceKeys() - # pylint: enable=no-value-for-parameter + # pylint: disable=no-value-for-parameter + _ = result | "TrackDistinctSliceKeys" >> _TrackDistinctSliceKeys() + # pylint: enable=no-value-for-parameter - return result + return result # TFMA v1 uses Text for its keys while TFMA v2 uses MetricKey @@ -676,68 +756,73 @@ def FilterOutSlices( # pylint: disable=invalid-name values: beam.pvalue.PCollection, slices_count: beam.pvalue.PCollection, min_slice_size: int, - error_metric_key: str = '__ERROR__') -> beam.pvalue.PCollection: - """Filter out slices with examples count lower than k_anonymization_count. - - Since we might filter out certain slices to preserve privacy in the case of - small slices, to make end users aware of this, we will append filtered out - slice keys with empty data, and a debug message explaining the omission. - - Args: - values: PCollection of aggregated data keyed at slice_key - slices_count: PCollection of slice keys and their example count. - min_slice_size: If the number of examples in a specific slice is less than - min_slice_size, then an error will be returned for that slice. This will - be useful to ensure privacy by not displaying the aggregated data for - smaller number of examples. - error_metric_key: The special metric key to indicate errors. - - Returns: - A PCollection keyed at all the possible slice_key and aggregated data for - slice keys with example count more than min_slice_size and error - message for filtered out slices. - """ - - class FilterOutSmallSlicesDoFn(beam.DoFn): - """DoFn to filter out small slices.""" - - def __init__(self, error_metric_key: str): - self.error_metric_key = error_metric_key + error_metric_key: str = "__ERROR__", +) -> beam.pvalue.PCollection: + """Filter out slices with examples count lower than k_anonymization_count. - def process( - self, element: Tuple[SliceKeyType, _MetricsDict] - ) -> Generator[Tuple[SliceKeyType, _MetricsDict], None, None]: - """Filter out small slices. - - For slices (excluding overall slice) with examples count lower than - min_slice_size, it adds an error message. - - Args: - element: Tuple containing slice key and a dictionary containing - corresponding elements from merged pcollections. - - Yields: - PCollection of (slice_key, aggregated_data or error message) - """ - (slice_key, value) = element - if value['values']: - if (not slice_key or value['slices_count'][0] >= min_slice_size): - yield (slice_key, value['values'][0]) - else: - yield ( - slice_key, - { - self.error_metric_key: # LINT.IfChange - 'Example count for this slice key is lower than the ' - 'minimum required value: %d. No data is aggregated for ' - 'this slice.' % min_slice_size - # LINT.ThenChange(../addons/fairness/frontend/fairness-metrics-board/fairness-metrics-board.js) - }) - - return ({ - 'values': values, - 'slices_count': slices_count - } - | 'CoGroupingSlicesCountAndAggregatedData' >> beam.CoGroupByKey() - | 'FilterOutSmallSlices' >> beam.ParDo( - FilterOutSmallSlicesDoFn(error_metric_key))) + Since we might filter out certain slices to preserve privacy in the case of + small slices, to make end users aware of this, we will append filtered out + slice keys with empty data, and a debug message explaining the omission. + + Args: + ---- + values: PCollection of aggregated data keyed at slice_key + slices_count: PCollection of slice keys and their example count. + min_slice_size: If the number of examples in a specific slice is less than + min_slice_size, then an error will be returned for that slice. This will + be useful to ensure privacy by not displaying the aggregated data for + smaller number of examples. + error_metric_key: The special metric key to indicate errors. + + Returns: + ------- + A PCollection keyed at all the possible slice_key and aggregated data for + slice keys with example count more than min_slice_size and error + message for filtered out slices. + """ + + class FilterOutSmallSlicesDoFn(beam.DoFn): + """DoFn to filter out small slices.""" + + def __init__(self, error_metric_key: str): + self.error_metric_key = error_metric_key + + def process( + self, element: Tuple[SliceKeyType, _MetricsDict] + ) -> Generator[Tuple[SliceKeyType, _MetricsDict], None, None]: + """Filter out small slices. + + For slices (excluding overall slice) with examples count lower than + min_slice_size, it adds an error message. + + Args: + ---- + element: Tuple containing slice key and a dictionary containing + corresponding elements from merged pcollections. + + Yields: + ------ + PCollection of (slice_key, aggregated_data or error message) + """ + (slice_key, value) = element + if value["values"]: + if not slice_key or value["slices_count"][0] >= min_slice_size: + yield (slice_key, value["values"][0]) + else: + yield ( + slice_key, + { + self.error_metric_key: # LINT.IfChange + "Example count for this slice key is lower than the " + "minimum required value: %d. No data is aggregated for " + "this slice." % min_slice_size + # LINT.ThenChange(../addons/fairness/frontend/fairness-metrics-board/fairness-metrics-board.js) + }, + ) + + return ( + {"values": values, "slices_count": slices_count} + | "CoGroupingSlicesCountAndAggregatedData" >> beam.CoGroupByKey() + | "FilterOutSmallSlices" + >> beam.ParDo(FilterOutSmallSlicesDoFn(error_metric_key)) + ) diff --git a/tensorflow_model_analysis/slicer/slicer_test.py b/tensorflow_model_analysis/slicer/slicer_test.py index 637793477c..353f4a0818 100644 --- a/tensorflow_model_analysis/slicer/slicer_test.py +++ b/tensorflow_model_analysis/slicer/slicer_test.py @@ -13,90 +13,92 @@ # limitations under the License. """Slicer test.""" -from absl.testing import parameterized import apache_beam as beam -from apache_beam.testing import util import numpy as np import six import tensorflow as tf +from absl.testing import parameterized +from apache_beam.testing import util +from google.protobuf import text_format + from tensorflow_model_analysis import constants from tensorflow_model_analysis.api import types from tensorflow_model_analysis.extractors import slice_key_extractor from tensorflow_model_analysis.post_export_metrics import metric_keys -from tensorflow_model_analysis.proto import config_pb2 -from tensorflow_model_analysis.proto import metrics_for_slice_pb2 +from tensorflow_model_analysis.proto import config_pb2, metrics_for_slice_pb2 from tensorflow_model_analysis.slicer import slicer_lib as slicer from tensorflow_model_analysis.utils import test_util from tensorflow_model_analysis.utils import util as tfma_util -from google.protobuf import text_format - def make_features_dict(features_dict): - result = {} - for key, value in features_dict.items(): - result[key] = {'node': np.array(value)} - return result + result = {} + for key, value in features_dict.items(): + result[key] = {"node": np.array(value)} + return result def create_fpls(): - fpl1 = types.FeaturesPredictionsLabels( - input_ref=0, - features=make_features_dict({ - 'gender': ['f'], - 'age': [13], - 'interest': ['cars'] - }), - predictions=make_features_dict({ - 'kb': [1], - }), - labels=make_features_dict({'ad_risk_score': [0]})) - fpl2 = types.FeaturesPredictionsLabels( - input_ref=0, - features=make_features_dict({ - 'gender': ['m'], - 'age': [10], - 'interest': ['cars'] - }), - predictions=make_features_dict({ - 'kb': [1], - }), - labels=make_features_dict({'ad_risk_score': [0]})) - return [fpl1, fpl2] + fpl1 = types.FeaturesPredictionsLabels( + input_ref=0, + features=make_features_dict( + {"gender": ["f"], "age": [13], "interest": ["cars"]} + ), + predictions=make_features_dict( + { + "kb": [1], + } + ), + labels=make_features_dict({"ad_risk_score": [0]}), + ) + fpl2 = types.FeaturesPredictionsLabels( + input_ref=0, + features=make_features_dict( + {"gender": ["m"], "age": [10], "interest": ["cars"]} + ), + predictions=make_features_dict( + { + "kb": [1], + } + ), + labels=make_features_dict({"ad_risk_score": [0]}), + ) + return [fpl1, fpl2] def wrap_fpl(fpl): - return { - constants.INPUT_KEY: fpl, - constants.FEATURES_PREDICTIONS_LABELS_KEY: fpl - } + return {constants.INPUT_KEY: fpl, constants.FEATURES_PREDICTIONS_LABELS_KEY: fpl} class SlicerTest(test_util.TensorflowModelAnalysisTest, parameterized.TestCase): + def setUp(self): + super().setUp() + self.longMessage = True # pylint: disable=invalid-name + beam.typehints.disable_type_annotations() + + def _makeFeaturesDict(self, features_dict): + result = {} + for key, value in features_dict.items(): + result[key] = {"node": np.array(value)} + return result + + def assertSliceResult(self, name, features_dict, columns, features, expected): + spec = slicer.SingleSliceSpec(columns=columns, features=features) + msg = "Test case %s: slice on columns %s, features %s" % ( + name, + columns, + features, + ) + six.assertCountEqual( + self, + expected, + slicer.get_slices_for_features_dicts([features_dict], None, [spec]), + msg, + ) - def setUp(self): - super().setUp() - self.longMessage = True # pylint: disable=invalid-name - beam.typehints.disable_type_annotations() - - def _makeFeaturesDict(self, features_dict): - result = {} - for key, value in features_dict.items(): - result[key] = {'node': np.array(value)} - return result - - def assertSliceResult(self, name, features_dict, columns, features, expected): - spec = slicer.SingleSliceSpec(columns=columns, features=features) - msg = 'Test case %s: slice on columns %s, features %s' % (name, columns, - features) - six.assertCountEqual( - self, expected, - slicer.get_slices_for_features_dicts([features_dict], None, [spec]), - msg) - - def testDeserializeSliceKey(self): - slice_metrics = text_format.Parse( - """ + def testDeserializeSliceKey(self): + slice_metrics = text_format.Parse( + """ single_slice_keys { column: 'age' int64_value: 5 @@ -109,15 +111,18 @@ def testDeserializeSliceKey(self): column: 'price' float_value: 1.0 } - """, metrics_for_slice_pb2.SliceKey()) + """, + metrics_for_slice_pb2.SliceKey(), + ) - got_slice_key = slicer.deserialize_slice_key(slice_metrics) - self.assertCountEqual([('age', 5), ('language', 'english'), ('price', 1.0)], - got_slice_key) + got_slice_key = slicer.deserialize_slice_key(slice_metrics) + self.assertCountEqual( + [("age", 5), ("language", "english"), ("price", 1.0)], got_slice_key + ) - def testDeserializeCrossSliceKey(self): - slice_metrics = text_format.Parse( - """ + def testDeserializeCrossSliceKey(self): + slice_metrics = text_format.Parse( + """ baseline_slice_key { single_slice_keys { column: 'age' @@ -142,604 +147,752 @@ def testDeserializeCrossSliceKey(self): bytes_value: 'hindi' } } - """, metrics_for_slice_pb2.CrossSliceKey()) - - got_slice_key = slicer.deserialize_cross_slice_key(slice_metrics) - self.assertCountEqual( - ((('age', 5), ('language', 'english'), ('price', 1.0)), - (('age', 8), ('language', 'hindi'))), got_slice_key) - - def testSliceEquality(self): - overall = slicer.SingleSliceSpec() - age_column = slicer.SingleSliceSpec(columns=['age']) - age_feature = slicer.SingleSliceSpec(features=[('age', 5)]) - age_and_gender = slicer.SingleSliceSpec( - columns=['age'], features=[('gender', 'f')]) - - # Note that we construct new instances of the slices to ensure that we - # aren't just checking object identity. - def check_equality_and_hash_equality(left, right): - self.assertEqual(left, right) - self.assertEqual(hash(left), hash(right)) - - check_equality_and_hash_equality(overall, slicer.SingleSliceSpec()) - check_equality_and_hash_equality(age_column, - slicer.SingleSliceSpec(columns=['age'])) - check_equality_and_hash_equality( - age_feature, slicer.SingleSliceSpec(features=[('age', 5)])) - check_equality_and_hash_equality( - age_and_gender, - slicer.SingleSliceSpec(columns=['age'], features=[('gender', 'f')])) - - self.assertNotEqual(overall, age_column) - self.assertNotEqual(age_column, age_feature) - self.assertNotEqual(age_column, age_and_gender) - self.assertNotEqual(age_feature, age_and_gender) - - self.assertCountEqual([slicer.SingleSliceSpec()], [overall]) - self.assertCountEqual([ - slicer.SingleSliceSpec(columns=['age']), - slicer.SingleSliceSpec(), - slicer.SingleSliceSpec(features=[('age', 5)]), - slicer.SingleSliceSpec(columns=['age'], features=[('gender', 'f')]) - ], [age_and_gender, age_feature, overall, age_column]) - - def testNoOverlappingColumns(self): - self.assertRaises(ValueError, slicer.SingleSliceSpec, ['age'], [('age', 5)]) - - def testNonUTF8ValueRaisesValueError(self): - column_name = 'column_name' - invalid_value = b'\x8a' - spec = slicer.SingleSliceSpec(columns=[column_name]) - features_dict = self._makeFeaturesDict({ - column_name: [invalid_value], - }) - with self.assertRaisesRegex(ValueError, column_name): - list(slicer.get_slices_for_features_dicts([features_dict], None, [spec])) - - def testGetSlicesForFeaturesDictUnivalent(self): - test_cases = [ - ('Overall', [], [], [()]), - ('Feature does not match', [], [('age', 99)], []), - ('No such column', ['no_such_column'], [], []), - ('Single column', ['age'], [], [(('age', 5),)]), - ('Single feature', [], [('age', 5)], [(('age', 5),)]), - ('Single feature type mismatch', [], [('age', '5')], [(('age', 5),)]), - ('One column, one feature', - ['gender'], [('age', 5)], [(('age', 5), ('gender', 'f'))]), - ('Two features', ['interest', 'gender'], [('age', 5)], - [(('age', 5), ('gender', 'f'), ('interest', 'cars'))]), - ] # pyformat: disable - features_dict = self._makeFeaturesDict({ - 'gender': ['f'], - 'age': [5], - 'interest': ['cars'] - }) - for (name, columns, features, expected) in test_cases: - self.assertSliceResult(name, features_dict, columns, features, expected) - - def testGetSlicesForFeaturesDictMultivalent(self): - test_cases = [ - ( - 'One column', - ['fruits'], - [], - [ - (('fruits', 'apples'),), - (('fruits', 'pears'),) - ], - ), - ( - 'Two columns', - ['fruits', 'interests'], - [], - [ - (('fruits', 'apples'), ('interests', 'cars')), - (('fruits', 'apples'), ('interests', 'dogs')), - (('fruits', 'pears'), ('interests', 'cars')), - (('fruits', 'pears'), ('interests', 'dogs')) - ], - ), - ( - 'One feature', - [], - [('interests', 'cars')], - [ - (('interests', 'cars'),) - ], - ), - ( - 'Two features', - [], - [('gender', 'f'), ('interests', 'cars')], - [ - (('gender', 'f'), ('interests', 'cars')) - ], - ), - ( - 'One column, one feature', - ['fruits'], - [('interests', 'cars')], - [ - (('fruits', 'apples'), ('interests', 'cars')), - (('fruits', 'pears'), ('interests', 'cars')) - ], - ), - ( - 'One column, two features', - ['fruits'], - [('gender', 'f'), ('interests', 'cars')], + """, + metrics_for_slice_pb2.CrossSliceKey(), + ) + + got_slice_key = slicer.deserialize_cross_slice_key(slice_metrics) + self.assertCountEqual( + ( + (("age", 5), ("language", "english"), ("price", 1.0)), + (("age", 8), ("language", "hindi")), + ), + got_slice_key, + ) + + def testSliceEquality(self): + overall = slicer.SingleSliceSpec() + age_column = slicer.SingleSliceSpec(columns=["age"]) + age_feature = slicer.SingleSliceSpec(features=[("age", 5)]) + age_and_gender = slicer.SingleSliceSpec( + columns=["age"], features=[("gender", "f")] + ) + + # Note that we construct new instances of the slices to ensure that we + # aren't just checking object identity. + def check_equality_and_hash_equality(left, right): + self.assertEqual(left, right) + self.assertEqual(hash(left), hash(right)) + + check_equality_and_hash_equality(overall, slicer.SingleSliceSpec()) + check_equality_and_hash_equality( + age_column, slicer.SingleSliceSpec(columns=["age"]) + ) + check_equality_and_hash_equality( + age_feature, slicer.SingleSliceSpec(features=[("age", 5)]) + ) + check_equality_and_hash_equality( + age_and_gender, + slicer.SingleSliceSpec(columns=["age"], features=[("gender", "f")]), + ) + + self.assertNotEqual(overall, age_column) + self.assertNotEqual(age_column, age_feature) + self.assertNotEqual(age_column, age_and_gender) + self.assertNotEqual(age_feature, age_and_gender) + + self.assertCountEqual([slicer.SingleSliceSpec()], [overall]) + self.assertCountEqual( [ - (('fruits', 'apples'), ('gender', 'f'), ('interests', 'cars')), - (('fruits', 'pears'), ('gender', 'f'), ('interests', 'cars')), + slicer.SingleSliceSpec(columns=["age"]), + slicer.SingleSliceSpec(), + slicer.SingleSliceSpec(features=[("age", 5)]), + slicer.SingleSliceSpec(columns=["age"], features=[("gender", "f")]), ], - ), + [age_and_gender, age_feature, overall, age_column], + ) + + def testNoOverlappingColumns(self): + self.assertRaises(ValueError, slicer.SingleSliceSpec, ["age"], [("age", 5)]) + + def testNonUTF8ValueRaisesValueError(self): + column_name = "column_name" + invalid_value = b"\x8a" + spec = slicer.SingleSliceSpec(columns=[column_name]) + features_dict = self._makeFeaturesDict( + { + column_name: [invalid_value], + } + ) + with self.assertRaisesRegex(ValueError, column_name): + list(slicer.get_slices_for_features_dicts([features_dict], None, [spec])) + + def testGetSlicesForFeaturesDictUnivalent(self): + test_cases = [ + ("Overall", [], [], [()]), + ("Feature does not match", [], [("age", 99)], []), + ("No such column", ["no_such_column"], [], []), + ("Single column", ["age"], [], [(("age", 5),)]), + ("Single feature", [], [("age", 5)], [(("age", 5),)]), + ("Single feature type mismatch", [], [("age", "5")], [(("age", 5),)]), + ( + "One column, one feature", + ["gender"], + [("age", 5)], + [(("age", 5), ("gender", "f"))], + ), + ( + "Two features", + ["interest", "gender"], + [("age", 5)], + [(("age", 5), ("gender", "f"), ("interest", "cars"))], + ), + ] # pyformat: disable + features_dict = self._makeFeaturesDict( + {"gender": ["f"], "age": [5], "interest": ["cars"]} + ) + for name, columns, features, expected in test_cases: + self.assertSliceResult(name, features_dict, columns, features, expected) + + def testGetSlicesForFeaturesDictMultivalent(self): + test_cases = [ + ( + "One column", + ["fruits"], + [], + [(("fruits", "apples"),), (("fruits", "pears"),)], + ), + ( + "Two columns", + ["fruits", "interests"], + [], + [ + (("fruits", "apples"), ("interests", "cars")), + (("fruits", "apples"), ("interests", "dogs")), + (("fruits", "pears"), ("interests", "cars")), + (("fruits", "pears"), ("interests", "dogs")), + ], + ), + ( + "One feature", + [], + [("interests", "cars")], + [(("interests", "cars"),)], + ), + ( + "Two features", + [], + [("gender", "f"), ("interests", "cars")], + [(("gender", "f"), ("interests", "cars"))], + ), + ( + "One column, one feature", + ["fruits"], + [("interests", "cars")], + [ + (("fruits", "apples"), ("interests", "cars")), + (("fruits", "pears"), ("interests", "cars")), + ], + ), + ( + "One column, two features", + ["fruits"], + [("gender", "f"), ("interests", "cars")], + [ + (("fruits", "apples"), ("gender", "f"), ("interests", "cars")), + (("fruits", "pears"), ("gender", "f"), ("interests", "cars")), + ], + ), + ( + "Two columns, one feature", + ["interests", "fruits"], + [("gender", "f")], + [ + (("fruits", "apples"), ("gender", "f"), ("interests", "cars")), + (("fruits", "pears"), ("gender", "f"), ("interests", "cars")), + (("fruits", "apples"), ("gender", "f"), ("interests", "dogs")), + (("fruits", "pears"), ("gender", "f"), ("interests", "dogs")), + ], + ), + ( + "Two columns, two features", + ["interests", "fruits"], + [("gender", "f"), ("age", 5)], + [ + ( + ("age", 5), + ("fruits", "apples"), + ("gender", "f"), + ("interests", "cars"), + ), + ( + ("age", 5), + ("fruits", "pears"), + ("gender", "f"), + ("interests", "cars"), + ), + ( + ("age", 5), + ("fruits", "apples"), + ("gender", "f"), + ("interests", "dogs"), + ), + ( + ("age", 5), + ("fruits", "pears"), + ("gender", "f"), + ("interests", "dogs"), + ), + ], + ), + ] # pyformat: disable + + features_dict = self._makeFeaturesDict( + { + "gender": ["f"], + "age": [5], + "interests": ["cars", "dogs"], + "fruits": ["apples", "pears"], + } + ) + + for name, columns, features, expected in test_cases: + self.assertSliceResult(name, features_dict, columns, features, expected) + + def testGetSlicesForFeaturesDictMultipleSingleSliceSpecs(self): + features_dict = self._makeFeaturesDict( + {"gender": ["f"], "age": [5], "interest": ["cars"]} + ) + + spec_overall = slicer.SingleSliceSpec() + spec_age = slicer.SingleSliceSpec(columns=["age"]) + spec_age4 = slicer.SingleSliceSpec(features=[("age", 4)]) + spec_age5_gender = slicer.SingleSliceSpec( + columns=["gender"], features=[("age", 5)] + ) + + slice_spec = [spec_overall, spec_age, spec_age4, spec_age5_gender] + expected = [(), (("age", 5),), (("age", 5), ("gender", "f"))] + self.assertCountEqual( + expected, + slicer.get_slices_for_features_dicts([features_dict], None, slice_spec), + ) + + def testStringifySliceKey(self): + test_cases = [ + ("overall", (), "Overall"), + ("one bytes feature", (("age_str", "5"),), "age_str:5"), + ("one int64 feature", (("age", 1),), "age:1"), + ("mixed", (("age", 1), ("gender", "f")), "age_X_gender:1_X_f"), + ( + "more", + (("age", 1), ("gender", "f"), ("interest", "cars")), + "age_X_gender_X_interest:1_X_f_X_cars", + ), + ("unicode", (("text", b"\xe4\xb8\xad\xe6\x96\x87"),), "text:\u4e2d\u6587"), + ] # pyformat: disable + for name, slice_key, stringified_key in test_cases: + self.assertEqual( + stringified_key, slicer.stringify_slice_key(slice_key), msg=name + ) + + @parameterized.named_parameters( + ("empty_slice_keys", [], np.array([])), ( - 'Two columns, one feature', - ['interests', 'fruits'], [('gender', 'f')], - [ - (('fruits', 'apples'), ('gender', 'f'), ('interests', 'cars')), - (('fruits', 'pears'), ('gender', 'f'), ('interests', 'cars')), - (('fruits', 'apples'), ('gender', 'f'), ('interests', 'dogs')), - (('fruits', 'pears'), ('gender', 'f'), ('interests', 'dogs')) - ], + "specific_and_overall_slice_key", + [("f", 1), ()], + np.array([("f", 1), ()], dtype=object), ), - ( - 'Two columns, two features', - ['interests', 'fruits'], - [('gender', 'f'), ('age', 5)], - [ - (('age', 5), ('fruits', 'apples'), ('gender', 'f'), - ('interests', 'cars')), - (('age', 5), ('fruits', 'pears'), ('gender', 'f'), - ('interests', 'cars')), - (('age', 5), ('fruits', 'apples'), ('gender', 'f'), - ('interests', 'dogs')), - (('age', 5), ('fruits', 'pears'), ('gender', 'f'), - ('interests', 'dogs')) - ], + ) + def testSliceKeysToNumpy(self, slice_keys_tuples, expected_slice_keys_array): + np.testing.assert_array_equal( + slicer.slice_keys_to_numpy_array(slice_keys_tuples), + expected_slice_keys_array, ) - ] # pyformat: disable - - features_dict = self._makeFeaturesDict({ - 'gender': ['f'], - 'age': [5], - 'interests': ['cars', 'dogs'], - 'fruits': ['apples', 'pears'] - }) - - for (name, columns, features, expected) in test_cases: - self.assertSliceResult(name, features_dict, columns, features, expected) - - def testGetSlicesForFeaturesDictMultipleSingleSliceSpecs(self): - features_dict = self._makeFeaturesDict({ - 'gender': ['f'], - 'age': [5], - 'interest': ['cars'] - }) - - spec_overall = slicer.SingleSliceSpec() - spec_age = slicer.SingleSliceSpec(columns=['age']) - spec_age4 = slicer.SingleSliceSpec(features=[('age', 4)]) - spec_age5_gender = slicer.SingleSliceSpec( - columns=['gender'], features=[('age', 5)]) - - slice_spec = [spec_overall, spec_age, spec_age4, spec_age5_gender] - expected = [(), (('age', 5),), (('age', 5), ('gender', 'f'))] - self.assertCountEqual( - expected, - slicer.get_slices_for_features_dicts([features_dict], None, slice_spec)) - - def testStringifySliceKey(self): - test_cases = [ - ('overall', (), 'Overall'), - ('one bytes feature', (('age_str', '5'),), 'age_str:5'), - ('one int64 feature', (('age', 1),), 'age:1'), - ('mixed', (('age', 1), ('gender', 'f')), 'age_X_gender:1_X_f'), - ('more', (('age', 1), ('gender', 'f'), ('interest', 'cars')), - 'age_X_gender_X_interest:1_X_f_X_cars'), - ('unicode', (('text', b'\xe4\xb8\xad\xe6\x96\x87'),), u'text:\u4e2d\u6587'), - ] # pyformat: disable - for (name, slice_key, stringified_key) in test_cases: - self.assertEqual( - stringified_key, slicer.stringify_slice_key(slice_key), msg=name) - - @parameterized.named_parameters(('empty_slice_keys', [], np.array([])), - ('specific_and_overall_slice_key', [ - ('f', 1), () - ], np.array([('f', 1), ()], dtype=object))) - def testSliceKeysToNumpy(self, slice_keys_tuples, expected_slice_keys_array): - np.testing.assert_array_equal( - slicer.slice_keys_to_numpy_array(slice_keys_tuples), - expected_slice_keys_array) - - def testSliceKeysToNumpyOverall(self): - actual = slicer.slice_keys_to_numpy_array([()]) - self.assertIsInstance(actual, np.ndarray) - self.assertEqual(actual.dtype, object) - self.assertEqual(actual.shape, (1,)) - self.assertEqual(actual[0], ()) - - def testIsCrossSliceApplicable(self): - test_cases = [ - (True, 'overall pass', ((), (('b', 2),)), config_pb2.CrossSlicingSpec( - baseline_spec=config_pb2.SlicingSpec(), - slicing_specs=[config_pb2.SlicingSpec(feature_values={'b': '2'})])), - (True, 'value pass', ((('a', 1),), (('b', 2),)), - config_pb2.CrossSlicingSpec( - baseline_spec=config_pb2.SlicingSpec(feature_values={'a': '1'}), - slicing_specs=[config_pb2.SlicingSpec(feature_values={'b': '2'})])), - (True, 'baseline key pass', ((('a', 1),), (('b', 2),)), - config_pb2.CrossSlicingSpec( - baseline_spec=config_pb2.SlicingSpec(feature_keys=['a']), - slicing_specs=[config_pb2.SlicingSpec(feature_values={'b': '2'})])), - (True, 'comparison key pass', ((('a', 1),), (('b', 2),)), - config_pb2.CrossSlicingSpec( - baseline_spec=config_pb2.SlicingSpec(feature_values={'a': '1'}), - slicing_specs=[config_pb2.SlicingSpec(feature_keys=['b'])])), - (True, 'comparison multiple key pass', ((('a', 1),), (('c', 3),)), - config_pb2.CrossSlicingSpec( - baseline_spec=config_pb2.SlicingSpec(feature_values={'a': '1'}), - slicing_specs=[config_pb2.SlicingSpec(feature_keys=['b']), - config_pb2.SlicingSpec(feature_keys=['c'])])), - (False, 'overall fail', ((('a', 1),), (('b', 2),)), - config_pb2.CrossSlicingSpec( - baseline_spec=config_pb2.SlicingSpec(), - slicing_specs=[config_pb2.SlicingSpec(feature_values={'b': '2'})])), - (False, 'value fail', ((('a', 1),), (('b', 3),)), - config_pb2.CrossSlicingSpec( - baseline_spec=config_pb2.SlicingSpec(feature_values={'a': '1'}), - slicing_specs=[config_pb2.SlicingSpec(feature_values={'b': '2'})])), - (False, 'baseline key fail', ((('c', 1),), (('b', 2),)), - config_pb2.CrossSlicingSpec( - baseline_spec=config_pb2.SlicingSpec(feature_keys=['a']), - slicing_specs=[config_pb2.SlicingSpec(feature_values={'b': '2'})])), - (False, 'comparison key fail', ((('a', 1),), (('c', 3),)), - config_pb2.CrossSlicingSpec( - baseline_spec=config_pb2.SlicingSpec(feature_values={'a': '1'}), - slicing_specs=[config_pb2.SlicingSpec(feature_keys=['b'])])), - (False, 'comparison multiple key fail', ((('a', 1),), (('d', 3),)), - config_pb2.CrossSlicingSpec( - baseline_spec=config_pb2.SlicingSpec(feature_values={'a': '1'}), - slicing_specs=[config_pb2.SlicingSpec(feature_keys=['b']), - config_pb2.SlicingSpec(feature_keys=['c'])])), - ] # pyformat: disable - for (expected_result, name, sliced_key, slicing_spec) in test_cases: - self.assertEqual( - expected_result, - slicer.is_cross_slice_applicable( - cross_slice_key=sliced_key, cross_slicing_spec=slicing_spec), - msg=name) - - def testGetSliceKeyType(self): - test_cases = [ - (slicer.SliceKeyType, 'overall', ()), - (slicer.SliceKeyType, 'one bytes feature', (('a', '5'),)), - (slicer.SliceKeyType, 'one int64 feature', (('a', 1),)), - (slicer.SliceKeyType, 'mixed', (('a', 1), ('b', 'f'))), - (slicer.SliceKeyType, 'more', (('a', 1), ('b', 'f'), ('c', 'cars'))), - (slicer.SliceKeyType, 'unicode', - (('a', b'\xe4\xb8\xad\xe6\x96\x87'),)), - (slicer.CrossSliceKeyType, 'CrossSlice overall', ((), ())), - (slicer.CrossSliceKeyType, 'CrossSlice one slice key baseline', - ((('a', '5'),), ())), - (slicer.CrossSliceKeyType, 'CrossSlice one slice key comparison', - ((), (('a', 1),))), - (slicer.CrossSliceKeyType, 'CrossSlice two simple slice key', - ((('a', 1),), (('b', 'f'),))), - (slicer.CrossSliceKeyType, 'CrossSlice two multiple slice key', - ((('a', 1), ('b', 'f'), ('c', '11')), - (('a2', 1), ('b', 'm'), ('c', '11')))), - ] # pyformat: disable - for (expected_result, name, slice_key) in test_cases: - self.assertEqual( - expected_result, slicer.get_slice_key_type(slice_key), msg=name) - - unrecognized_test_cases = [ - ('Unrecognized 1: ', ('a')), - ('Unrecognized 2: ', ('a',)), - ('Unrecognized 3: ', ('a', 1)), - ('Unrecognized 4: ', (('a'))), - ('Unrecognized 5: ', (('a',))), - ('Unrecognized 6: ', ((), (), ())), - ('Unrecognized 7: ', ((('a', 1),), (('b', 1),), (('c', 1),))), - ('Unrecognized 8: ', ((('a', 1),), ('b', 1))), - ('Unrecognized 9: ', (('a', 1), (('b', 1),))), - ] # pyformat: disable - for (name, slice_key) in unrecognized_test_cases: - with self.assertRaises(TypeError, msg=name + str(slice_key)): - slicer.get_slice_key_type(slice_key) - - @parameterized.named_parameters( - { - 'testcase_name': '_single_slice_spec', - 'slice_type': slicer.SingleSliceSpec, - 'slicing_spec': config_pb2.SlicingSpec(feature_values={'a': '1'}), - }, { - 'testcase_name': - '_cross_slice_spec', - 'slice_type': - slicer.CrossSliceSpec, - 'slicing_spec': - config_pb2.CrossSlicingSpec( - baseline_spec=config_pb2.SlicingSpec(), - slicing_specs=[ - config_pb2.SlicingSpec(feature_values={'b': '2'}) - ]), - }) - def testDeserializeSliceSpec(self, slice_type, slicing_spec): - slice_spec = slicer.deserialize_slice_spec(slicing_spec) - self.assertIsInstance(slice_spec, slice_type) - - def testDeserializeSliceSpec_hashable(self): - single_slice_spec = slicer.deserialize_slice_spec( - config_pb2.SlicingSpec(feature_values={'a': '1'})) - cross_slice_spec = slicer.deserialize_slice_spec( - slicer.config_pb2.CrossSlicingSpec( - baseline_spec=config_pb2.SlicingSpec(), - slicing_specs=[config_pb2.SlicingSpec(feature_values={'b': '2'})])) - # Check either of them can be hashed and used as keys. - slice_map = {single_slice_spec: 1, cross_slice_spec: 2} - self.assertEqual(slice_map[single_slice_spec], 1) - self.assertEqual(slice_map[cross_slice_spec], 2) - - def testIsSliceApplicable(self): - test_cases = [ - ('applicable', ['column1'], - [('column3', 'value3'), ('column4', 'value4')], - (('column1', 'value1'), ('column3', 'value3'), ('column4', 'value4')), - True), - ('wrongcolumns', ['column1', 'column2'], - [('column3', 'value3'), ('column4', 'value4')], - (('column1', 'value1'), ('column3', 'value3'), ('column4', 'value4')), - False), - ('wrongfeatures', ['column1'], [('column3', 'value3')], - (('column1', 'value1'), ('column3', 'value3'), ('column4', 'value4')), - False), - ('nocolumns', [], [('column3', 'value3')], - (('column1', 'value1'), ('column3', 'value3'), ('column4', 'value4')), - False), - ('nofeatures', ['column1'], [], (('column1', 'value1'),), True), - ('empty slice key', ['column1'], [('column2', 'value1')], (), False), - ('overall', [], [], (), True) - ] # pyformat: disable - - for (name, columns, features, slice_key, result) in test_cases: - slice_spec = slicer.SingleSliceSpec(columns=columns, features=features) - self.assertEqual( - slice_spec.is_slice_applicable(slice_key), result, msg=name) - - def testSliceDefaultSlice(self): - with beam.Pipeline() as pipeline: - fpls = create_fpls() - - metrics = ( - pipeline - | 'CreateTestInput' >> beam.Create(fpls) - | 'WrapFpls' >> beam.Map(wrap_fpl) - | 'ExtractSlices' >> slice_key_extractor.ExtractSliceKeys( - [slicer.SingleSliceSpec()]) - | 'FanoutSlices' >> slicer.FanoutSlices()) - - def check_result(got): - try: - self.assertLen(got, 2) - expected_result = [ - ((), wrap_fpl(fpls[0])), - ((), wrap_fpl(fpls[1])), - ] - self.assertEqual(len(got), len(expected_result)) - self.assertTrue( - got[0] == expected_result[0] and got[1] == expected_result[1] or - got[1] == expected_result[0] and got[0] == expected_result[1]) - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(metrics, check_result) - - def testSliceOneSlice(self): - with beam.Pipeline() as pipeline: - fpls = create_fpls() - metrics = ( - pipeline - | 'CreateTestInput' >> beam.Create(fpls, reshuffle=False) - | 'WrapFpls' >> beam.Map(wrap_fpl) - | 'ExtractSlices' >> slice_key_extractor.ExtractSliceKeys([ - slicer.SingleSliceSpec(), - slicer.SingleSliceSpec(columns=['gender']) - ]) - | 'FanoutSlices' >> slicer.FanoutSlices()) - - def check_result(got): - try: - self.assertLen(got, 4) - expected_result = [ - ((), wrap_fpl(fpls[0])), - ((), wrap_fpl(fpls[1])), - ((('gender', 'f'),), wrap_fpl(fpls[0])), - ((('gender', 'm'),), wrap_fpl(fpls[1])), - ] - self.assertCountEqual(got, expected_result) - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(metrics, check_result) - - def testMultidimSlices(self): - data = [{ - 'features': { - 'gender': [['f'], ['f']], - 'age': [[13], [13]], - 'interest': [['cars'], ['cars']] - }, - 'predictions': [[1], [1]], - 'labels': [[0], [0]], - constants.SLICE_KEY_TYPES_KEY: - np.array([ - slicer.slice_keys_to_numpy_array([(), (('gender', 'f'),)]), - slicer.slice_keys_to_numpy_array([(), (('gender', 'f'),)]) - ]) - }, { - 'features': { - 'gender': [['f'], ['m']], - 'age': [[13], [10]], - 'interest': [['cars'], ['cars']] + + def testSliceKeysToNumpyOverall(self): + actual = slicer.slice_keys_to_numpy_array([()]) + self.assertIsInstance(actual, np.ndarray) + self.assertEqual(actual.dtype, object) + self.assertEqual(actual.shape, (1,)) + self.assertEqual(actual[0], ()) + + def testIsCrossSliceApplicable(self): + test_cases = [ + ( + True, + "overall pass", + ((), (("b", 2),)), + config_pb2.CrossSlicingSpec( + baseline_spec=config_pb2.SlicingSpec(), + slicing_specs=[config_pb2.SlicingSpec(feature_values={"b": "2"})], + ), + ), + ( + True, + "value pass", + ((("a", 1),), (("b", 2),)), + config_pb2.CrossSlicingSpec( + baseline_spec=config_pb2.SlicingSpec(feature_values={"a": "1"}), + slicing_specs=[config_pb2.SlicingSpec(feature_values={"b": "2"})], + ), + ), + ( + True, + "baseline key pass", + ((("a", 1),), (("b", 2),)), + config_pb2.CrossSlicingSpec( + baseline_spec=config_pb2.SlicingSpec(feature_keys=["a"]), + slicing_specs=[config_pb2.SlicingSpec(feature_values={"b": "2"})], + ), + ), + ( + True, + "comparison key pass", + ((("a", 1),), (("b", 2),)), + config_pb2.CrossSlicingSpec( + baseline_spec=config_pb2.SlicingSpec(feature_values={"a": "1"}), + slicing_specs=[config_pb2.SlicingSpec(feature_keys=["b"])], + ), + ), + ( + True, + "comparison multiple key pass", + ((("a", 1),), (("c", 3),)), + config_pb2.CrossSlicingSpec( + baseline_spec=config_pb2.SlicingSpec(feature_values={"a": "1"}), + slicing_specs=[ + config_pb2.SlicingSpec(feature_keys=["b"]), + config_pb2.SlicingSpec(feature_keys=["c"]), + ], + ), + ), + ( + False, + "overall fail", + ((("a", 1),), (("b", 2),)), + config_pb2.CrossSlicingSpec( + baseline_spec=config_pb2.SlicingSpec(), + slicing_specs=[config_pb2.SlicingSpec(feature_values={"b": "2"})], + ), + ), + ( + False, + "value fail", + ((("a", 1),), (("b", 3),)), + config_pb2.CrossSlicingSpec( + baseline_spec=config_pb2.SlicingSpec(feature_values={"a": "1"}), + slicing_specs=[config_pb2.SlicingSpec(feature_values={"b": "2"})], + ), + ), + ( + False, + "baseline key fail", + ((("c", 1),), (("b", 2),)), + config_pb2.CrossSlicingSpec( + baseline_spec=config_pb2.SlicingSpec(feature_keys=["a"]), + slicing_specs=[config_pb2.SlicingSpec(feature_values={"b": "2"})], + ), + ), + ( + False, + "comparison key fail", + ((("a", 1),), (("c", 3),)), + config_pb2.CrossSlicingSpec( + baseline_spec=config_pb2.SlicingSpec(feature_values={"a": "1"}), + slicing_specs=[config_pb2.SlicingSpec(feature_keys=["b"])], + ), + ), + ( + False, + "comparison multiple key fail", + ((("a", 1),), (("d", 3),)), + config_pb2.CrossSlicingSpec( + baseline_spec=config_pb2.SlicingSpec(feature_values={"a": "1"}), + slicing_specs=[ + config_pb2.SlicingSpec(feature_keys=["b"]), + config_pb2.SlicingSpec(feature_keys=["c"]), + ], + ), + ), + ] # pyformat: disable + for expected_result, name, sliced_key, slicing_spec in test_cases: + self.assertEqual( + expected_result, + slicer.is_cross_slice_applicable( + cross_slice_key=sliced_key, cross_slicing_spec=slicing_spec + ), + msg=name, + ) + + def testGetSliceKeyType(self): + test_cases = [ + (slicer.SliceKeyType, "overall", ()), + (slicer.SliceKeyType, "one bytes feature", (("a", "5"),)), + (slicer.SliceKeyType, "one int64 feature", (("a", 1),)), + (slicer.SliceKeyType, "mixed", (("a", 1), ("b", "f"))), + (slicer.SliceKeyType, "more", (("a", 1), ("b", "f"), ("c", "cars"))), + (slicer.SliceKeyType, "unicode", (("a", b"\xe4\xb8\xad\xe6\x96\x87"),)), + (slicer.CrossSliceKeyType, "CrossSlice overall", ((), ())), + ( + slicer.CrossSliceKeyType, + "CrossSlice one slice key baseline", + ((("a", "5"),), ()), + ), + ( + slicer.CrossSliceKeyType, + "CrossSlice one slice key comparison", + ((), (("a", 1),)), + ), + ( + slicer.CrossSliceKeyType, + "CrossSlice two simple slice key", + ((("a", 1),), (("b", "f"),)), + ), + ( + slicer.CrossSliceKeyType, + "CrossSlice two multiple slice key", + ( + (("a", 1), ("b", "f"), ("c", "11")), + (("a2", 1), ("b", "m"), ("c", "11")), + ), + ), + ] # pyformat: disable + for expected_result, name, slice_key in test_cases: + self.assertEqual( + expected_result, slicer.get_slice_key_type(slice_key), msg=name + ) + + unrecognized_test_cases = [ + ("Unrecognized 1: ", ("a")), + ("Unrecognized 2: ", ("a",)), + ("Unrecognized 3: ", ("a", 1)), + ("Unrecognized 4: ", ("a")), + ("Unrecognized 5: ", (("a",))), + ("Unrecognized 6: ", ((), (), ())), + ("Unrecognized 7: ", ((("a", 1),), (("b", 1),), (("c", 1),))), + ("Unrecognized 8: ", ((("a", 1),), ("b", 1))), + ("Unrecognized 9: ", (("a", 1), (("b", 1),))), + ] # pyformat: disable + for name, slice_key in unrecognized_test_cases: + with self.assertRaises(TypeError, msg=name + str(slice_key)): + slicer.get_slice_key_type(slice_key) + + @parameterized.named_parameters( + { + "testcase_name": "_single_slice_spec", + "slice_type": slicer.SingleSliceSpec, + "slicing_spec": config_pb2.SlicingSpec(feature_values={"a": "1"}), }, - 'predictions': [[1], [1]], - 'labels': [[0], [0]], - constants.SLICE_KEY_TYPES_KEY: - np.array([ - slicer.slice_keys_to_numpy_array([(), (('gender', 'f'),)]), - slicer.slice_keys_to_numpy_array([(), (('gender', 'm'),)]) - ]) - }] - - with beam.Pipeline() as pipeline: - result = ( - pipeline - | 'CreateTestInput' >> beam.Create(data, reshuffle=False) - | 'FanoutSlices' >> slicer.FanoutSlices()) - - def check_result(got): - try: - self.assertLen(got, 5) - del data[0][constants.SLICE_KEY_TYPES_KEY] - del data[1][constants.SLICE_KEY_TYPES_KEY] - expected_result = [ - ((), data[0]), - ((), data[1]), - ((('gender', 'f'),), data[0]), - ((('gender', 'f'),), data[1]), - ((('gender', 'm'),), data[1]), - ] - self.assertCountEqual(got, expected_result) - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(result, check_result) - - def testMultidimOverallSlices(self): - data = [ { - constants.SLICE_KEY_TYPES_KEY: ( # variable length batch case - types.VarLenTensorValue.from_dense_rows([ - slicer.slice_keys_to_numpy_array([(('gender', 'f'),), ()]), - slicer.slice_keys_to_numpy_array([()]), - ]) + "testcase_name": "_cross_slice_spec", + "slice_type": slicer.CrossSliceSpec, + "slicing_spec": config_pb2.CrossSlicingSpec( + baseline_spec=config_pb2.SlicingSpec(), + slicing_specs=[config_pb2.SlicingSpec(feature_values={"b": "2"})], + ), + }, + ) + def testDeserializeSliceSpec(self, slice_type, slicing_spec): + slice_spec = slicer.deserialize_slice_spec(slicing_spec) + self.assertIsInstance(slice_spec, slice_type) + + def testDeserializeSliceSpec_hashable(self): + single_slice_spec = slicer.deserialize_slice_spec( + config_pb2.SlicingSpec(feature_values={"a": "1"}) + ) + cross_slice_spec = slicer.deserialize_slice_spec( + slicer.config_pb2.CrossSlicingSpec( + baseline_spec=config_pb2.SlicingSpec(), + slicing_specs=[config_pb2.SlicingSpec(feature_values={"b": "2"})], + ) + ) + # Check either of them can be hashed and used as keys. + slice_map = {single_slice_spec: 1, cross_slice_spec: 2} + self.assertEqual(slice_map[single_slice_spec], 1) + self.assertEqual(slice_map[cross_slice_spec], 2) + + def testIsSliceApplicable(self): + test_cases = [ + ( + "applicable", + ["column1"], + [("column3", "value3"), ("column4", "value4")], + (("column1", "value1"), ("column3", "value3"), ("column4", "value4")), + True, + ), + ( + "wrongcolumns", + ["column1", "column2"], + [("column3", "value3"), ("column4", "value4")], + (("column1", "value1"), ("column3", "value3"), ("column4", "value4")), + False, + ), + ( + "wrongfeatures", + ["column1"], + [("column3", "value3")], + (("column1", "value1"), ("column3", "value3"), ("column4", "value4")), + False, + ), + ( + "nocolumns", + [], + [("column3", "value3")], + (("column1", "value1"), ("column3", "value3"), ("column4", "value4")), + False, + ), + ("nofeatures", ["column1"], [], (("column1", "value1"),), True), + ("empty slice key", ["column1"], [("column2", "value1")], (), False), + ("overall", [], [], (), True), + ] # pyformat: disable + + for name, columns, features, slice_key, result in test_cases: + slice_spec = slicer.SingleSliceSpec(columns=columns, features=features) + self.assertEqual( + slice_spec.is_slice_applicable(slice_key), result, msg=name + ) + + def testSliceDefaultSlice(self): + with beam.Pipeline() as pipeline: + fpls = create_fpls() + + metrics = ( + pipeline + | "CreateTestInput" >> beam.Create(fpls) + | "WrapFpls" >> beam.Map(wrap_fpl) + | "ExtractSlices" + >> slice_key_extractor.ExtractSliceKeys([slicer.SingleSliceSpec()]) + | "FanoutSlices" >> slicer.FanoutSlices() ) + + def check_result(got): + try: + self.assertLen(got, 2) + expected_result = [ + ((), wrap_fpl(fpls[0])), + ((), wrap_fpl(fpls[1])), + ] + self.assertEqual(len(got), len(expected_result)) + self.assertTrue( + got[0] == expected_result[0] + and got[1] == expected_result[1] + or got[1] == expected_result[0] + and got[0] == expected_result[1] + ) + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(metrics, check_result) + + def testSliceOneSlice(self): + with beam.Pipeline() as pipeline: + fpls = create_fpls() + metrics = ( + pipeline + | "CreateTestInput" >> beam.Create(fpls, reshuffle=False) + | "WrapFpls" >> beam.Map(wrap_fpl) + | "ExtractSlices" + >> slice_key_extractor.ExtractSliceKeys( + [ + slicer.SingleSliceSpec(), + slicer.SingleSliceSpec(columns=["gender"]), + ] + ) + | "FanoutSlices" >> slicer.FanoutSlices() + ) + + def check_result(got): + try: + self.assertLen(got, 4) + expected_result = [ + ((), wrap_fpl(fpls[0])), + ((), wrap_fpl(fpls[1])), + ((("gender", "f"),), wrap_fpl(fpls[0])), + ((("gender", "m"),), wrap_fpl(fpls[1])), + ] + self.assertCountEqual(got, expected_result) + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(metrics, check_result) + + def testMultidimSlices(self): + data = [ + { + "features": { + "gender": [["f"], ["f"]], + "age": [[13], [13]], + "interest": [["cars"], ["cars"]], + }, + "predictions": [[1], [1]], + "labels": [[0], [0]], + constants.SLICE_KEY_TYPES_KEY: np.array( + [ + slicer.slice_keys_to_numpy_array([(), (("gender", "f"),)]), + slicer.slice_keys_to_numpy_array([(), (("gender", "f"),)]), + ] + ), + }, + { + "features": { + "gender": [["f"], ["m"]], + "age": [[13], [10]], + "interest": [["cars"], ["cars"]], + }, + "predictions": [[1], [1]], + "labels": [[0], [0]], + constants.SLICE_KEY_TYPES_KEY: np.array( + [ + slicer.slice_keys_to_numpy_array([(), (("gender", "f"),)]), + slicer.slice_keys_to_numpy_array([(), (("gender", "m"),)]), + ] + ), + }, + ] + + with beam.Pipeline() as pipeline: + result = ( + pipeline + | "CreateTestInput" >> beam.Create(data, reshuffle=False) + | "FanoutSlices" >> slicer.FanoutSlices() + ) + + def check_result(got): + try: + self.assertLen(got, 5) + del data[0][constants.SLICE_KEY_TYPES_KEY] + del data[1][constants.SLICE_KEY_TYPES_KEY] + expected_result = [ + ((), data[0]), + ((), data[1]), + ((("gender", "f"),), data[0]), + ((("gender", "f"),), data[1]), + ((("gender", "m"),), data[1]), + ] + self.assertCountEqual(got, expected_result) + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(result, check_result) + + def testMultidimOverallSlices(self): + data = [ + { + constants.SLICE_KEY_TYPES_KEY: ( # variable length batch case + types.VarLenTensorValue.from_dense_rows( + [ + slicer.slice_keys_to_numpy_array([(("gender", "f"),), ()]), + slicer.slice_keys_to_numpy_array([()]), + ] + ) + ) + }, + { + constants.SLICE_KEY_TYPES_KEY: np.array( + [ # fixed length batch case + slicer.slice_keys_to_numpy_array([()]), + slicer.slice_keys_to_numpy_array([()]), + ] + ) + }, + ] + data = [tfma_util.StandardExtracts(d) for d in data] + with beam.Pipeline() as pipeline: + # Fix the typehint infer error + beam.typehints.disable_type_annotations() + result = ( + pipeline + | "CreateTestInput" + >> beam.Create(data, reshuffle=False).with_output_types(types.Extracts) + | "FanoutSlices" >> slicer.FanoutSlices() + ) + + def check_result(got): + try: + del data[0][constants.SLICE_KEY_TYPES_KEY] + del data[1][constants.SLICE_KEY_TYPES_KEY] + expected_result = [ + ((("gender", "f"),), data[0]), + ((), data[0]), + ((), data[1]), + ] + self.assertCountEqual(got, expected_result) + except AssertionError as err: + raise util.BeamAssertException(err) + + util.assert_that(result, check_result) + + def testFilterOutSlices(self): + slice_key_1 = (("slice_key", "slice1"),) + slice_key_2 = (("slice_key", "slice2"),) + slice_key_3 = (("slice_key", "slice3"),) + + values_list = [ + (slice_key_1, {"val11": "val12"}), + (slice_key_2, {"val21": "val22"}), + ] + slice_counts_list = [(slice_key_1, 2), (slice_key_2, 1), (slice_key_3, 0)] + + def check_output(got): + try: + self.assertLen(got, 2) + slices = {} + for k, v in got: + slices[k] = v + + self.assertEqual(slices[slice_key_1], {"val11": "val12"}) + self.assertIn(metric_keys.ERROR_METRIC, slices[slice_key_2]) + except AssertionError as err: + raise util.BeamAssertException(err) + + with beam.Pipeline() as pipeline: + slice_counts_pcoll = pipeline | "CreateSliceCountsPColl" >> beam.Create( + slice_counts_list + ) + output_dict = ( + pipeline + | "CreateValuesPColl" >> beam.Create(values_list) + | "FilterOutSlices" + >> slicer.FilterOutSlices( + slice_counts_pcoll, + min_slice_size=2, + error_metric_key=metric_keys.ERROR_METRIC, + ) + ) + util.assert_that(output_dict, check_output) + + @parameterized.named_parameters( + { + "testcase_name": "matching_single_spec", + "slice_key": (("f1", 1),), + "slice_specs": [slicer.SingleSliceSpec(features=[("f1", 1)])], + "expected_result": True, + }, + { + "testcase_name": "matching_single_spec_with_float", + "slice_key": (("f1", "1.0"),), + "slice_specs": [slicer.SingleSliceSpec(features=[("f1", "1.0")])], + "expected_result": True, + }, + { + "testcase_name": "non_matching_single_spec", + "slice_key": (("f1", 1),), + "slice_specs": [slicer.SingleSliceSpec(columns=["f2"])], + "expected_result": False, + }, + { + "testcase_name": "matching_multiple_specs", + "slice_key": (("f1", 1),), + "slice_specs": [ + slicer.SingleSliceSpec(columns=["f1"]), + slicer.SingleSliceSpec(columns=["f2"]), + ], + "expected_result": True, }, { - constants.SLICE_KEY_TYPES_KEY: np.array([ # fixed length batch case - slicer.slice_keys_to_numpy_array([()]), - slicer.slice_keys_to_numpy_array([()]), - ]) + "testcase_name": "empty_specs", + "slice_key": (("f1", 1),), + "slice_specs": [], + "expected_result": False, }, - ] - data = [tfma_util.StandardExtracts(d) for d in data] - with beam.Pipeline() as pipeline: - # Fix the typehint infer error - beam.typehints.disable_type_annotations() - result = ( - pipeline - | 'CreateTestInput' - >> beam.Create(data, reshuffle=False).with_output_types( - types.Extracts - ) - | 'FanoutSlices' >> slicer.FanoutSlices() - ) - - def check_result(got): - try: - del data[0][constants.SLICE_KEY_TYPES_KEY] - del data[1][constants.SLICE_KEY_TYPES_KEY] - expected_result = [ - ((('gender', 'f'),), data[0]), - ((), data[0]), - ((), data[1]), - ] - self.assertCountEqual(got, expected_result) - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(result, check_result) - - def testFilterOutSlices(self): - slice_key_1 = (('slice_key', 'slice1'),) - slice_key_2 = (('slice_key', 'slice2'),) - slice_key_3 = (('slice_key', 'slice3'),) - - values_list = [(slice_key_1, { - 'val11': 'val12' - }), (slice_key_2, { - 'val21': 'val22' - })] - slice_counts_list = [(slice_key_1, 2), (slice_key_2, 1), (slice_key_3, 0)] - - def check_output(got): - try: - self.assertLen(got, 2) - slices = {} - for (k, v) in got: - slices[k] = v - - self.assertEqual(slices[slice_key_1], {'val11': 'val12'}) - self.assertIn(metric_keys.ERROR_METRIC, slices[slice_key_2]) - except AssertionError as err: - raise util.BeamAssertException(err) - - with beam.Pipeline() as pipeline: - slice_counts_pcoll = ( - pipeline | 'CreateSliceCountsPColl' >> beam.Create(slice_counts_list)) - output_dict = ( - pipeline - | 'CreateValuesPColl' >> beam.Create(values_list) - | 'FilterOutSlices' >> slicer.FilterOutSlices( - slice_counts_pcoll, - min_slice_size=2, - error_metric_key=metric_keys.ERROR_METRIC)) - util.assert_that(output_dict, check_output) - - @parameterized.named_parameters( - { - 'testcase_name': 'matching_single_spec', - 'slice_key': (('f1', 1),), - 'slice_specs': [slicer.SingleSliceSpec(features=[('f1', 1)])], - 'expected_result': True - }, - { - 'testcase_name': 'matching_single_spec_with_float', - 'slice_key': (('f1', '1.0'),), - 'slice_specs': [slicer.SingleSliceSpec(features=[('f1', '1.0')])], - 'expected_result': True - }, - { - 'testcase_name': 'non_matching_single_spec', - 'slice_key': (('f1', 1),), - 'slice_specs': [slicer.SingleSliceSpec(columns=['f2'])], - 'expected_result': False - }, - { - 'testcase_name': 'matching_multiple_specs', - 'slice_key': (('f1', 1),), - 'slice_specs': [ - slicer.SingleSliceSpec(columns=['f1']), - slicer.SingleSliceSpec(columns=['f2']) - ], - 'expected_result': True - }, - { - 'testcase_name': 'empty_specs', - 'slice_key': (('f1', 1),), - 'slice_specs': [], - 'expected_result': False - }, - ) - def testSliceKeyMatchesSliceSpecs(self, slice_key, slice_specs, - expected_result): - self.assertEqual( - expected_result, - slicer.slice_key_matches_slice_specs(slice_key, slice_specs)) - - -if __name__ == '__main__': - tf.test.main() + ) + def testSliceKeyMatchesSliceSpecs(self, slice_key, slice_specs, expected_result): + self.assertEqual( + expected_result, + slicer.slice_key_matches_slice_specs(slice_key, slice_specs), + ) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow_model_analysis/static/extension.js b/tensorflow_model_analysis/static/extension.js index 074e7469fe..ade7954c96 100644 --- a/tensorflow_model_analysis/static/extension.js +++ b/tensorflow_model_analysis/static/extension.js @@ -104,4 +104,4 @@ module.exports = { /***/ }) -/******/ ])});; \ No newline at end of file +/******/ ])});; diff --git a/tensorflow_model_analysis/static/index.js b/tensorflow_model_analysis/static/index.js index d460d31db2..b24d192580 100644 --- a/tensorflow_model_analysis/static/index.js +++ b/tensorflow_model_analysis/static/index.js @@ -17652,4 +17652,4 @@ module.exports = function(module) { /***/ }) /******/ ])});; -//# sourceMappingURL=index.js.map \ No newline at end of file +//# sourceMappingURL=index.js.map diff --git a/tensorflow_model_analysis/static/index.js.map b/tensorflow_model_analysis/static/index.js.map index c5600d9e90..ede74abc73 100644 --- a/tensorflow_model_analysis/static/index.js.map +++ b/tensorflow_model_analysis/static/index.js.map @@ -1 +1 @@ -{"version":3,"sources":["webpack:///webpack/bootstrap 9db01875ed55e7424b10","webpack:///./package.json","webpack:///./lib/index.js","webpack:///./lib/widget.js","webpack:///external \"@jupyter-widgets/base\"","webpack:///./node_modules/lodash/lodash.js","webpack:///(webpack)/buildin/global.js","webpack:///(webpack)/buildin/module.js"],"names":[],"mappings":";QAAA;QACA;;QAEA;QACA;;QAEA;QACA;QACA;QACA;QACA;QACA;QACA;QACA;QACA;QACA;;QAEA;QACA;;QAEA;QACA;;QAEA;QACA;QACA;;;QAGA;QACA;;QAEA;QACA;;QAEA;QACA;QACA;QACA;QACA;QACA;QACA;QACA,KAAK;QACL;QACA;;QAEA;QACA;QACA;QACA,2BAA2B,0BAA0B,EAAE;QACvD,iCAAiC,eAAe;QAChD;QACA;QACA;;QAEA;QACA,sDAAsD,+DAA+D;;QAErH;QACA;;QAEA;QACA;;;;;;;AC7DA,kBAAkB,kNAAkN,sEAAsE,4FAA4F,qIAAqI,oBAAoB,qCAAqC,iBAAiB,oEAAoE,eAAe,8CAA8C,yBAAyB,mCAAmC,kBAAkB,uD;;;;;;ACApyB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA,iBAAiB,mBAAO,CAAC,CAAa;AACtC,4BAA4B,mBAAO,CAAC,CAAiB;;;;;;;AClBrD;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,gBAAgB,mBAAO,CAAC,CAAuB;AAC/C,UAAU,mBAAO,CAAC,CAAQ;AAC1B,gBAAgB,mBAAO,CAAC,CAAiB;;AAEzC;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,mBAAmB,qBAAuB;AAC1C;AACA;AACA,8BAA8B,aAAa;;AAE3C;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA,WAAW,UAAU;AACrB;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,cAAc;AACd;AACA;AACA,GAAG;AACH,CAAC;;AAED;AACA;AACA;;AAEA;AACA;;AAEA;AACA;AACA,KAAK;;AAEL;AACA;AACA;AACA;AACA;AACA,KAAK;AACL,GAAG;AACH;AACA;AACA,GAAG;AACH;AACA;AACA,GAAG;AACH,CAAC;;AAED;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,cAAc;AACd;AACA,GAAG;AACH,CAAC;;AAED;AACA;AACA;;AAEA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA,KAAK;AACL,GAAG;AACH;AACA;AACA,GAAG;AACH;AACA;AACA,GAAG;AACH,CAAC;;AAED;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,cAAc;AACd;AACA,GAAG;AACH,CAAC;;AAED;AACA;AACA;;AAEA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA,KAAK;AACL,GAAG;AACH;AACA;AACA,GAAG;AACH;AACA;AACA,GAAG;AACH,CAAC;;AAED;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,GAAG;AACH,CAAC;;AAED;AACA;AACA;;AAEA;AACA;;AAEA;AACA;AACA,KAAK;;AAEL;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,KAAK;AACL,GAAG;AACH;AACA;AACA,GAAG;AACH;AACA;AACA,GAAG;AACH;AACA;AACA,GAAG;AACH;AACA;AACA,GAAG;AACH,CAAC;;AAED;AACA;AACA,WAAW,OAAO;AAClB,WAAW,SAAS;AACpB;AACA;AACA;AACA;AACA;AACA,iBAAiB,6CAA6C;AAC9D;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;;;;;;AC/QA,+C;;;;;;ACAA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,CAAC;;AAED;AACA;;AAEA;AACA;;AAEA;AACA;;AAEA;AACA;AACA;AACA;;AAEA;AACA;;AAEA;AACA;;AAEA;AACA;;AAEA;AACA;AACA;AACA;;AAEA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;;AAEA;AACA;AACA;;AAEA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA,2CAA2C;AAC3C;AACA,2DAA2D;;AAE3D;AACA,+CAA+C;AAC/C;AACA;AACA;;AAEA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,sCAAsC;AACtC;;AAEA;AACA;;AAEA;AACA;;AAEA;AACA,yBAAyB;AACzB,yBAAyB;AACzB;;AAEA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA,YAAY;AACZ;AACA;AACA;AACA,2CAA2C;;AAE3C;AACA;;AAEA;AACA;AACA;AACA;AACA,0BAA0B,MAAM,aAAa,OAAO;;AAEpD;AACA;;AAEA;AACA;;AAEA;AACA;;AAEA;AACA;;AAEA;AACA;;AAEA;AACA;;AAEA;AACA;;AAEA;AACA;;AAEA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,iDAAiD,EAAE;AACnD;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;;AAEA;AACA;AACA;AACA;AACA;;AAEA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;;AAEA;AACA,2CAA2C,EAAE;;AAE7C;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA,eAAe;AACf,cAAc;AACd,cAAc;AACd,gBAAgB;AAChB,eAAe;AACf;;AAEA;AACA;AACA,UAAU;AACV,SAAS;AACT,SAAS;AACT,WAAW;AACX,UAAU;AACV;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;;AAEA;AACA;;AAEA;AACA;;AAEA;AACA;;AAEA;AACA;;AAEA;AACA;;AAEA;AACA;;AAEA;AACA;;AAEA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;;AAEA;AACA;AACA,KAAK;AACL,GAAG;;AAEH;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;;AAEA;AACA;AACA;AACA;AACA;AACA,aAAa,SAAS;AACtB,aAAa,EAAE;AACf,aAAa,MAAM;AACnB,eAAe,EAAE;AACjB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,aAAa,MAAM;AACnB,aAAa,SAAS;AACtB,aAAa,SAAS;AACtB,aAAa,OAAO;AACpB,eAAe,SAAS;AACxB;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA,aAAa,MAAM;AACnB,aAAa,SAAS;AACtB,eAAe,MAAM;AACrB;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA,aAAa,MAAM;AACnB,aAAa,SAAS;AACtB,eAAe,MAAM;AACrB;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA,aAAa,MAAM;AACnB,aAAa,SAAS;AACtB,eAAe,QAAQ;AACvB;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA,aAAa,MAAM;AACnB,aAAa,SAAS;AACtB,eAAe,MAAM;AACrB;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA,aAAa,MAAM;AACnB,aAAa,EAAE;AACf,eAAe,QAAQ;AACvB;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,aAAa,MAAM;AACnB,aAAa,EAAE;AACf,aAAa,SAAS;AACtB,eAAe,QAAQ;AACvB;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA,aAAa,MAAM;AACnB,aAAa,SAAS;AACtB,eAAe,MAAM;AACrB;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,aAAa,MAAM;AACnB,aAAa,MAAM;AACnB,eAAe,MAAM;AACrB;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA,aAAa,MAAM;AACnB,aAAa,SAAS;AACtB,aAAa,EAAE;AACf,aAAa,QAAQ;AACrB;AACA,eAAe,EAAE;AACjB;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA,aAAa,MAAM;AACnB,aAAa,SAAS;AACtB,aAAa,EAAE;AACf,aAAa,QAAQ;AACrB;AACA,eAAe,EAAE;AACjB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA,aAAa,MAAM;AACnB,aAAa,SAAS;AACtB,eAAe,QAAQ;AACvB;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,aAAa,OAAO;AACpB,eAAe,OAAO;AACtB;AACA;;AAEA;AACA;AACA;AACA;AACA,aAAa,OAAO;AACpB,eAAe,MAAM;AACrB;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,aAAa,OAAO;AACpB,eAAe,MAAM;AACrB;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA,aAAa,aAAa;AAC1B,aAAa,SAAS;AACtB,aAAa,SAAS;AACtB,eAAe,EAAE;AACjB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,KAAK;AACL;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA,aAAa,MAAM;AACnB,aAAa,SAAS;AACtB,aAAa,OAAO;AACpB,aAAa,QAAQ;AACrB,eAAe,OAAO;AACtB;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,aAAa,MAAM;AACnB,aAAa,EAAE;AACf,aAAa,OAAO;AACpB,eAAe,OAAO;AACtB;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,aAAa,MAAM;AACnB,aAAa,EAAE;AACf,aAAa,OAAO;AACpB,aAAa,SAAS;AACtB,eAAe,OAAO;AACtB;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,aAAa,EAAE;AACf,eAAe,QAAQ;AACvB;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA,aAAa,MAAM;AACnB,aAAa,SAAS;AACtB,eAAe,OAAO;AACtB;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,aAAa,OAAO;AACpB,eAAe,SAAS;AACxB;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,aAAa,OAAO;AACpB,eAAe,SAAS;AACxB;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA,aAAa,aAAa;AAC1B,aAAa,SAAS;AACtB,aAAa,EAAE;AACf,aAAa,QAAQ;AACrB;AACA,aAAa,SAAS;AACtB,eAAe,EAAE;AACjB;AACA;AACA;AACA;AACA;AACA;AACA,KAAK;AACL;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA,aAAa,MAAM;AACnB,aAAa,SAAS;AACtB,eAAe,MAAM;AACrB;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA,aAAa,MAAM;AACnB,aAAa,SAAS;AACtB,eAAe,OAAO;AACtB;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA,aAAa,OAAO;AACpB,aAAa,SAAS;AACtB,eAAe,MAAM;AACrB;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA,aAAa,OAAO;AACpB,aAAa,MAAM;AACnB,eAAe,OAAO;AACtB;AACA;AACA;AACA;AACA,KAAK;AACL;;AAEA;AACA;AACA;AACA;AACA,aAAa,OAAO;AACpB,eAAe,OAAO;AACtB;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,aAAa,SAAS;AACtB,eAAe,SAAS;AACxB;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA,aAAa,OAAO;AACpB,aAAa,MAAM;AACnB,eAAe,OAAO;AACtB;AACA;AACA;AACA;AACA,KAAK;AACL;;AAEA;AACA;AACA;AACA;AACA,aAAa,OAAO;AACpB,aAAa,OAAO;AACpB,eAAe,QAAQ;AACvB;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA,aAAa,MAAM;AACnB,aAAa,MAAM;AACnB,eAAe,OAAO;AACtB;AACA;AACA;AACA;;AAEA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA,aAAa,MAAM;AACnB,aAAa,MAAM;AACnB,eAAe,OAAO;AACtB;AACA;AACA;;AAEA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,aAAa,MAAM;AACnB,aAAa,EAAE;AACf,eAAe,OAAO;AACtB;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA,aAAa,OAAO;AACpB,eAAe,OAAO;AACtB;AACA;;AAEA;AACA;AACA;AACA;AACA,aAAa,OAAO;AACpB,eAAe,OAAO;AACtB;AACA;;AAEA;AACA;AACA;AACA;AACA,aAAa,OAAO;AACpB,eAAe,OAAO;AACtB;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,aAAa,OAAO;AACpB,aAAa,OAAO;AACpB,eAAe,EAAE;AACjB;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,aAAa,OAAO;AACpB,eAAe,QAAQ;AACvB;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,aAAa,OAAO;AACpB,eAAe,QAAQ;AACvB;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,aAAa,OAAO;AACpB,eAAe,MAAM;AACrB;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,aAAa,OAAO;AACpB,eAAe,MAAM;AACrB;AACA;AACA;AACA;;AAEA;AACA;AACA,KAAK;AACL;AACA;;AAEA;AACA;AACA;AACA;AACA,aAAa,SAAS;AACtB,aAAa,SAAS;AACtB,eAAe,SAAS;AACxB;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA,aAAa,MAAM;AACnB,aAAa,EAAE;AACf,eAAe,MAAM;AACrB;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,aAAa,OAAO;AACpB,eAAe,MAAM;AACrB;AACA;AACA;AACA;;AAEA;AACA;AACA,KAAK;AACL;AACA;;AAEA;AACA;AACA;AACA;AACA,aAAa,OAAO;AACpB,eAAe,MAAM;AACrB;AACA;AACA;AACA;;AAEA;AACA;AACA,KAAK;AACL;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA,aAAa,MAAM;AACnB,aAAa,EAAE;AACf,aAAa,OAAO;AACpB,eAAe,OAAO;AACtB;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA,aAAa,MAAM;AACnB,aAAa,EAAE;AACf,aAAa,OAAO;AACpB,eAAe,OAAO;AACtB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,aAAa,OAAO;AACpB,eAAe,OAAO;AACtB;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,aAAa,OAAO;AACpB,eAAe,MAAM;AACrB;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA,aAAa,OAAO;AACpB,eAAe,OAAO;AACtB;AACA;AACA;;AAEA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,aAAa,OAAO;AACpB,eAAe,OAAO;AACtB;AACA;;AAEA;AACA;AACA;AACA;AACA,aAAa,OAAO;AACpB,eAAe,OAAO;AACtB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,aAAa,OAAO;AACpB,eAAe,MAAM;AACrB;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,aAAa,OAAO;AACpB,eAAe,MAAM;AACrB;AACA;AACA;AACA;;AAEA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,aAAa,OAAO;AACpB,eAAe,SAAS;AACxB;AACA;AACA,cAAc,2BAA2B;AACzC;AACA;AACA,mBAAmB,gCAAgC;AACnD;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,iCAAiC,6BAA6B;AAC9D;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;;AAEA;AACA;;AAEA;AACA;;AAEA;AACA;;AAEA;AACA;;AAEA;AACA;AACA;AACA;AACA,KAAK;;AAEL;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;;AAEA;AACA;;AAEA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA,eAAe,QAAQ;AACvB;AACA,OAAO;AACP,KAAK;;AAEL;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;;AAEA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;;AAEA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,kEAAkE;AAClE;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,KAAK;;AAEL;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,eAAe,QAAQ;AACvB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,cAAc;AACd;AACA;;AAEA;AACA;AACA;AACA;AACA,gBAAgB;AAChB;AACA;;AAEA;AACA;AACA;AACA;AACA,gBAAgB;AAChB;AACA;;AAEA;AACA;AACA;AACA;AACA,gBAAgB;AAChB;AACA;;AAEA;AACA;AACA;AACA;AACA,gBAAgB;AAChB;AACA;;AAEA;AACA;AACA;AACA;AACA,gBAAgB;AAChB;AACA;;AAEA;AACA;AACA;AACA;AACA,kBAAkB;AAClB;AACA;AACA;AACA;;AAEA;AACA;AACA;;AAEA;AACA;;AAEA;;AAEA;AACA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;AACA;AACA,OAAO;AACP;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA,iBAAiB,EAAE;AACnB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;;AAEA;AACA;AACA;;AAEA;AACA;;AAEA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA,WAAW;AACX;AACA;AACA,aAAa;AACb;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;;AAEA;;AAEA;AACA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,OAAO;AACtB,iBAAiB,QAAQ;AACzB;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,iBAAiB,EAAE;AACnB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,iBAAiB,QAAQ;AACzB;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,EAAE;AACjB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;;AAEA;;AAEA;AACA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,iBAAiB,QAAQ;AACzB;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA,OAAO;AACP;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,iBAAiB,EAAE;AACnB;AACA;AACA;AACA;;AAEA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,iBAAiB,QAAQ;AACzB;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,EAAE;AACjB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA,OAAO;AACP;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;;AAEA;;AAEA;AACA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,iBAAiB,QAAQ;AACzB;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,iBAAiB,EAAE;AACnB;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,iBAAiB,QAAQ;AACzB;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,EAAE;AACjB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;;AAEA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;;AAEA;AACA;AACA;;AAEA;;AAEA;AACA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,iBAAiB,QAAQ;AACzB;AACA;AACA;AACA;;AAEA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,iBAAiB,EAAE;AACnB;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,iBAAiB,QAAQ;AACzB;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,EAAE;AACjB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;;AAEA;;AAEA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,eAAe,QAAQ;AACvB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB,iBAAiB,EAAE;AACnB;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB,eAAe,OAAO;AACtB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,OAAO;AACtB,eAAe,EAAE;AACjB;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,OAAO;AACtB,eAAe,EAAE;AACjB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB,eAAe,EAAE;AACjB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA,eAAe,aAAa;AAC5B,eAAe,SAAS;AACxB,eAAe,SAAS;AACxB,eAAe,OAAO;AACtB,iBAAiB,SAAS;AAC1B;AACA;AACA;AACA;AACA,OAAO;AACP;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,OAAO;AACtB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,OAAO;AACtB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,OAAO;AACtB,eAAe,EAAE;AACjB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,SAAS;AACT,OAAO;AACP;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,SAAS;AACxB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,OAAO;AACtB,eAAe,OAAO;AACtB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,eAAe,QAAQ;AACvB;AACA;AACA;AACA,eAAe,SAAS;AACxB,eAAe,OAAO;AACtB,eAAe,OAAO;AACtB,eAAe,OAAO;AACtB,iBAAiB,EAAE;AACnB;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,OAAO;AACP;AACA;;AAEA;AACA;AACA;AACA;AACA,0CAA0C;AAC1C;AACA;AACA;AACA;AACA;AACA,SAAS;AACT;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA,SAAS;AACT,OAAO;AACP;AACA;AACA,SAAS;AACT;;AAEA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,OAAO;AACP;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,iBAAiB,SAAS;AAC1B;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,OAAO;AACtB,iBAAiB,QAAQ;AACzB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA,eAAe,SAAS;AACxB,eAAe,OAAO;AACtB,eAAe,MAAM;AACrB,iBAAiB,cAAc;AAC/B;AACA;AACA;AACA;AACA;AACA,oCAAoC,6BAA6B,EAAE;AACnE;;AAEA;AACA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB,eAAe,MAAM;AACrB,eAAe,SAAS;AACxB,eAAe,SAAS;AACxB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,aAAa;AAC5B,eAAe,SAAS;AACxB,iBAAiB,aAAa;AAC9B;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,aAAa;AAC5B,eAAe,SAAS;AACxB,iBAAiB,aAAa;AAC9B;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,aAAa;AAC5B,eAAe,SAAS;AACxB,iBAAiB,QAAQ;AACzB;AACA;AACA;AACA;AACA;AACA;AACA;AACA,OAAO;AACP;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB,eAAe,SAAS;AACxB,eAAe,SAAS;AACxB,iBAAiB,EAAE;AACnB;AACA;AACA;AACA;;AAEA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB,eAAe,EAAE;AACjB,eAAe,OAAO;AACtB,eAAe,OAAO;AACtB,iBAAiB,MAAM;AACvB;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,aAAa;AAC5B,eAAe,SAAS;AACxB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;AACA;AACA;AACA;AACA,OAAO;AACP;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB,eAAe,OAAO;AACtB,eAAe,QAAQ;AACvB,eAAe,QAAQ;AACvB,eAAe,MAAM;AACrB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;;AAEA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA,WAAW;AACX;AACA;AACA,SAAS;AACT;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,SAAS;AACxB,eAAe,SAAS;AACxB,iBAAiB,OAAO;AACxB;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,SAAS;AACxB,eAAe,SAAS;AACxB,iBAAiB,OAAO;AACxB;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,SAAS;AACxB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,SAAS;AACxB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,MAAM;AACrB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;AACA,OAAO;AACP;;AAEA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,aAAa;AAC5B,iBAAiB,EAAE;AACnB;AACA;AACA;;AAEA;AACA;;AAEA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,SAAS;AACxB,eAAe,SAAS;AACxB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,eAAe,EAAE;AACjB,iBAAiB,QAAQ;AACzB;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,aAAa;AAC5B,iBAAiB,QAAQ;AACzB;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,aAAa;AAC5B,iBAAiB,QAAQ;AACzB;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,OAAO;AACtB,eAAe,OAAO;AACtB,iBAAiB,QAAQ;AACzB;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB,eAAe,SAAS;AACxB,eAAe,SAAS;AACxB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;;AAEA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,SAAS;AACxB,eAAe,SAAS;AACxB,eAAe,OAAO;AACtB,iBAAiB,SAAS;AAC1B;AACA;AACA;AACA;AACA,OAAO;AACP;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,aAAa;AAC5B,eAAe,MAAM;AACrB,iBAAiB,EAAE;AACnB;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,iBAAiB,QAAQ;AACzB;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,iBAAiB,QAAQ;AACzB;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,iBAAiB,QAAQ;AACzB;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,eAAe,EAAE;AACjB,eAAe,QAAQ;AACvB;AACA;AACA,eAAe,SAAS;AACxB,eAAe,OAAO;AACtB,iBAAiB,QAAQ;AACzB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,OAAO;AACtB,eAAe,OAAO;AACtB,eAAe,SAAS;AACxB,eAAe,SAAS;AACxB,eAAe,OAAO;AACtB,iBAAiB,QAAQ;AACzB;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;;AAEA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,iBAAiB,QAAQ;AACzB;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,OAAO;AACtB,eAAe,MAAM;AACrB,eAAe,SAAS;AACxB,iBAAiB,QAAQ;AACzB;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,SAAS;AACT;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,iBAAiB,QAAQ;AACzB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,iBAAiB,QAAQ;AACzB;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,iBAAiB,QAAQ;AACzB;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,iBAAiB,QAAQ;AACzB;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,iBAAiB,SAAS;AAC1B;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,eAAe,EAAE;AACjB,iBAAiB,QAAQ;AACzB;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,aAAa;AAC5B,eAAe,SAAS;AACxB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;;AAEA;AACA;AACA,OAAO;AACP;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,iBAAiB,SAAS;AAC1B;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,EAAE;AACjB,iBAAiB,SAAS;AAC1B;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,OAAO;AACtB,eAAe,OAAO;AACtB,eAAe,SAAS;AACxB,eAAe,OAAO;AACtB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA,OAAO;AACP;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,OAAO;AACtB,eAAe,OAAO;AACtB,eAAe,OAAO;AACtB,eAAe,SAAS;AACxB,eAAe,SAAS;AACxB,eAAe,OAAO;AACtB;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;;AAEA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB,eAAe,OAAO;AACtB,iBAAiB,EAAE;AACnB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,aAAa;AAC5B,eAAe,6BAA6B;AAC5C,eAAe,SAAS;AACxB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,SAAS;AACT,OAAO;AACP;AACA;;AAEA;AACA;;AAEA;AACA;AACA;AACA,SAAS;AACT,gBAAgB;AAChB,OAAO;;AAEP;AACA;AACA,OAAO;AACP;;AAEA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,SAAS;AACxB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA,OAAO;AACP;;AAEA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,SAAS;AACxB,eAAe,SAAS;AACxB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,aAAa;AAC5B,iBAAiB,SAAS;AAC1B;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB,eAAe,MAAM;AACrB,eAAe,SAAS;AACxB,eAAe,SAAS;AACxB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB,eAAe,SAAS;AACxB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA,WAAW;AACX;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,OAAO;AACtB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,OAAO;AACtB,eAAe,OAAO;AACtB,eAAe,QAAQ;AACvB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,OAAO;AACtB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,OAAO;;AAEP;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,SAAS;AACxB,eAAe,OAAO;AACtB,iBAAiB,SAAS;AAC1B;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,aAAa;AAC5B,iBAAiB,EAAE;AACnB;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,aAAa;AAC5B,eAAe,OAAO;AACtB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,aAAa;AAC5B,eAAe,EAAE;AACjB,eAAe,SAAS;AACxB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;;AAEA;AACA;AACA;;AAEA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA,mDAAmD;AACnD;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,SAAS;AACxB,eAAe,EAAE;AACjB,iBAAiB,SAAS;AAC1B;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,SAAS;AACxB,eAAe,SAAS;AACxB,iBAAiB,SAAS;AAC1B;AACA;AACA;AACA;AACA;AACA;AACA;AACA,OAAO;AACP;;AAEA;AACA;AACA;AACA;AACA,eAAe,aAAa;AAC5B,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB,eAAe,OAAO;AACtB,eAAe,OAAO;AACtB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,aAAa;AAC5B,eAAe,SAAS;AACxB,iBAAiB,QAAQ;AACzB;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA,OAAO;AACP;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB,eAAe,EAAE;AACjB,eAAe,QAAQ;AACvB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA,WAAW;AACX;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA,qEAAqE;AACrE;AACA;AACA,eAAe,MAAM;AACrB,eAAe,EAAE;AACjB,eAAe,SAAS;AACxB,eAAe,QAAQ;AACvB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA,SAAS;AACT;AACA,SAAS;AACT;AACA,SAAS;AACT;AACA,SAAS;AACT;AACA,SAAS;AACT;AACA;AACA;AACA;AACA,SAAS;AACT;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB,eAAe,SAAS;AACxB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB,eAAe,SAAS;AACxB,eAAe,SAAS;AACxB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,aAAa;AAC5B,iBAAiB,QAAQ;AACzB;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,aAAa;AAC5B,eAAe,SAAS;AACxB,eAAe,SAAS;AACxB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB,eAAe,SAAS;AACxB,eAAe,QAAQ;AACvB,eAAe,QAAQ;AACvB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;;AAEA;AACA;;AAEA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,eAAe,MAAM;AACrB,iBAAiB,EAAE;AACnB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,OAAO;AACP;;AAEA;AACA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB,eAAe,SAAS;AACxB,eAAe,SAAS;AACxB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB,eAAe,MAAM;AACrB,eAAe,SAAS;AACxB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,iBAAiB,aAAa;AAC9B;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,iBAAiB,SAAS;AAC1B;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,eAAe,OAAO;AACtB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA,cAAc;AACd,eAAe,SAAS;AACxB,iBAAiB,SAAS;AAC1B;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB,eAAe,OAAO;AACtB,eAAe,OAAO;AACtB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,cAAc;AAC7B;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,QAAQ;AACvB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,YAAY;AAC3B,iBAAiB,YAAY;AAC7B;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,QAAQ;AACvB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,QAAQ;AACvB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,eAAe,EAAE;AACjB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,OAAO;AACtB,eAAe,mBAAmB;AAClC,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB,eAAe,MAAM;AACrB,eAAe,MAAM;AACrB,gBAAgB,QAAQ;AACxB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB,eAAe,MAAM;AACrB,eAAe,MAAM;AACrB,gBAAgB,QAAQ;AACxB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB,eAAe,MAAM;AACrB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,MAAM;AACrB,eAAe,OAAO,WAAW;AACjC,eAAe,SAAS;AACxB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA,4BAA4B;;AAE5B;AACA;;AAEA;AACA;;AAEA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA,SAAS;AACT;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,OAAO,WAAW;AACjC,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,OAAO,WAAW;AACjC,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,SAAS;AACxB,eAAe,SAAS;AACxB,iBAAiB,SAAS;AAC1B;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,SAAS;AACxB,iBAAiB,SAAS;AAC1B;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,OAAO;AACP;;AAEA;AACA;AACA;AACA;AACA,eAAe,SAAS;AACxB,eAAe,QAAQ;AACvB,iBAAiB,SAAS;AAC1B;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,QAAQ;AACvB,iBAAiB,SAAS;AAC1B;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA,eAAe,SAAS;AACxB,eAAe,OAAO;AACtB,eAAe,EAAE;AACjB,iBAAiB,SAAS;AAC1B;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,iBAAiB,SAAS;AAC1B;AACA;AACA;AACA;;AAEA;AACA;AACA;;AAEA;AACA;AACA;;AAEA;AACA;AACA;;AAEA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,SAAS;AACxB,iBAAiB,SAAS;AAC1B;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA,eAAe,SAAS;AACxB,iBAAiB,SAAS;AAC1B;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,SAAS;AACxB,eAAe,OAAO;AACtB,eAAe,OAAO;AACtB,iBAAiB,SAAS;AAC1B;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,SAAS;AACxB,iBAAiB,SAAS;AAC1B;AACA;AACA;AACA;AACA;AACA;AACA;AACA,qCAAqC,+CAA+C;AACpF;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,QAAQ;AACvB,iBAAiB,SAAS;AAC1B;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA,WAAW;AACX;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA,OAAO;AACP;;AAEA;AACA;AACA;AACA;AACA;AACA,eAAe,gBAAgB;AAC/B,eAAe,OAAO;AACtB,eAAe,EAAE;AACjB,eAAe,MAAM;AACrB;AACA,eAAe,MAAM;AACrB,eAAe,MAAM;AACrB;AACA,eAAe,MAAM;AACrB,eAAe,MAAM;AACrB,eAAe,OAAO;AACtB,eAAe,OAAO;AACtB,iBAAiB,SAAS;AAC1B;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA,SAAS;AACT;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,SAAS;AACxB,eAAe,SAAS;AACxB,iBAAiB,SAAS;AAC1B;AACA;AACA;AACA,oEAAoE;AACpE;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,SAAS;AACxB,eAAe,OAAO;AACtB,iBAAiB,SAAS;AAC1B;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,WAAW;AACX;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,SAAS;AACxB,iBAAiB,SAAS;AAC1B;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,WAAW;AACX,SAAS;AACT,OAAO;AACP;;AAEA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,OAAO;AACtB,iBAAiB,OAAO;AACxB;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA,eAAe,SAAS;AACxB,eAAe,OAAO;AACtB,eAAe,EAAE;AACjB,eAAe,MAAM;AACrB;AACA,iBAAiB,SAAS;AAC1B;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,QAAQ;AACvB,iBAAiB,SAAS;AAC1B;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,SAAS;AACT;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,SAAS;AACxB,iBAAiB,SAAS;AAC1B;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,SAAS;AACxB,eAAe,OAAO;AACtB,eAAe,SAAS;AACxB,eAAe,EAAE;AACjB,eAAe,EAAE;AACjB,eAAe,MAAM;AACrB;AACA,eAAe,MAAM;AACrB,eAAe,MAAM;AACrB,eAAe,OAAO;AACtB,eAAe,OAAO;AACtB,iBAAiB,SAAS;AAC1B;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,iBAAiB,SAAS;AAC1B;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,SAAS;AACxB,iBAAiB,SAAS;AAC1B;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA,eAAe,gBAAgB;AAC/B,eAAe,OAAO;AACtB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,eAAe,MAAM;AACrB,eAAe,MAAM;AACrB,eAAe,MAAM;AACrB,eAAe,OAAO;AACtB,eAAe,OAAO;AACtB,iBAAiB,SAAS;AAC1B;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;;AAEA;AACA;AACA;;AAEA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA,OAAO;AACP;AACA,OAAO;AACP;AACA,OAAO;AACP;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,eAAe,EAAE;AACjB,eAAe,OAAO;AACtB,eAAe,OAAO;AACtB,iBAAiB,EAAE;AACnB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,eAAe,EAAE;AACjB,eAAe,OAAO;AACtB,eAAe,OAAO;AACtB,eAAe,OAAO;AACtB,eAAe,OAAO;AACtB;AACA,iBAAiB,EAAE;AACnB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,eAAe,OAAO;AACtB,iBAAiB,EAAE;AACnB;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB,eAAe,MAAM;AACrB,eAAe,OAAO;AACtB,eAAe,SAAS;AACxB,eAAe,SAAS;AACxB,eAAe,OAAO;AACtB,iBAAiB,QAAQ;AACzB;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;;AAEA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe;AACf;AACA;AACA;AACA,SAAS;AACT;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,OAAO;AACtB,eAAe,OAAO;AACtB,eAAe,OAAO;AACtB,eAAe,SAAS;AACxB,eAAe,SAAS;AACxB,eAAe,OAAO;AACtB,iBAAiB,QAAQ;AACzB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;;AAEA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,OAAO;AACtB,eAAe,OAAO;AACtB,eAAe,SAAS;AACxB,eAAe,SAAS;AACxB,eAAe,OAAO;AACtB,iBAAiB,QAAQ;AACzB;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,SAAS;AACxB,iBAAiB,SAAS;AAC1B;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,SAAS;AACxB,iBAAiB,EAAE;AACnB;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,SAAS;AACxB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,SAAS;AACxB,iBAAiB,EAAE;AACnB;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,eAAe,OAAO;AACtB,iBAAiB,SAAS;AAC1B;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,OAAO;AACtB,iBAAiB,EAAE;AACnB;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;;AAEA;AACA;AACA;;AAEA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,OAAO;AACtB,iBAAiB,EAAE;AACnB;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA,OAAO;;AAEP;AACA;AACA;AACA;AACA,SAAS;AACT;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,OAAO;AACP;;AAEA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,iBAAiB,OAAO;AACxB;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,OAAO;AACtB,eAAe,MAAM;AACrB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;;AAEA;AACA,0CAA0C;AAC1C,wCAAwC;AACxC,+DAA+D;AAC/D,iEAAiE;AACjE;AACA;AACA,cAAc;AACd;;AAEA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,aAAa;AAC5B,eAAe,SAAS;AACxB,iBAAiB,QAAQ;AACzB;AACA;AACA;;AAEA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,OAAO;AACtB,eAAe,QAAQ;AACvB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;;AAEA;AACA;;AAEA;AACA;AACA;AACA;;AAEA;AACA;;AAEA;AACA;AACA;;AAEA;AACA;;AAEA;AACA;;AAEA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,iBAAiB,MAAM;AACvB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,6CAA6C;AAC7C;;AAEA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,iBAAiB,QAAQ;AACzB;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,eAAe,OAAO;AACtB,iBAAiB,QAAQ;AACzB;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,eAAe,EAAE;AACjB,eAAe,EAAE;AACjB,iBAAiB,QAAQ;AACzB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,eAAe,OAAO;AACtB,iBAAiB,QAAQ;AACzB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,iBAAiB,QAAQ;AACzB;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,SAAS;AACxB,iBAAiB,QAAQ;AACzB;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,SAAS;AACxB,iBAAiB,QAAQ;AACzB;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,iBAAiB,QAAQ;AACzB;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,iBAAiB,QAAQ;AACzB;AACA;AACA;AACA;;AAEA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,iBAAiB,QAAQ;AACzB;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,EAAE;AACjB,iBAAiB,SAAS;AAC1B;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA,eAAe,SAAS;AACxB,iBAAiB,SAAS;AAC1B;AACA;AACA;AACA;AACA;AACA;AACA;AACA,OAAO;;AAEP;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB,eAAe,MAAM;AACrB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,SAAS;AACxB,eAAe,OAAO;AACtB,eAAe,SAAS;AACxB,iBAAiB,SAAS;AAC1B;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,MAAM;AACrB,iBAAiB,EAAE;AACnB;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB,eAAe,MAAM;AACrB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,OAAO;AACtB,iBAAiB,EAAE;AACnB;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;;AAEA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,SAAS;AACxB,eAAe,EAAE;AACjB,iBAAiB,SAAS;AAC1B;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,SAAS;AACxB,eAAe,OAAO;AACtB,iBAAiB,cAAc;AAC/B;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,SAAS;AACxB,eAAe,SAAS;AACxB,iBAAiB,SAAS;AAC1B;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA,eAAe,SAAS;AACxB,eAAe,SAAS;AACxB,eAAe,OAAO;AACtB,iBAAiB,SAAS;AAC1B;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,SAAS;AACxB,iBAAiB,SAAS;AAC1B;AACA;AACA;AACA;;AAEA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA,SAAS;AACT;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB,eAAe,OAAO;AACtB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,OAAO;AACP;AACA,KAAK;;AAEL;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,iBAAiB,cAAc;AAC/B;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,SAAS;AACxB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;AACA,SAAS;AACT;AACA;AACA,SAAS;AACT;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,iBAAiB,MAAM;AACvB,eAAe,OAAO;AACtB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;AACA;AACA;AACA;AACA,OAAO;AACP;AACA;;AAEA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB,eAAe,OAAO;AACtB,gBAAgB,OAAO;AACvB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,OAAO;AACP;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB,eAAe,KAAK;AACpB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB,eAAe,SAAS;AACxB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,KAAK;;AAEL;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB,eAAe,SAAS;AACxB,eAAe,SAAS;AACxB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;AACA;AACA;AACA,wBAAwB,SAAS,GAAG,SAAS,KAAK,SAAS;AAC3D,eAAe,SAAS;AACxB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,KAAK;;AAEL;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB,eAAe,SAAS;AACxB,eAAe,SAAS;AACxB,iBAAiB,MAAM;AACvB;AACA;AACA,uBAAuB,iBAAiB,GAAG,iBAAiB;AAC5D;AACA,mCAAmC,iBAAiB;AACpD,eAAe,iBAAiB;AAChC;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,KAAK;;AAEL;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB,eAAe,OAAO;AACtB,gBAAgB,OAAO;AACvB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB,eAAe,OAAO;AACtB,gBAAgB,OAAO;AACvB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB,eAAe,SAAS;AACxB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA,UAAU,oCAAoC;AAC9C,UAAU,qCAAqC;AAC/C,UAAU;AACV;AACA;AACA,4CAA4C,kBAAkB,EAAE;AAChE;AACA;AACA;AACA,gCAAgC,qCAAqC;AACrE;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB,eAAe,SAAS;AACxB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA,UAAU,qCAAqC;AAC/C,UAAU,qCAAqC;AAC/C,UAAU;AACV;AACA;AACA,uCAAuC,kBAAkB,EAAE;AAC3D;AACA;AACA;AACA,2BAA2B,oCAAoC;AAC/D;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB,eAAe,EAAE;AACjB,eAAe,OAAO;AACtB,eAAe,OAAO;AACtB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB,eAAe,SAAS;AACxB,eAAe,OAAO;AACtB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA,UAAU,qCAAqC;AAC/C,UAAU,qCAAqC;AAC/C,UAAU;AACV;AACA;AACA,uCAAuC,2BAA2B,EAAE;AACpE;AACA;AACA;AACA,2BAA2B,kCAAkC;AAC7D;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB,eAAe,SAAS;AACxB,eAAe,OAAO;AACtB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA,UAAU,oCAAoC;AAC9C,UAAU,qCAAqC;AAC/C,UAAU;AACV;AACA;AACA,2CAA2C,4BAA4B,EAAE;AACzE;AACA;AACA;AACA,+BAA+B,mCAAmC;AAClE;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB,eAAe,OAAO;AACtB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA,kCAAkC;AAClC;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA,cAAc;AACd;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB,iBAAiB,EAAE;AACnB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB,eAAe,EAAE;AACjB,eAAe,OAAO;AACtB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,SAAS;AACxB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,KAAK;;AAEL;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,SAAS;AACxB,eAAe,SAAS;AACxB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;AACA;AACA;AACA,0BAA0B,SAAS,KAAK,SAAS,GAAG,SAAS;AAC7D,eAAe,SAAS;AACxB;AACA;AACA;AACA;;AAEA;AACA;AACA,OAAO;AACP;AACA;AACA;AACA;AACA;AACA,KAAK;;AAEL;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,SAAS;AACxB,eAAe,SAAS;AACxB,iBAAiB,MAAM;AACvB;AACA;AACA,uBAAuB,iBAAiB,GAAG,iBAAiB;AAC5D,sBAAsB,iBAAiB,GAAG,iBAAiB;AAC3D;AACA;AACA,eAAe,iBAAiB;AAChC;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,KAAK;;AAEL;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB,eAAe,OAAO;AACtB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB,iBAAiB,EAAE;AACnB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB,eAAe,EAAE;AACjB,eAAe,OAAO;AACtB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB,eAAe,OAAO;AACtB,iBAAiB,EAAE;AACnB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB,eAAe,KAAK;AACpB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB,eAAe,MAAM;AACrB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB,eAAe,MAAM;AACrB,eAAe,SAAS;AACxB,iBAAiB,MAAM;AACvB;AACA;AACA,qBAAqB,SAAS,GAAG,SAAS,GAAG,SAAS,GAAG,SAAS;AAClE;AACA,4BAA4B,SAAS,GAAG,SAAS;AACjD;AACA,eAAe,SAAS;AACxB;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB,eAAe,MAAM;AACrB,eAAe,SAAS;AACxB,iBAAiB,MAAM;AACvB;AACA;AACA,qBAAqB,iBAAiB,GAAG,iBAAiB,GAAG,iBAAiB;AAC9E;AACA,8BAA8B,iBAAiB;AAC/C;AACA,eAAe,iBAAiB,GAAG,iBAAiB;AACpD;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB,eAAe,qBAAqB;AACpC,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA,OAAO;;AAEP;AACA,KAAK;;AAEL;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB,eAAe,SAAS;AACxB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;AACA;AACA,QAAQ;AACR;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB,eAAe,OAAO;AACtB,eAAe,OAAO;AACtB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB,eAAe,EAAE;AACjB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB,eAAe,EAAE;AACjB,eAAe,SAAS;AACxB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA,uBAAuB,SAAS,GAAG,SAAS;AAC5C;AACA,iCAAiC,SAAS,eAAe,YAAY,EAAE;AACvE;AACA;AACA;AACA,iCAAiC,SAAS;AAC1C;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB,eAAe,EAAE;AACjB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB,eAAe,EAAE;AACjB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB,eAAe,EAAE;AACjB,eAAe,SAAS;AACxB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA,uBAAuB,SAAS,GAAG,SAAS;AAC5C;AACA,qCAAqC,SAAS,eAAe,YAAY,EAAE;AAC3E;AACA;AACA;AACA,qCAAqC,SAAS;AAC9C;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB,eAAe,EAAE;AACjB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB,eAAe,SAAS;AACxB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB,eAAe,OAAO;AACtB,gBAAgB,OAAO;AACvB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB,eAAe,OAAO;AACtB,gBAAgB,OAAO;AACvB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB,eAAe,SAAS;AACxB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA,UAAU,oCAAoC;AAC9C,UAAU,qCAAqC;AAC/C,UAAU;AACV;AACA;AACA,4CAA4C,kBAAkB,EAAE;AAChE;AACA;AACA;AACA,gCAAgC,qCAAqC;AACrE;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB,eAAe,SAAS;AACxB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA,UAAU,qCAAqC;AAC/C,UAAU,qCAAqC;AAC/C,UAAU;AACV;AACA;AACA,uCAAuC,kBAAkB,EAAE;AAC3D;AACA;AACA;AACA,2BAA2B,oCAAoC;AAC/D;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,SAAS;AACxB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;AACA;AACA;AACA;AACA,KAAK;;AAEL;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,SAAS;AACxB,eAAe,SAAS;AACxB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;AACA;AACA;AACA,mBAAmB,SAAS,KAAK,SAAS,GAAG,SAAS;AACtD,eAAe,SAAS,GAAG,SAAS;AACpC;AACA;AACA;AACA;AACA;AACA;AACA;AACA,KAAK;;AAEL;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,SAAS;AACxB,eAAe,SAAS;AACxB,iBAAiB,MAAM;AACvB;AACA;AACA,uBAAuB,iBAAiB,GAAG,iBAAiB;AAC5D,sBAAsB,iBAAiB,GAAG,iBAAiB;AAC3D;AACA;AACA,eAAe,iBAAiB,GAAG,iBAAiB,GAAG,iBAAiB;AACxE;AACA;AACA;AACA;AACA;AACA,KAAK;;AAEL;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB,eAAe,SAAS;AACxB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;AACA;AACA;AACA,kBAAkB,SAAS,GAAG,SAAS,GAAG,SAAS;AACnD,eAAe,SAAS,GAAG,SAAS;AACpC;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB,eAAe,SAAS;AACxB,iBAAiB,MAAM;AACvB;AACA;AACA,uBAAuB,iBAAiB,GAAG,iBAAiB,GAAG,iBAAiB;AAChF;AACA;AACA,eAAe,iBAAiB,GAAG,iBAAiB;AACpD;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,OAAO;AACP;AACA;AACA,OAAO;AACP;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB,eAAe,SAAS;AACxB;AACA,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,OAAO;AACP;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB,eAAe,KAAK;AACpB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,KAAK;;AAEL;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,SAAS;AACxB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,KAAK;;AAEL;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,SAAS;AACxB,eAAe,SAAS;AACxB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;AACA;AACA;AACA,iBAAiB,SAAS,KAAK,SAAS,GAAG,SAAS;AACpD,eAAe,SAAS;AACxB;AACA;AACA;AACA;AACA;AACA;AACA;AACA,KAAK;;AAEL;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,SAAS;AACxB,eAAe,SAAS;AACxB,iBAAiB,MAAM;AACvB;AACA;AACA,uBAAuB,iBAAiB,GAAG,iBAAiB;AAC5D,sBAAsB,iBAAiB,GAAG,iBAAiB;AAC3D;AACA;AACA,eAAe,iBAAiB,GAAG,iBAAiB;AACpD;AACA;AACA;AACA;AACA;AACA,KAAK;;AAEL;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,SAAS;AACxB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB,eAAe,MAAM;AACrB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA,cAAc;AACd;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB,eAAe,MAAM;AACrB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA,cAAc,OAAO,QAAQ,SAAS,GAAG,SAAS,GAAG;AACrD;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,SAAS;AACxB,eAAe,SAAS;AACxB;AACA,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;AACA,QAAQ;AACR;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA,KAAK;;AAEL;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA,UAAU,+BAA+B;AACzC,UAAU,+BAA+B;AACzC,UAAU;AACV;AACA;AACA;AACA;AACA;AACA;AACA;AACA,UAAU;AACV;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA,oCAAoC;AACpC;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,eAAe,SAAS;AACxB,iBAAiB,EAAE;AACnB;AACA;AACA;AACA;AACA;AACA;AACA,SAAS;AACT;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,eAAe,SAAS;AACxB,iBAAiB,EAAE;AACnB;AACA;AACA;AACA;AACA;AACA;AACA;AACA,SAAS;AACT;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,qBAAqB;AACpC,iBAAiB,OAAO;AACxB;AACA;AACA,qBAAqB,QAAQ,OAAO,SAAS,EAAE;AAC/C;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,0CAA0C,8BAA8B;;AAExE;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,OAAO;AACP;AACA;AACA;AACA;AACA;AACA,OAAO;AACP,KAAK;;AAEL;AACA;AACA;AACA;AACA;AACA;AACA;AACA,iBAAiB,OAAO;AACxB;AACA;AACA;AACA,UAAU,8BAA8B;AACxC,UAAU;AACV;AACA;AACA;AACA;AACA,cAAc;AACd;AACA;AACA;AACA;AACA;AACA;AACA;AACA,cAAc;AACd;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;AACA,cAAc;AACd;AACA;AACA,cAAc;AACd;AACA;AACA,cAAc;AACd;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA,cAAc;AACd;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA,SAAS;AACT;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,SAAS;AACT;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,iBAAiB,EAAE;AACnB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,aAAa;AAC5B,eAAe,SAAS;AACxB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA,cAAc;AACd;AACA;AACA;AACA,cAAc;AACd;AACA;AACA;AACA;AACA,OAAO;AACP;AACA;AACA,KAAK;;AAEL;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,aAAa;AAC5B,eAAe,SAAS;AACxB,gBAAgB,OAAO;AACvB,iBAAiB,QAAQ;AACzB;AACA;AACA;AACA;AACA;AACA;AACA;AACA,UAAU,+CAA+C;AACzD,UAAU;AACV;AACA;AACA;AACA,uBAAuB,oCAAoC;AAC3D;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,aAAa;AAC5B,eAAe,SAAS;AACxB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;AACA,UAAU,8CAA8C;AACxD,UAAU;AACV;AACA;AACA,oCAAoC,kBAAkB,EAAE;AACxD;AACA;AACA;AACA,wBAAwB,4BAA4B;AACpD;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,oCAAoC,YAAY;AAChD;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,aAAa;AAC5B,eAAe,SAAS;AACxB,eAAe,OAAO;AACtB,iBAAiB,EAAE;AACnB;AACA;AACA;AACA,UAAU,+CAA+C;AACzD,UAAU,gDAAgD;AAC1D,UAAU;AACV;AACA;AACA,kCAAkC,mBAAmB,EAAE;AACvD;AACA;AACA;AACA,sBAAsB,2BAA2B;AACjD;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,aAAa;AAC5B,eAAe,SAAS;AACxB,eAAe,OAAO;AACtB,iBAAiB,EAAE;AACnB;AACA;AACA;AACA;AACA,QAAQ;AACR;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,aAAa;AAC5B,eAAe,SAAS;AACxB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,aAAa;AAC5B,eAAe,SAAS;AACxB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,aAAa;AAC5B,eAAe,SAAS;AACxB,eAAe,OAAO;AACtB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,aAAa;AAC5B,eAAe,SAAS;AACxB,iBAAiB,aAAa;AAC9B;AACA;AACA;AACA;AACA;AACA,QAAQ;AACR;AACA;AACA,kBAAkB,iBAAiB;AACnC;AACA,QAAQ;AACR;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,aAAa;AAC5B,eAAe,SAAS;AACxB,iBAAiB,aAAa;AAC9B;AACA;AACA;AACA;AACA;AACA,QAAQ;AACR;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,aAAa;AAC5B,eAAe,SAAS;AACxB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA,cAAc;AACd;AACA;AACA;AACA,cAAc;AACd;AACA;AACA;AACA;AACA,OAAO;AACP;AACA;AACA,KAAK;;AAEL;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,oBAAoB;AACnC,eAAe,EAAE;AACjB,eAAe,OAAO;AACtB,gBAAgB,OAAO;AACvB,iBAAiB,QAAQ;AACzB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,mBAAmB,iBAAiB;AACpC;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,aAAa;AAC5B,eAAe,sBAAsB;AACrC;AACA,eAAe,KAAK;AACpB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA,OAAO;AACP;AACA,KAAK;;AAEL;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,aAAa;AAC5B,eAAe,SAAS;AACxB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA,UAAU,4BAA4B;AACtC,UAAU;AACV;AACA;AACA;AACA;AACA,QAAQ;AACR,cAAc,OAAO,4BAA4B,QAAQ,8BAA8B;AACvF;AACA;AACA,cAAc,UAAU,4BAA4B,YAAY,8BAA8B;AAC9F;AACA;AACA;AACA,KAAK;;AAEL;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,aAAa;AAC5B,eAAe,SAAS;AACxB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,cAAc,iBAAiB;AAC/B;AACA;AACA;AACA,UAAU,mBAAmB;AAC7B,UAAU;AACV;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,aAAa;AAC5B,eAAe,qCAAqC;AACpD;AACA,eAAe,SAAS;AACxB,gBAAgB,OAAO;AACvB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA,UAAU,8BAA8B;AACxC,UAAU,8BAA8B;AACxC,UAAU,8BAA8B;AACxC,UAAU;AACV;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,aAAa;AAC5B,eAAe,SAAS;AACxB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA,UAAU,gDAAgD;AAC1D,UAAU,+CAA+C;AACzD,UAAU;AACV;AACA;AACA,uCAAuC,iBAAiB,EAAE;AAC1D;AACA;AACA;AACA,2BAA2B,4BAA4B;AACvD;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,KAAK,cAAc,iBAAiB,EAAE;;AAEtC;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,aAAa;AAC5B,eAAe,SAAS;AACxB,eAAe,EAAE;AACjB,iBAAiB,EAAE;AACnB;AACA;AACA;AACA;AACA;AACA,QAAQ;AACR;AACA;AACA,iBAAiB,yBAAyB;AAC1C;AACA;AACA,QAAQ,IAAI;AACZ,cAAc,8BAA8B;AAC5C;AACA;AACA;AACA;;AAEA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,aAAa;AAC5B,eAAe,SAAS;AACxB,eAAe,EAAE;AACjB,iBAAiB,EAAE;AACnB;AACA;AACA;AACA;AACA;AACA;AACA;AACA,QAAQ;AACR;AACA;AACA;AACA;AACA;;AAEA;AACA;;AAEA;AACA,kCAAkC;AAClC;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,aAAa;AAC5B,eAAe,SAAS;AACxB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;AACA,UAAU,+CAA+C;AACzD,UAAU;AACV;AACA;AACA,oCAAoC,kBAAkB,EAAE;AACxD;AACA;AACA;AACA,wBAAwB,4BAA4B;AACpD;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,aAAa;AAC5B,iBAAiB,EAAE;AACnB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,aAAa;AAC5B,eAAe,OAAO;AACtB,gBAAgB,OAAO;AACvB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,OAAO;AACP;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,aAAa;AAC5B,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,oBAAoB;AACnC,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;AACA,eAAe,iBAAiB;AAChC;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,aAAa;AAC5B,eAAe,SAAS;AACxB,gBAAgB,OAAO;AACvB,iBAAiB,QAAQ;AACzB;AACA;AACA;AACA;AACA;AACA;AACA;AACA,UAAU,mCAAmC;AAC7C,UAAU;AACV;AACA;AACA;AACA,sBAAsB,oCAAoC;AAC1D;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,aAAa;AAC5B,eAAe,yBAAyB;AACxC;AACA,iBAAiB,MAAM;AACvB;AACA;AACA;AACA,UAAU,8BAA8B;AACxC,UAAU,8BAA8B;AACxC,UAAU,8BAA8B;AACxC,UAAU;AACV;AACA;AACA,qCAAqC,eAAe,EAAE;AACtD;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,OAAO;AACP;AACA;AACA;AACA,KAAK;;AAEL;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA,QAAQ;AACR;AACA;AACA;AACA;AACA;;AAEA;;AAEA;AACA,kCAAkC;AAClC;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,SAAS;AACxB,iBAAiB,SAAS;AAC1B;AACA;AACA;AACA;AACA;AACA;AACA,QAAQ;AACR;AACA;AACA,oBAAoB,iCAAiC;AACrD,QAAQ;AACR;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,SAAS;AACxB,eAAe,OAAO;AACtB,gBAAgB,OAAO;AACvB,iBAAiB,SAAS;AAC1B;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,SAAS;AACxB,iBAAiB,SAAS;AAC1B;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,SAAS;AACxB,eAAe,EAAE;AACjB,eAAe,KAAK;AACpB,iBAAiB,SAAS;AAC1B;AACA;AACA;AACA;AACA;AACA;AACA,qBAAqB;AACrB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,KAAK;;AAEL;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,OAAO;AACtB,eAAe,KAAK;AACpB,iBAAiB,SAAS;AAC1B;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,KAAK;;AAEL;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,SAAS;AACxB,eAAe,OAAO;AACtB,gBAAgB,OAAO;AACvB,iBAAiB,SAAS;AAC1B;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,SAAS;AACxB,eAAe,OAAO;AACtB,gBAAgB,OAAO;AACvB,iBAAiB,SAAS;AAC1B;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,SAAS;AACxB,eAAe,OAAO;AACtB,eAAe,OAAO,YAAY;AAClC,eAAe,QAAQ;AACvB;AACA,eAAe,OAAO;AACtB;AACA,eAAe,QAAQ;AACvB;AACA,iBAAiB,SAAS;AAC1B;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,QAAQ;AACR;AACA;AACA,kDAAkD,kBAAkB;AACpE;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;;AAEA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;;AAEA;AACA;AACA;;AAEA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,SAAS;AACxB,eAAe,KAAK;AACpB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA,QAAQ;AACR;AACA;AACA;AACA;AACA,KAAK;;AAEL;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,SAAS;AACxB,eAAe,OAAO;AACtB,eAAe,KAAK;AACpB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA,QAAQ;AACR;AACA;AACA;AACA;AACA,KAAK;;AAEL;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,SAAS;AACxB,iBAAiB,SAAS;AAC1B;AACA;AACA;AACA;AACA,QAAQ;AACR;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,SAAS;AACxB,eAAe,SAAS;AACxB,iBAAiB,SAAS;AAC1B;AACA;AACA,qBAAqB;AACrB,oBAAoB;AACpB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,SAAS;AACxB,iBAAiB,SAAS;AAC1B;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,SAAS;AACxB,iBAAiB,SAAS;AAC1B;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,SAAS;AACxB,eAAe,yBAAyB;AACxC;AACA,iBAAiB,SAAS;AAC1B;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,QAAQ;AACR;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,OAAO;AACP,KAAK;;AAEL;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,SAAS;AACxB,eAAe,KAAK;AACpB,iBAAiB,SAAS;AAC1B;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,KAAK;;AAEL;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,SAAS;AACxB,eAAe,KAAK;AACpB,iBAAiB,SAAS;AAC1B;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,KAAK;;AAEL;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,SAAS;AACxB,eAAe,qBAAqB;AACpC,iBAAiB,SAAS;AAC1B;AACA;AACA;AACA;AACA,QAAQ;AACR;AACA;AACA;AACA;AACA;AACA;AACA,KAAK;;AAEL;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,SAAS;AACxB,eAAe,OAAO;AACtB,iBAAiB,SAAS;AAC1B;AACA;AACA;AACA;AACA;AACA,QAAQ;AACR;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,SAAS;AACxB,eAAe,OAAO;AACtB,iBAAiB,SAAS;AAC1B;AACA;AACA;AACA;AACA,QAAQ;AACR;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,QAAQ;AACR;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,OAAO;AACP;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,SAAS;AACxB,eAAe,OAAO;AACtB,eAAe,OAAO,YAAY;AAClC,eAAe,QAAQ;AACvB;AACA,eAAe,QAAQ;AACvB;AACA,iBAAiB,SAAS;AAC1B;AACA;AACA;AACA;AACA;AACA;AACA,uDAAuD,oBAAoB;AAC3E;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,OAAO;AACP;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,SAAS;AACxB,iBAAiB,SAAS;AAC1B;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,eAAe,SAAS;AACxB,iBAAiB,SAAS;AAC1B;AACA;AACA;AACA;AACA,QAAQ;AACR;AACA;AACA,oCAAoC;AACpC;AACA;AACA;AACA;;AAEA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;AACA;AACA,oBAAoB,SAAS;AAC7B,eAAe,SAAS;AACxB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,iBAAiB,EAAE;AACnB;AACA;AACA;AACA,uBAAuB,SAAS,GAAG,SAAS;AAC5C;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,4BAA4B;AAC5B;AACA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,eAAe,SAAS;AACxB,iBAAiB,EAAE;AACnB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,iBAAiB,EAAE;AACnB;AACA;AACA;AACA,uBAAuB,SAAS,GAAG,SAAS;AAC5C;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,eAAe,SAAS;AACxB,iBAAiB,EAAE;AACnB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,OAAO;AACtB,iBAAiB,QAAQ;AACzB;AACA;AACA,qBAAqB;AACrB;AACA,6BAA6B,mBAAmB,cAAc,EAAE,EAAE;AAClE;AACA;AACA,6BAA6B,mBAAmB,cAAc,EAAE,EAAE;AAClE;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,eAAe,EAAE;AACjB,iBAAiB,QAAQ;AACzB;AACA;AACA,qBAAqB;AACrB,oBAAoB;AACpB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,eAAe,EAAE;AACjB,iBAAiB,QAAQ;AACzB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,eAAe,EAAE;AACjB,iBAAiB,QAAQ;AACzB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,KAAK;;AAEL;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,iBAAiB,QAAQ;AACzB;AACA;AACA;AACA,iCAAiC,kBAAkB,EAAE;AACrD;AACA;AACA;AACA;AACA;AACA,kDAAkD,kBAAkB,EAAE;AACtE;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,iBAAiB,QAAQ;AACzB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,iBAAiB,QAAQ;AACzB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,iBAAiB,QAAQ;AACzB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,iBAAiB,QAAQ;AACzB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,iBAAiB,QAAQ;AACzB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,iBAAiB,QAAQ;AACzB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,iBAAiB,QAAQ;AACzB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,iBAAiB,QAAQ;AACzB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,iBAAiB,QAAQ;AACzB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,kBAAkB,SAAS;AAC3B;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,eAAe,EAAE;AACjB,iBAAiB,QAAQ;AACzB;AACA;AACA,qBAAqB;AACrB,oBAAoB;AACpB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,eAAe,EAAE;AACjB,eAAe,SAAS;AACxB,iBAAiB,QAAQ;AACzB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,iBAAiB,QAAQ;AACzB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,iBAAiB,QAAQ;AACzB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,iBAAiB,QAAQ;AACzB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,iBAAiB,QAAQ;AACzB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,iBAAiB,QAAQ;AACzB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,iBAAiB,QAAQ;AACzB;AACA;AACA,oBAAoB;AACpB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,iBAAiB,QAAQ;AACzB;AACA;AACA,wBAAwB;AACxB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,iBAAiB,QAAQ;AACzB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,OAAO;AACtB,iBAAiB,QAAQ;AACzB;AACA;AACA,qBAAqB;AACrB;AACA,0BAA0B,SAAS;AACnC;AACA;AACA,0BAA0B,SAAS;AACnC;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,OAAO;AACtB,eAAe,SAAS;AACxB,iBAAiB,QAAQ;AACzB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,qBAAqB;AACrB,qBAAqB;AACrB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,iBAAiB,QAAQ;AACzB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,iBAAiB,QAAQ;AACzB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,iBAAiB,QAAQ;AACzB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,iBAAiB,QAAQ;AACzB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,iBAAiB,QAAQ;AACzB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,iBAAiB,QAAQ;AACzB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,wBAAwB,iBAAiB;AACzC;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,iBAAiB,QAAQ;AACzB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,iBAAiB,QAAQ;AACzB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,iBAAiB,QAAQ;AACzB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,iBAAiB,QAAQ;AACzB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,iBAAiB,QAAQ;AACzB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,iBAAiB,QAAQ;AACzB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,iBAAiB,QAAQ;AACzB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,iBAAiB,QAAQ;AACzB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,iBAAiB,QAAQ;AACzB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,eAAe,EAAE;AACjB,iBAAiB,QAAQ;AACzB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,eAAe,EAAE;AACjB,iBAAiB,QAAQ;AACzB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,KAAK;;AAEL;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,iBAAiB,MAAM;AACvB;AACA;AACA,kBAAkB,iBAAiB;AACnC;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,iBAAiB,SAAS;AAC1B,cAAc;AACd;AACA,iBAAiB,SAAS;AAC1B,cAAc;AACd;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,UAAU;AACzB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,iBAAiB,SAAS;AAC1B,cAAc;AACd;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,KAAK;;AAEL;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,UAAU;AACzB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,mBAAmB,SAAS;AAC5B,cAAc;AACd;AACA;AACA;AACA,KAAK;;AAEL;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,UAAU;AACzB,eAAe,SAAS;AACxB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,iBAAiB,SAAS,GAAG,SAAS,GAAG,SAAS;AAClD,cAAc;AACd;AACA;AACA;AACA,KAAK;;AAEL;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,UAAU;AACzB,eAAe,SAAS;AACxB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,iBAAiB,SAAS,GAAG,SAAS,GAAG,SAAS;AAClD,cAAc;AACd;AACA;AACA;AACA,KAAK;;AAEL;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,qBAAqB;AACpC,iBAAiB,MAAM;AACvB;AACA;AACA,qBAAqB,QAAQ,OAAO,SAAS,EAAE;AAC/C;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,OAAO;AACtB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,QAAQ;AACR;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,UAAU;AACzB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA,mBAAmB,SAAS,GAAG,SAAS,GAAG,SAAS;AACpD,cAAc;AACd;AACA;AACA;;AAEA;AACA;AACA;;AAEA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA,KAAK;;AAEL;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,UAAU;AACzB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA,uBAAuB,OAAO,SAAS,EAAE,GAAG,OAAO,iBAAiB,EAAE;AACtE,cAAc,OAAO,iBAAiB;AACtC;AACA;AACA;AACA;AACA,KAAK;;AAEL;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,SAAS;AACxB,iBAAiB,iBAAiB;AAClC;AACA;AACA;AACA;AACA,qBAAqB,4BAA4B;AACjD,qBAAqB,6BAA6B;AAClD,qBAAqB;AACrB;AACA;AACA,qCAAqC,mBAAmB,EAAE;AAC1D;AACA;AACA;AACA,yBAAyB,2BAA2B;AACpD;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,SAAS;AACxB,iBAAiB,iBAAiB;AAClC;AACA;AACA;AACA;AACA,qBAAqB,4BAA4B;AACjD,qBAAqB,6BAA6B;AAClD,qBAAqB;AACrB;AACA;AACA,yCAAyC,mBAAmB,EAAE;AAC9D;AACA;AACA;AACA,6BAA6B,4BAA4B;AACzD;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,SAAS;AACxB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,QAAQ;AACR;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,SAAS;AACxB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,QAAQ;AACR;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,SAAS;AACxB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,QAAQ;AACR;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,SAAS;AACxB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,QAAQ;AACR;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,aAAa;AAC5B,eAAe,EAAE;AACjB,iBAAiB,EAAE;AACnB;AACA;AACA,qBAAqB,QAAQ,OAAO,SAAS,EAAE;AAC/C;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,aAAa;AAC5B,iBAAiB,QAAQ;AACzB;AACA;AACA,qBAAqB,OAAO,SAAS;AACrC,6BAA6B,gBAAgB,SAAS,GAAG;AACzD;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,aAAa;AAC5B,iBAAiB,QAAQ;AACzB;AACA;AACA,8BAA8B,gBAAgB,SAAS,GAAG;AAC1D;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,iBAAiB,OAAO;AACxB;AACA;AACA,qBAAqB;AACrB;AACA;AACA,cAAc;AACd;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA,KAAK;;AAEL;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,SAAS;AACxB,iBAAiB,OAAO;AACxB;AACA;AACA,qBAAqB;AACrB;AACA;AACA,cAAc;AACd;AACA;AACA;AACA,QAAQ;AACR,cAAc;AACd;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA,OAAO;AACP;AACA;AACA,KAAK;;AAEL;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,aAAa;AAC5B,eAAe,KAAK;AACpB,iBAAiB,EAAE;AACnB;AACA;AACA,qBAAqB,QAAQ,OAAO,oBAAoB,EAAE;AAC1D;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA,qCAAqC;AACrC;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,SAAS;AACxB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA,kBAAkB,iBAAiB;AACnC;AACA,QAAQ;AACR,cAAc;AACd;AACA;AACA;AACA;;AAEA;AACA;AACA,OAAO;AACP;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,SAAS;AACxB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA,qBAAqB,+BAA+B;AACpD,qBAAqB;AACrB;AACA;AACA,uCAAuC,cAAc,EAAE;AACvD,cAAc,2BAA2B;AACzC;AACA;AACA;AACA,cAAc,2BAA2B;AACzC;AACA;AACA;AACA;;AAEA;AACA;AACA,OAAO;AACP;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,UAAU;AACzB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA,gBAAgB,SAAS,GAAG,SAAS;AACrC;AACA;AACA;AACA,gBAAgB,SAAS,GAAG,SAAS;AACrC;AACA;AACA;AACA,cAAc,QAAQ,iBAAiB,GAAG,iBAAiB;AAC3D;AACA;AACA;AACA,KAAK;;AAEL;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,UAAU;AACzB,eAAe,SAAS;AACxB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,qBAAqB;AACrB,oBAAoB;AACpB;AACA;AACA,cAAc;AACd;AACA;AACA;AACA,KAAK;;AAEL;AACA,gCAAgC;AAChC;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,qBAAqB;AACpC,iBAAiB,OAAO;AACxB;AACA;AACA,qBAAqB;AACrB;AACA;AACA,cAAc;AACd;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,OAAO;AACP;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,KAAK;;AAEL;AACA,kCAAkC;AAClC;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,SAAS;AACxB,iBAAiB,OAAO;AACxB;AACA;AACA,qBAAqB;AACrB;AACA;AACA,cAAc;AACd;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,qBAAqB;AACpC,iBAAiB,OAAO;AACxB;AACA;AACA,qBAAqB;AACrB;AACA;AACA,cAAc;AACd;AACA;AACA,gCAAgC;AAChC,KAAK;;AAEL;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,SAAS;AACxB,iBAAiB,OAAO;AACxB;AACA;AACA,qBAAqB;AACrB;AACA;AACA,cAAc;AACd;AACA;AACA;AACA;AACA;AACA;AACA;AACA,OAAO;AACP;AACA;AACA;AACA,OAAO;AACP;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,aAAa;AAC5B,eAAe,EAAE;AACjB,iBAAiB,EAAE;AACnB;AACA;AACA,qBAAqB,QAAQ,OAAO,+BAA+B,EAAE;AACrE;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,aAAa;AAC5B,eAAe,EAAE;AACjB,iBAAiB,OAAO;AACxB;AACA;AACA,qBAAqB,QAAQ,OAAO,SAAS,EAAE;AAC/C;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,aAAa;AAC5B,eAAe,EAAE;AACjB,eAAe,SAAS;AACxB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;AACA,cAAc,OAAO,WAAW;AAChC;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA,oCAAoC;AACpC;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,SAAS;AACxB,eAAe,EAAE;AACjB,iBAAiB,EAAE;AACnB;AACA;AACA;AACA;AACA;AACA,QAAQ;AACR;AACA;AACA,oBAAoB,yBAAyB;AAC7C;AACA,QAAQ,IAAI;AACZ,cAAc;AACd;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,OAAO;AACP;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,aAAa;AAC5B,iBAAiB,QAAQ;AACzB;AACA;AACA,qBAAqB,QAAQ,OAAO,SAAS,EAAE;AAC/C;AACA;AACA;AACA;AACA,cAAc,QAAQ,QAAQ,EAAE;AAChC;AACA;AACA;AACA;AACA;AACA,cAAc,QAAQ,QAAQ,EAAE;AAChC;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,aAAa;AAC5B,eAAe,SAAS;AACxB,iBAAiB,OAAO;AACxB;AACA;AACA,qBAAqB,QAAQ,OAAO,SAAS,EAAE;AAC/C;AACA,iDAAiD,cAAc,EAAE;AACjE;AACA;AACA;AACA,iDAAiD,sBAAsB,EAAE;AACzE;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,aAAa;AAC5B,eAAe,SAAS;AACxB,eAAe,SAAS;AACxB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;AACA,cAAc,OAAO,WAAW;AAChC;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,OAAO;AACtB,eAAe,OAAO;AACtB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,OAAO;AACtB,eAAe,OAAO;AACtB,iBAAiB,QAAQ;AACzB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,OAAO;AACP;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,OAAO;AACtB,eAAe,QAAQ;AACvB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,SAAS;AACT;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,KAAK;;AAEL;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,OAAO;AACtB,eAAe,OAAO;AACtB,iBAAiB,QAAQ;AACzB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA,iCAAiC;AACjC;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA,kCAAkC,KAAK;AACvC;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,KAAK;;AAEL;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,KAAK;;AAEL;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,OAAO;AACtB,eAAe,OAAO;AACtB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,OAAO;AACtB,eAAe,OAAO;AACtB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,OAAO;AACtB,eAAe,OAAO;AACtB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,OAAO;AACtB,gBAAgB,OAAO;AACvB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,OAAO;AACP;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,OAAO;AACtB,gBAAgB,OAAO;AACvB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,OAAO;AACP;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,cAAc;AAC7B,eAAe,gBAAgB;AAC/B,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,KAAK;;AAEL;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,cAAc;AAC7B,eAAe,OAAO;AACtB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,KAAK;;AAEL;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,OAAO;AACtB,eAAe,OAAO;AACtB,iBAAiB,QAAQ;AACzB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,OAAO,YAAY;AAClC,eAAe,OAAO;AACtB;AACA,eAAe,OAAO;AACtB;AACA,eAAe,OAAO;AACtB;AACA,eAAe,OAAO;AACtB;AACA,eAAe,OAAO;AACtB;AACA,eAAe,OAAO;AACtB;AACA,gBAAgB,OAAO;AACvB,iBAAiB,SAAS;AAC1B;AACA;AACA;AACA;AACA,iBAAiB,iBAAiB;AAClC;AACA;AACA;AACA;AACA,iBAAiB,sBAAsB;AACvC,qBAAqB,UAAU;AAC/B;AACA;AACA,sEAAsE,2BAA2B,EAAE;AACnG,iBAAiB,8BAA8B;AAC/C;AACA;AACA;AACA,4DAA4D;AAC5D,iBAAiB,mBAAmB;AACpC;AACA;AACA;AACA;AACA,0CAA0C,OAAO;AACjD,iBAAiB,oBAAoB;AACrC;AACA;AACA;AACA;AACA,iBAAiB,qBAAqB;AACtC;AACA;AACA;AACA,qDAAqD,2BAA2B,EAAE;AAClF,wCAAwC,aAAa,eAAe,EAAE;AACtE,iBAAiB,8BAA8B;AAC/C;AACA;AACA;AACA,wDAAwD,qCAAqC;AAC7F;AACA;AACA;AACA;AACA,0DAA0D,qBAAqB;AAC/E;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,2CAA2C,YAAY;AACvD,0CAA0C,QAAQ;AAClD,iBAAiB,qBAAqB;AACtC;AACA;AACA;AACA;AACA;AACA,oBAAoB;AACpB;AACA,WAAW;AACX;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA,+BAA+B;;AAE/B,mCAAmC;AACnC;AACA;;AAEA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;;AAEA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,uBAAuB,wBAAwB;AAC/C;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA,OAAO;;AAEP,mBAAmB;;AAEnB;AACA;AACA;AACA;AACA,8BAA8B,mBAAmB;AACjD;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA,4CAA4C;;AAE5C;AACA,uDAAuD;AACvD;AACA;AACA,6BAA6B,EAAE;AAC/B;AACA;AACA;AACA;AACA;AACA;AACA;AACA,0CAA0C;AAC1C,+BAA+B,iCAAiC;AAChE,cAAc;AACd;AACA;AACA,sBAAsB;;AAEtB;AACA;AACA;AACA,OAAO;;AAEP;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,OAAO;AACtB,gBAAgB,OAAO;AACvB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,OAAO;AACtB,gBAAgB,OAAO;AACvB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,OAAO;AACtB,gBAAgB,OAAO;AACvB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,OAAO,YAAY;AAClC,eAAe,OAAO;AACtB,eAAe,OAAO;AACtB,eAAe,cAAc;AAC7B,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,QAAQ;AACR;AACA;AACA;AACA;AACA;AACA,QAAQ;AACR;AACA;AACA;AACA;AACA,QAAQ;AACR;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,OAAO;AACP;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA,iCAAiC;AACjC,aAAa,QAAQ,QAAQ,UAAU,aAAa;AACpD;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,iBAAiB,OAAO;AACxB;AACA;AACA,sCAAsC;AACtC;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,KAAK;;AAEL;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,cAAc;AAC7B,gBAAgB,OAAO;AACvB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;;AAEA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,SAAS;AACxB,eAAe,KAAK;AACpB,iBAAiB,EAAE;AACnB;AACA;AACA;AACA;AACA;AACA,QAAQ;AACR;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,OAAO;AACP;AACA;AACA,KAAK;;AAEL;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,qBAAqB;AACpC,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,OAAO;AACP;AACA,KAAK;;AAEL;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB,iBAAiB,SAAS;AAC1B;AACA;AACA;AACA,qBAAqB,SAAS;AAC9B,sBAAsB,kBAAkB;AACxC;AACA;AACA;AACA,aAAa,iBAAiB;AAC9B;AACA;AACA,aAAa,iBAAiB;AAC9B;AACA;AACA,aAAa,qBAAqB;AAClC;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA,OAAO;;AAEP;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,OAAO;AACP;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,iBAAiB,SAAS;AAC1B;AACA;AACA;AACA,UAAU,iBAAiB;AAC3B,UAAU;AACV;AACA;AACA,qCAAqC,mBAAmB,cAAc,EAAE,EAAE;AAC1E,eAAe,iBAAiB;AAChC;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,iBAAiB,SAAS;AAC1B;AACA;AACA,4CAA4C,SAAS;AACrD;AACA;AACA,eAAe,SAAS,GAAG,SAAS;AACpC;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,eAAe,EAAE;AACjB,iBAAiB,EAAE;AACnB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,yBAAyB;AACxC,iBAAiB,SAAS;AAC1B;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,yBAAyB;AACxC,iBAAiB,SAAS;AAC1B;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,iBAAiB,EAAE;AACnB;AACA;AACA,qBAAqB;AACrB;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,iBAAiB,SAAS;AAC1B;AACA;AACA;AACA,UAAU,8CAA8C;AACxD,UAAU;AACV;AACA;AACA;AACA,mCAAmC,mCAAmC;AACtE,eAAe,8CAA8C;AAC7D;AACA;AACA;AACA,eAAe,4BAA4B;AAC3C;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,QAAQ;AACR;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,iBAAiB,SAAS;AAC1B;AACA;AACA;AACA,UAAU,yBAAyB;AACnC,UAAU;AACV;AACA;AACA,oCAAoC,iBAAiB;AACrD,eAAe,yBAAyB;AACxC;AACA;AACA,gDAAgD,SAAS,cAAc,SAAS;AAChF,eAAe,yBAAyB,GAAG,yBAAyB;AACpE;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,aAAa;AAC5B,eAAe,EAAE;AACjB,iBAAiB,SAAS;AAC1B;AACA;AACA;AACA,UAAU,yBAAyB;AACnC,UAAU;AACV;AACA;AACA;AACA,cAAc;AACd;AACA;AACA;AACA,eAAe,yBAAyB,GAAG,yBAAyB;AACpE;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,aAAa;AAC5B,eAAe,KAAK;AACpB,iBAAiB,SAAS;AAC1B;AACA;AACA;AACA,UAAU,OAAO,qBAAqB,EAAE;AACxC,UAAU,OAAO,qBAAqB;AACtC;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,KAAK;;AAEL;AACA,kCAAkC;AAClC;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,KAAK;AACpB,iBAAiB,SAAS;AAC1B;AACA;AACA;AACA,qBAAqB;AACrB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,KAAK;;AAEL;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,gBAAgB;AAC/B,eAAe,OAAO;AACtB,eAAe,OAAO,YAAY;AAClC,eAAe,QAAQ;AACvB,iBAAiB,gBAAgB;AACjC;AACA;AACA;AACA;AACA;AACA,UAAU;AACV;AACA;AACA,gBAAgB,mBAAmB;AACnC;AACA;AACA;AACA;AACA;AACA;AACA,gBAAgB,mBAAmB,GAAG,iBAAiB;AACvD;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA,4BAA4B,qDAAqD;AACjF;AACA;AACA;AACA;AACA;AACA;AACA,OAAO;;AAEP;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,iBAAiB,SAAS;AAC1B;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,iBAAiB,SAAS;AAC1B;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,OAAO;AACP;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,yBAAyB;AACxC;AACA,iBAAiB,SAAS;AAC1B;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,yBAAyB;AACxC;AACA,iBAAiB,SAAS;AAC1B;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,yBAAyB;AACxC;AACA,iBAAiB,SAAS;AAC1B;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,sCAAsC,SAAS,GAAG,SAAS;AAC3D;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,aAAa;AAC5B,iBAAiB,SAAS;AAC1B;AACA;AACA;AACA,UAAU,OAAO,SAAS,EAAE;AAC5B,UAAU,OAAO,SAAS;AAC1B;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA,oCAAoC;AACpC;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,iBAAiB,SAAS;AAC1B;AACA;AACA;AACA,qBAAqB;AACrB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,OAAO;AACtB,eAAe,OAAO;AACtB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,OAAO;AACtB,eAAe,OAAO;AACtB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,iBAAiB,QAAQ;AACzB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;AACA,gBAAgB,IAAI;AACpB;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,iBAAiB,QAAQ;AACzB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA,kEAAkE;AAClE;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,SAAS;AACxB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,EAAE;AACjB,iBAAiB,MAAM;AACvB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,OAAO;AACtB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;AACA;AACA;AACA,KAAK;;AAEL;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,OAAO;AACtB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,OAAO;AACtB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;AACA;AACA;AACA,KAAK;;AAEL;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,OAAO;AACtB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB,iBAAiB,EAAE;AACnB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB,eAAe,SAAS;AACxB,iBAAiB,EAAE;AACnB;AACA;AACA,uBAAuB,SAAS,GAAG,SAAS;AAC5C;AACA,qCAAqC,YAAY,EAAE;AACnD,cAAc;AACd;AACA;AACA;AACA,cAAc;AACd;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB,eAAe,SAAS;AACxB,iBAAiB,OAAO;AACxB;AACA;AACA,uBAAuB,SAAS,GAAG,SAAS,GAAG,SAAS,GAAG,SAAS;AACpE;AACA,sCAAsC,YAAY,EAAE;AACpD;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB,iBAAiB,EAAE;AACnB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB,eAAe,SAAS;AACxB,iBAAiB,EAAE;AACnB;AACA;AACA,uBAAuB,SAAS,GAAG,SAAS;AAC5C;AACA,qCAAqC,YAAY,EAAE;AACnD,cAAc;AACd;AACA;AACA;AACA,cAAc;AACd;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,OAAO;AACtB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;AACA;AACA;AACA,KAAK;;AAEL;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,OAAO;AACtB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,OAAO;AACtB,eAAe,OAAO;AACtB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;AACA;AACA;AACA,KAAK;;AAEL;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB,iBAAiB,OAAO;AACxB;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,eAAe,MAAM;AACrB,eAAe,SAAS;AACxB,iBAAiB,OAAO;AACxB;AACA;AACA,uBAAuB,SAAS,GAAG,SAAS,GAAG,SAAS,GAAG,SAAS;AACpE;AACA,qCAAqC,YAAY,EAAE;AACnD;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;;AAEA;AACA;;AAEA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA,OAAO;AACP;AACA,KAAK,MAAM,iBAAiB;;AAE5B;;AAEA;AACA;AACA;AACA;AACA;AACA,cAAc;AACd;AACA;;AAEA;AACA;AACA;AACA,KAAK;;AAEL;AACA;AACA;AACA;;AAEA;AACA;AACA;;AAEA;AACA;AACA,SAAS;AACT;AACA;AACA;AACA,WAAW;AACX;AACA;AACA;;AAEA;AACA;AACA;AACA,KAAK;;AAEL;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA,SAAS;AACT;AACA;AACA;AACA,KAAK;;AAEL;AACA;AACA;;AAEA;AACA;AACA;AACA,KAAK;;AAEL;AACA;AACA;;AAEA;AACA;AACA;AACA,KAAK;;AAEL;AACA;AACA;;AAEA;AACA;AACA;;AAEA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA,OAAO;AACP,KAAK;;AAEL;AACA;AACA;;AAEA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA,OAAO;AACP;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;;AAEA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA,mCAAmC,4DAA4D;AAC/F;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,KAAK;;AAEL;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,SAAS;AACT;AACA,KAAK;;AAEL;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,6BAA6B,yCAAyC;AACtE;AACA,KAAK;;AAEL;AACA;AACA;AACA,KAAK;;AAEL;AACA;AACA;AACA;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;;AAEA;AACA;;AAEA;AACA;AACA;AACA;AACA,GAAG;;AAEH;;AAEA;AACA;;AAEA;AACA,MAAM,IAA0E;AAChF;AACA;AACA;AACA;AACA;;AAEA;AACA;AACA,IAAI,mCAAO;AACX;AACA,KAAK;AAAA,oGAAC;AACN;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,CAAC;;;;;;;;ACxzhBD;;AAEA;AACA;AACA;AACA,CAAC;;AAED;AACA;AACA;AACA,CAAC;AACD;AACA;AACA;AACA;;AAEA;AACA;AACA,4CAA4C;;AAE5C;;;;;;;ACpBA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,GAAG;AACH;AACA;AACA;AACA;AACA;AACA,GAAG;AACH;AACA;AACA;AACA","file":"index.js","sourcesContent":[" \t// The module cache\n \tvar installedModules = {};\n\n \t// The require function\n \tfunction __webpack_require__(moduleId) {\n\n \t\t// Check if module is in cache\n \t\tif(installedModules[moduleId]) {\n \t\t\treturn installedModules[moduleId].exports;\n \t\t}\n \t\t// Create a new module (and put it into the cache)\n \t\tvar module = installedModules[moduleId] = {\n \t\t\ti: moduleId,\n \t\t\tl: false,\n \t\t\texports: {}\n \t\t};\n\n \t\t// Execute the module function\n \t\tmodules[moduleId].call(module.exports, module, module.exports, __webpack_require__);\n\n \t\t// Flag the module as loaded\n \t\tmodule.l = true;\n\n \t\t// Return the exports of the module\n \t\treturn module.exports;\n \t}\n\n\n \t// expose the modules object (__webpack_modules__)\n \t__webpack_require__.m = modules;\n\n \t// expose the module cache\n \t__webpack_require__.c = installedModules;\n\n \t// define getter function for harmony exports\n \t__webpack_require__.d = function(exports, name, getter) {\n \t\tif(!__webpack_require__.o(exports, name)) {\n \t\t\tObject.defineProperty(exports, name, {\n \t\t\t\tconfigurable: false,\n \t\t\t\tenumerable: true,\n \t\t\t\tget: getter\n \t\t\t});\n \t\t}\n \t};\n\n \t// getDefaultExport function for compatibility with non-harmony modules\n \t__webpack_require__.n = function(module) {\n \t\tvar getter = module && module.__esModule ?\n \t\t\tfunction getDefault() { return module['default']; } :\n \t\t\tfunction getModuleExports() { return module; };\n \t\t__webpack_require__.d(getter, 'a', getter);\n \t\treturn getter;\n \t};\n\n \t// Object.prototype.hasOwnProperty.call\n \t__webpack_require__.o = function(object, property) { return Object.prototype.hasOwnProperty.call(object, property); };\n\n \t// __webpack_public_path__\n \t__webpack_require__.p = \"\";\n\n \t// Load entry module and return exports\n \treturn __webpack_require__(__webpack_require__.s = 1);\n\n\n\n// WEBPACK FOOTER //\n// webpack/bootstrap 9db01875ed55e7424b10","module.exports = {\"name\":\"tensorflow_model_analysis\",\"version\":\"0.34.1\",\"homepage\":\"https://github.com/tensorflow/model-analysis\",\"bugs\":\"https://github.com/tensorflow/model-analysis/issues\",\"license\":\"Apache-2.0\",\"repository\":{\"type\":\"git\",\"url\":\"https://github.com/tensorflow/model-analysis.git\"},\"main\":\"lib/index.js\",\"files\":[\"lib/**/*.js\",\"dist/*.js\",\"README.md\",\"LICENSE\"],\"scripts\":{\"clean\":\"rimraf dist/\",\"prepare\":\"webpack && ./collect-files-before-publish.sh\",\"test\":\"echo \\\"Error: no test specified\\\" && exit 1\"},\"devDependencies\":{\"webpack\":\"^3.5.5\",\"rimraf\":\"^2.6.1\"},\"dependencies\":{\"@jupyter-widgets/base\":\"^1.1 || ^2 || ^3 || ^4\",\"lodash\":\"^4.17.4\"},\"jupyterlab\":{\"extension\":\"lib/labplugin\",\"sharedPackages\":{\"@jupyter-widgets/base\":{\"bundled\":false,\"singleton\":true}}},\"publishConfig\":{\"registry\":\"https://wombat-dressing-room.appspot.com\"}}\n\n\n//////////////////\n// WEBPACK FOOTER\n// ./package.json\n// module id = 0\n// module chunks = 0","/**\n * Copyright 2018 Google LLC\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * https://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n// Export widget models and views, and the npm package version number.\nmodule.exports = require('./widget.js');\nmodule.exports['version'] = require('../package.json').version;\n\n\n\n//////////////////\n// WEBPACK FOOTER\n// ./lib/index.js\n// module id = 1\n// module chunks = 0","/**\n * Copyright 2018 Google LLC\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * https://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\nconst widgets = require('@jupyter-widgets/base');\nconst _ = require('lodash');\nconst version = require('../package.json').version;\n\n/**\n * Helper method to load the vulcanized templates.\n */\nfunction loadVulcanizedTemplate() {\n let templatePath;\n const dataBaseUrl =\n document.querySelector('body').getAttribute('data-base-url');\n // Jupyter Classic\n if (dataBaseUrl) {\n templatePath = dataBaseUrl + 'nbextensions/tensorflow_model_analysis/';\n }\n // Jupyter Lab\n else if (window['isJupyterLab']) {\n let baseUrl = '/';\n const jupyterConfigData = document.getElementById('jupyter-config-data');\n if (jupyterConfigData) {\n const configData = JSON.parse(jupyterConfigData.textContent || '');\n if (configData) {\n baseUrl = configMap['baseUrl'] || '/';\n }\n }\n\n templatePath = baseUrl + 'nbextensions/tensorflow_model_analysis/';\n }\n // Kubeflow\n else {\n templatePath = __webpack_public_path__;\n }\n // templatePath ends with a slash.\n const templateLocation = `${templatePath}vulcanized_tfma.js`;\n\n // If the vulcanizes tempalets are not loaded yet, load it now.\n if (!document.querySelector('script[src=\"' + templateLocation + '\"]')) {\n const script = document.createElement('script');\n script.setAttribute('src', templateLocation);\n document.head.appendChild(script);\n }\n}\n\n/**\n * HACK: Calls the render callback in a setTimeout. This delay avoids some\n * rendering artifacts.\n * @param {!Function} cb\n */\nfunction delayedRender(cb) {\n setTimeout(cb, 0);\n}\n\nconst MODULE_NAME = 'tensorflow_model_analysis';\nconst MODEL_VERSION = version;\nconst VIEW_VERSION = version;\nconst SLICING_METRICS_MODEL_NAME = 'SlicingMetricsModel';\nconst SLICING_METRICS_VIEW_NAME = 'SlicingMetricsView';\nconst SLICING_METRICS_ELEMENT_NAME = 'tfma-nb-slicing-metrics';\nconst TIME_SERIES_MODEL_NAME = 'TimeSeriesModel';\nconst TIME_SERIES_VIEW_NAME = 'TimeSeriesView';\nconst TIME_SERIES_ELEMENT_NAME = 'tfma-nb-time-series';\nconst PLOT_MODEL_NAME = 'PlotModel';\nconst PLOT_VIEW_NAME = 'PlotView';\nconst PLOT_ELEMENT_NAME = 'tfma-nb-plot';\nconst FAIRNESS_INDICATOR_MODEL_NAME = 'FairnessIndicatorModel';\nconst FAIRNESS_INDICATOR_VIEW_NAME = 'FairnessIndicatorView';\nconst FAIRNESS_INDICATOR_ELEMENT_NAME = 'fairness-nb-container';\n\nconst SlicingMetricsModel = widgets.DOMWidgetModel.extend({\n defaults: _.extend(widgets.DOMWidgetModel.prototype.defaults(), {\n _model_name: SLICING_METRICS_MODEL_NAME,\n _view_name: SLICING_METRICS_VIEW_NAME,\n _model_module: MODULE_NAME,\n _view_module: MODULE_NAME,\n _model_module_version: MODEL_VERSION,\n _view_module_version: VIEW_VERSION,\n config: {},\n data: [],\n js_events: [],\n })\n});\n\nconst SlicingMetricsView = widgets.DOMWidgetView.extend({\n render: function() {\n loadVulcanizedTemplate();\n\n this.view_ = document.createElement(SLICING_METRICS_ELEMENT_NAME);\n this.el.appendChild(this.view_);\n\n this.view_.addEventListener('tfma-event', (e) => {\n handleTfmaEvent(e, this);\n });\n\n delayedRender(() => {\n this.configChanged_();\n this.dataChanged_();\n this.model.on('change:config', this.configChanged_, this);\n this.model.on('change:data', this.dataChanged_, this);\n });\n },\n dataChanged_: function() {\n this.view_.data = this.model.get('data');\n },\n configChanged_: function() {\n this.view_.config = this.model.get('config');\n },\n});\n\nconst TimeSeriesModel = widgets.DOMWidgetModel.extend({\n defaults: _.extend(widgets.DOMWidgetModel.prototype.defaults(), {\n _model_name: TIME_SERIES_MODEL_NAME,\n _view_name: TIME_SERIES_VIEW_NAME,\n _model_module: MODULE_NAME,\n _view_module: MODULE_NAME,\n _model_module_version: MODEL_VERSION,\n _view_module_version: VIEW_VERSION,\n config: {},\n data: [],\n })\n});\n\nconst TimeSeriesView = widgets.DOMWidgetView.extend({\n render: function() {\n loadVulcanizedTemplate();\n\n this.view_ = document.createElement(TIME_SERIES_ELEMENT_NAME);\n this.el.appendChild(this.view_);\n\n delayedRender(() => {\n this.configChanged_();\n this.dataChanged_();\n this.model.on('change:config', this.configChanged_, this);\n this.model.on('change:data', this.dataChanged_, this);\n });\n },\n dataChanged_: function() {\n this.view_.data = this.model.get('data');\n },\n configChanged_: function() {\n this.view_.config = this.model.get('config');\n },\n});\n\nconst PlotModel = widgets.DOMWidgetModel.extend({\n defaults: _.extend(widgets.DOMWidgetModel.prototype.defaults(), {\n _model_name: PLOT_MODEL_NAME,\n _view_name: PLOT_VIEW_NAME,\n _model_module: MODULE_NAME,\n _view_module: MODULE_NAME,\n _model_module_version: MODEL_VERSION,\n _view_module_version: VIEW_VERSION,\n config: {},\n data: [],\n })\n});\n\nconst PlotView = widgets.DOMWidgetView.extend({\n render: function() {\n loadVulcanizedTemplate();\n\n this.view_ = document.createElement(PLOT_ELEMENT_NAME);\n this.el.appendChild(this.view_);\n\n delayedRender(() => {\n this.configChanged_();\n this.dataChanged_();\n this.model.on('change:config', this.configChanged_, this);\n this.model.on('change:data', this.dataChanged_, this);\n });\n },\n dataChanged_: function() {\n this.view_.data = this.model.get('data');\n },\n configChanged_: function() {\n this.view_.config = this.model.get('config');\n },\n});\n\nconst FairnessIndicatorModel = widgets.DOMWidgetModel.extend({\n defaults: _.extend(widgets.DOMWidgetModel.prototype.defaults(), {\n _model_name: FAIRNESS_INDICATOR_MODEL_NAME,\n _view_name: FAIRNESS_INDICATOR_VIEW_NAME,\n _model_module: MODULE_NAME,\n _view_module: MODULE_NAME,\n _model_module_version: MODEL_VERSION,\n _view_module_version: VIEW_VERSION,\n slicingMetrics: [],\n slicingMetricsCompare: [],\n evalName: '',\n evalNameCompare: '',\n js_events: [],\n })\n});\n\nconst FairnessIndicatorView = widgets.DOMWidgetView.extend({\n render: function() {\n loadVulcanizedTemplate();\n\n this.view_ = document.createElement(FAIRNESS_INDICATOR_ELEMENT_NAME);\n this.el.appendChild(this.view_);\n\n this.view_.addEventListener('tfma-event', (e) => {\n handleTfmaEvent(e, this);\n });\n\n delayedRender(() => {\n this.slicingMetricsChanged_();\n this.slicingMetricsCompareChanged_();\n this.evalNameChanged_();\n this.evalNameCompareChanged_();\n this.model.on('change:slicingMetrics', this.slicingMetricsChanged_, this);\n this.model.on(\n 'change:slicingMetricsCompare', this.slicingMetricsCompareChanged_,\n this);\n this.model.on('change:evalName', this.evalNameChanged_, this);\n this.model.on(\n 'change:evalNameCompare', this.evalNameCompareChanged_, this);\n });\n },\n slicingMetricsChanged_: function() {\n this.view_.slicingMetrics = this.model.get('slicingMetrics');\n },\n slicingMetricsCompareChanged_: function() {\n this.view_.slicingMetricsCompare = this.model.get('slicingMetricsCompare');\n },\n evalNameChanged_: function() {\n this.view_.evalName = this.model.get('evalName');\n },\n evalNameCompareChanged_: function() {\n this.view_.evalNameCompare = this.model.get('evalNameCompare');\n },\n});\n\n/**\n * Handler for events of type \"tfma-event\" for the given view element.\n * @param {!Event} tfmaEvent\n * @param {!Element} view\n */\nconst handleTfmaEvent = (tfmaEvent, view) => {\n const model = view.model;\n const jsEvents = model.get('js_events').slice();\n const detail = tfmaEvent.detail;\n jsEvents.push({'name': detail.type, 'detail': detail.detail});\n model.set('js_events', jsEvents);\n view.touch();\n};\n\nmodule.exports = {\n [PLOT_MODEL_NAME]: PlotModel,\n [PLOT_VIEW_NAME]: PlotView,\n [SLICING_METRICS_MODEL_NAME]: SlicingMetricsModel,\n [SLICING_METRICS_VIEW_NAME]: SlicingMetricsView,\n [TIME_SERIES_MODEL_NAME]: TimeSeriesModel,\n [TIME_SERIES_VIEW_NAME]: TimeSeriesView,\n [FAIRNESS_INDICATOR_MODEL_NAME]: FairnessIndicatorModel,\n [FAIRNESS_INDICATOR_VIEW_NAME]: FairnessIndicatorView,\n};\n\n\n\n//////////////////\n// WEBPACK FOOTER\n// ./lib/widget.js\n// module id = 2\n// module chunks = 0","module.exports = __WEBPACK_EXTERNAL_MODULE_3__;\n\n\n//////////////////\n// WEBPACK FOOTER\n// external \"@jupyter-widgets/base\"\n// module id = 3\n// module chunks = 0","/**\n * @license\n * Lodash \n * Copyright OpenJS Foundation and other contributors \n * Released under MIT license \n * Based on Underscore.js 1.8.3 \n * Copyright Jeremy Ashkenas, DocumentCloud and Investigative Reporters & Editors\n */\n;(function() {\n\n /** Used as a safe reference for `undefined` in pre-ES5 environments. */\n var undefined;\n\n /** Used as the semantic version number. */\n var VERSION = '4.17.21';\n\n /** Used as the size to enable large array optimizations. */\n var LARGE_ARRAY_SIZE = 200;\n\n /** Error message constants. */\n var CORE_ERROR_TEXT = 'Unsupported core-js use. Try https://npms.io/search?q=ponyfill.',\n FUNC_ERROR_TEXT = 'Expected a function',\n INVALID_TEMPL_VAR_ERROR_TEXT = 'Invalid `variable` option passed into `_.template`';\n\n /** Used to stand-in for `undefined` hash values. */\n var HASH_UNDEFINED = '__lodash_hash_undefined__';\n\n /** Used as the maximum memoize cache size. */\n var MAX_MEMOIZE_SIZE = 500;\n\n /** Used as the internal argument placeholder. */\n var PLACEHOLDER = '__lodash_placeholder__';\n\n /** Used to compose bitmasks for cloning. */\n var CLONE_DEEP_FLAG = 1,\n CLONE_FLAT_FLAG = 2,\n CLONE_SYMBOLS_FLAG = 4;\n\n /** Used to compose bitmasks for value comparisons. */\n var COMPARE_PARTIAL_FLAG = 1,\n COMPARE_UNORDERED_FLAG = 2;\n\n /** Used to compose bitmasks for function metadata. */\n var WRAP_BIND_FLAG = 1,\n WRAP_BIND_KEY_FLAG = 2,\n WRAP_CURRY_BOUND_FLAG = 4,\n WRAP_CURRY_FLAG = 8,\n WRAP_CURRY_RIGHT_FLAG = 16,\n WRAP_PARTIAL_FLAG = 32,\n WRAP_PARTIAL_RIGHT_FLAG = 64,\n WRAP_ARY_FLAG = 128,\n WRAP_REARG_FLAG = 256,\n WRAP_FLIP_FLAG = 512;\n\n /** Used as default options for `_.truncate`. */\n var DEFAULT_TRUNC_LENGTH = 30,\n DEFAULT_TRUNC_OMISSION = '...';\n\n /** Used to detect hot functions by number of calls within a span of milliseconds. */\n var HOT_COUNT = 800,\n HOT_SPAN = 16;\n\n /** Used to indicate the type of lazy iteratees. */\n var LAZY_FILTER_FLAG = 1,\n LAZY_MAP_FLAG = 2,\n LAZY_WHILE_FLAG = 3;\n\n /** Used as references for various `Number` constants. */\n var INFINITY = 1 / 0,\n MAX_SAFE_INTEGER = 9007199254740991,\n MAX_INTEGER = 1.7976931348623157e+308,\n NAN = 0 / 0;\n\n /** Used as references for the maximum length and index of an array. */\n var MAX_ARRAY_LENGTH = 4294967295,\n MAX_ARRAY_INDEX = MAX_ARRAY_LENGTH - 1,\n HALF_MAX_ARRAY_LENGTH = MAX_ARRAY_LENGTH >>> 1;\n\n /** Used to associate wrap methods with their bit flags. */\n var wrapFlags = [\n ['ary', WRAP_ARY_FLAG],\n ['bind', WRAP_BIND_FLAG],\n ['bindKey', WRAP_BIND_KEY_FLAG],\n ['curry', WRAP_CURRY_FLAG],\n ['curryRight', WRAP_CURRY_RIGHT_FLAG],\n ['flip', WRAP_FLIP_FLAG],\n ['partial', WRAP_PARTIAL_FLAG],\n ['partialRight', WRAP_PARTIAL_RIGHT_FLAG],\n ['rearg', WRAP_REARG_FLAG]\n ];\n\n /** `Object#toString` result references. */\n var argsTag = '[object Arguments]',\n arrayTag = '[object Array]',\n asyncTag = '[object AsyncFunction]',\n boolTag = '[object Boolean]',\n dateTag = '[object Date]',\n domExcTag = '[object DOMException]',\n errorTag = '[object Error]',\n funcTag = '[object Function]',\n genTag = '[object GeneratorFunction]',\n mapTag = '[object Map]',\n numberTag = '[object Number]',\n nullTag = '[object Null]',\n objectTag = '[object Object]',\n promiseTag = '[object Promise]',\n proxyTag = '[object Proxy]',\n regexpTag = '[object RegExp]',\n setTag = '[object Set]',\n stringTag = '[object String]',\n symbolTag = '[object Symbol]',\n undefinedTag = '[object Undefined]',\n weakMapTag = '[object WeakMap]',\n weakSetTag = '[object WeakSet]';\n\n var arrayBufferTag = '[object ArrayBuffer]',\n dataViewTag = '[object DataView]',\n float32Tag = '[object Float32Array]',\n float64Tag = '[object Float64Array]',\n int8Tag = '[object Int8Array]',\n int16Tag = '[object Int16Array]',\n int32Tag = '[object Int32Array]',\n uint8Tag = '[object Uint8Array]',\n uint8ClampedTag = '[object Uint8ClampedArray]',\n uint16Tag = '[object Uint16Array]',\n uint32Tag = '[object Uint32Array]';\n\n /** Used to match empty string literals in compiled template source. */\n var reEmptyStringLeading = /\\b__p \\+= '';/g,\n reEmptyStringMiddle = /\\b(__p \\+=) '' \\+/g,\n reEmptyStringTrailing = /(__e\\(.*?\\)|\\b__t\\)) \\+\\n'';/g;\n\n /** Used to match HTML entities and HTML characters. */\n var reEscapedHtml = /&(?:amp|lt|gt|quot|#39);/g,\n reUnescapedHtml = /[&<>\"']/g,\n reHasEscapedHtml = RegExp(reEscapedHtml.source),\n reHasUnescapedHtml = RegExp(reUnescapedHtml.source);\n\n /** Used to match template delimiters. */\n var reEscape = /<%-([\\s\\S]+?)%>/g,\n reEvaluate = /<%([\\s\\S]+?)%>/g,\n reInterpolate = /<%=([\\s\\S]+?)%>/g;\n\n /** Used to match property names within property paths. */\n var reIsDeepProp = /\\.|\\[(?:[^[\\]]*|([\"'])(?:(?!\\1)[^\\\\]|\\\\.)*?\\1)\\]/,\n reIsPlainProp = /^\\w*$/,\n rePropName = /[^.[\\]]+|\\[(?:(-?\\d+(?:\\.\\d+)?)|([\"'])((?:(?!\\2)[^\\\\]|\\\\.)*?)\\2)\\]|(?=(?:\\.|\\[\\])(?:\\.|\\[\\]|$))/g;\n\n /**\n * Used to match `RegExp`\n * [syntax characters](http://ecma-international.org/ecma-262/7.0/#sec-patterns).\n */\n var reRegExpChar = /[\\\\^$.*+?()[\\]{}|]/g,\n reHasRegExpChar = RegExp(reRegExpChar.source);\n\n /** Used to match leading whitespace. */\n var reTrimStart = /^\\s+/;\n\n /** Used to match a single whitespace character. */\n var reWhitespace = /\\s/;\n\n /** Used to match wrap detail comments. */\n var reWrapComment = /\\{(?:\\n\\/\\* \\[wrapped with .+\\] \\*\\/)?\\n?/,\n reWrapDetails = /\\{\\n\\/\\* \\[wrapped with (.+)\\] \\*/,\n reSplitDetails = /,? & /;\n\n /** Used to match words composed of alphanumeric characters. */\n var reAsciiWord = /[^\\x00-\\x2f\\x3a-\\x40\\x5b-\\x60\\x7b-\\x7f]+/g;\n\n /**\n * Used to validate the `validate` option in `_.template` variable.\n *\n * Forbids characters which could potentially change the meaning of the function argument definition:\n * - \"(),\" (modification of function parameters)\n * - \"=\" (default value)\n * - \"[]{}\" (destructuring of function parameters)\n * - \"/\" (beginning of a comment)\n * - whitespace\n */\n var reForbiddenIdentifierChars = /[()=,{}\\[\\]\\/\\s]/;\n\n /** Used to match backslashes in property paths. */\n var reEscapeChar = /\\\\(\\\\)?/g;\n\n /**\n * Used to match\n * [ES template delimiters](http://ecma-international.org/ecma-262/7.0/#sec-template-literal-lexical-components).\n */\n var reEsTemplate = /\\$\\{([^\\\\}]*(?:\\\\.[^\\\\}]*)*)\\}/g;\n\n /** Used to match `RegExp` flags from their coerced string values. */\n var reFlags = /\\w*$/;\n\n /** Used to detect bad signed hexadecimal string values. */\n var reIsBadHex = /^[-+]0x[0-9a-f]+$/i;\n\n /** Used to detect binary string values. */\n var reIsBinary = /^0b[01]+$/i;\n\n /** Used to detect host constructors (Safari). */\n var reIsHostCtor = /^\\[object .+?Constructor\\]$/;\n\n /** Used to detect octal string values. */\n var reIsOctal = /^0o[0-7]+$/i;\n\n /** Used to detect unsigned integer values. */\n var reIsUint = /^(?:0|[1-9]\\d*)$/;\n\n /** Used to match Latin Unicode letters (excluding mathematical operators). */\n var reLatin = /[\\xc0-\\xd6\\xd8-\\xf6\\xf8-\\xff\\u0100-\\u017f]/g;\n\n /** Used to ensure capturing order of template delimiters. */\n var reNoMatch = /($^)/;\n\n /** Used to match unescaped characters in compiled string literals. */\n var reUnescapedString = /['\\n\\r\\u2028\\u2029\\\\]/g;\n\n /** Used to compose unicode character classes. */\n var rsAstralRange = '\\\\ud800-\\\\udfff',\n rsComboMarksRange = '\\\\u0300-\\\\u036f',\n reComboHalfMarksRange = '\\\\ufe20-\\\\ufe2f',\n rsComboSymbolsRange = '\\\\u20d0-\\\\u20ff',\n rsComboRange = rsComboMarksRange + reComboHalfMarksRange + rsComboSymbolsRange,\n rsDingbatRange = '\\\\u2700-\\\\u27bf',\n rsLowerRange = 'a-z\\\\xdf-\\\\xf6\\\\xf8-\\\\xff',\n rsMathOpRange = '\\\\xac\\\\xb1\\\\xd7\\\\xf7',\n rsNonCharRange = '\\\\x00-\\\\x2f\\\\x3a-\\\\x40\\\\x5b-\\\\x60\\\\x7b-\\\\xbf',\n rsPunctuationRange = '\\\\u2000-\\\\u206f',\n rsSpaceRange = ' \\\\t\\\\x0b\\\\f\\\\xa0\\\\ufeff\\\\n\\\\r\\\\u2028\\\\u2029\\\\u1680\\\\u180e\\\\u2000\\\\u2001\\\\u2002\\\\u2003\\\\u2004\\\\u2005\\\\u2006\\\\u2007\\\\u2008\\\\u2009\\\\u200a\\\\u202f\\\\u205f\\\\u3000',\n rsUpperRange = 'A-Z\\\\xc0-\\\\xd6\\\\xd8-\\\\xde',\n rsVarRange = '\\\\ufe0e\\\\ufe0f',\n rsBreakRange = rsMathOpRange + rsNonCharRange + rsPunctuationRange + rsSpaceRange;\n\n /** Used to compose unicode capture groups. */\n var rsApos = \"['\\u2019]\",\n rsAstral = '[' + rsAstralRange + ']',\n rsBreak = '[' + rsBreakRange + ']',\n rsCombo = '[' + rsComboRange + ']',\n rsDigits = '\\\\d+',\n rsDingbat = '[' + rsDingbatRange + ']',\n rsLower = '[' + rsLowerRange + ']',\n rsMisc = '[^' + rsAstralRange + rsBreakRange + rsDigits + rsDingbatRange + rsLowerRange + rsUpperRange + ']',\n rsFitz = '\\\\ud83c[\\\\udffb-\\\\udfff]',\n rsModifier = '(?:' + rsCombo + '|' + rsFitz + ')',\n rsNonAstral = '[^' + rsAstralRange + ']',\n rsRegional = '(?:\\\\ud83c[\\\\udde6-\\\\uddff]){2}',\n rsSurrPair = '[\\\\ud800-\\\\udbff][\\\\udc00-\\\\udfff]',\n rsUpper = '[' + rsUpperRange + ']',\n rsZWJ = '\\\\u200d';\n\n /** Used to compose unicode regexes. */\n var rsMiscLower = '(?:' + rsLower + '|' + rsMisc + ')',\n rsMiscUpper = '(?:' + rsUpper + '|' + rsMisc + ')',\n rsOptContrLower = '(?:' + rsApos + '(?:d|ll|m|re|s|t|ve))?',\n rsOptContrUpper = '(?:' + rsApos + '(?:D|LL|M|RE|S|T|VE))?',\n reOptMod = rsModifier + '?',\n rsOptVar = '[' + rsVarRange + ']?',\n rsOptJoin = '(?:' + rsZWJ + '(?:' + [rsNonAstral, rsRegional, rsSurrPair].join('|') + ')' + rsOptVar + reOptMod + ')*',\n rsOrdLower = '\\\\d*(?:1st|2nd|3rd|(?![123])\\\\dth)(?=\\\\b|[A-Z_])',\n rsOrdUpper = '\\\\d*(?:1ST|2ND|3RD|(?![123])\\\\dTH)(?=\\\\b|[a-z_])',\n rsSeq = rsOptVar + reOptMod + rsOptJoin,\n rsEmoji = '(?:' + [rsDingbat, rsRegional, rsSurrPair].join('|') + ')' + rsSeq,\n rsSymbol = '(?:' + [rsNonAstral + rsCombo + '?', rsCombo, rsRegional, rsSurrPair, rsAstral].join('|') + ')';\n\n /** Used to match apostrophes. */\n var reApos = RegExp(rsApos, 'g');\n\n /**\n * Used to match [combining diacritical marks](https://en.wikipedia.org/wiki/Combining_Diacritical_Marks) and\n * [combining diacritical marks for symbols](https://en.wikipedia.org/wiki/Combining_Diacritical_Marks_for_Symbols).\n */\n var reComboMark = RegExp(rsCombo, 'g');\n\n /** Used to match [string symbols](https://mathiasbynens.be/notes/javascript-unicode). */\n var reUnicode = RegExp(rsFitz + '(?=' + rsFitz + ')|' + rsSymbol + rsSeq, 'g');\n\n /** Used to match complex or compound words. */\n var reUnicodeWord = RegExp([\n rsUpper + '?' + rsLower + '+' + rsOptContrLower + '(?=' + [rsBreak, rsUpper, '$'].join('|') + ')',\n rsMiscUpper + '+' + rsOptContrUpper + '(?=' + [rsBreak, rsUpper + rsMiscLower, '$'].join('|') + ')',\n rsUpper + '?' + rsMiscLower + '+' + rsOptContrLower,\n rsUpper + '+' + rsOptContrUpper,\n rsOrdUpper,\n rsOrdLower,\n rsDigits,\n rsEmoji\n ].join('|'), 'g');\n\n /** Used to detect strings with [zero-width joiners or code points from the astral planes](http://eev.ee/blog/2015/09/12/dark-corners-of-unicode/). */\n var reHasUnicode = RegExp('[' + rsZWJ + rsAstralRange + rsComboRange + rsVarRange + ']');\n\n /** Used to detect strings that need a more robust regexp to match words. */\n var reHasUnicodeWord = /[a-z][A-Z]|[A-Z]{2}[a-z]|[0-9][a-zA-Z]|[a-zA-Z][0-9]|[^a-zA-Z0-9 ]/;\n\n /** Used to assign default `context` object properties. */\n var contextProps = [\n 'Array', 'Buffer', 'DataView', 'Date', 'Error', 'Float32Array', 'Float64Array',\n 'Function', 'Int8Array', 'Int16Array', 'Int32Array', 'Map', 'Math', 'Object',\n 'Promise', 'RegExp', 'Set', 'String', 'Symbol', 'TypeError', 'Uint8Array',\n 'Uint8ClampedArray', 'Uint16Array', 'Uint32Array', 'WeakMap',\n '_', 'clearTimeout', 'isFinite', 'parseInt', 'setTimeout'\n ];\n\n /** Used to make template sourceURLs easier to identify. */\n var templateCounter = -1;\n\n /** Used to identify `toStringTag` values of typed arrays. */\n var typedArrayTags = {};\n typedArrayTags[float32Tag] = typedArrayTags[float64Tag] =\n typedArrayTags[int8Tag] = typedArrayTags[int16Tag] =\n typedArrayTags[int32Tag] = typedArrayTags[uint8Tag] =\n typedArrayTags[uint8ClampedTag] = typedArrayTags[uint16Tag] =\n typedArrayTags[uint32Tag] = true;\n typedArrayTags[argsTag] = typedArrayTags[arrayTag] =\n typedArrayTags[arrayBufferTag] = typedArrayTags[boolTag] =\n typedArrayTags[dataViewTag] = typedArrayTags[dateTag] =\n typedArrayTags[errorTag] = typedArrayTags[funcTag] =\n typedArrayTags[mapTag] = typedArrayTags[numberTag] =\n typedArrayTags[objectTag] = typedArrayTags[regexpTag] =\n typedArrayTags[setTag] = typedArrayTags[stringTag] =\n typedArrayTags[weakMapTag] = false;\n\n /** Used to identify `toStringTag` values supported by `_.clone`. */\n var cloneableTags = {};\n cloneableTags[argsTag] = cloneableTags[arrayTag] =\n cloneableTags[arrayBufferTag] = cloneableTags[dataViewTag] =\n cloneableTags[boolTag] = cloneableTags[dateTag] =\n cloneableTags[float32Tag] = cloneableTags[float64Tag] =\n cloneableTags[int8Tag] = cloneableTags[int16Tag] =\n cloneableTags[int32Tag] = cloneableTags[mapTag] =\n cloneableTags[numberTag] = cloneableTags[objectTag] =\n cloneableTags[regexpTag] = cloneableTags[setTag] =\n cloneableTags[stringTag] = cloneableTags[symbolTag] =\n cloneableTags[uint8Tag] = cloneableTags[uint8ClampedTag] =\n cloneableTags[uint16Tag] = cloneableTags[uint32Tag] = true;\n cloneableTags[errorTag] = cloneableTags[funcTag] =\n cloneableTags[weakMapTag] = false;\n\n /** Used to map Latin Unicode letters to basic Latin letters. */\n var deburredLetters = {\n // Latin-1 Supplement block.\n '\\xc0': 'A', '\\xc1': 'A', '\\xc2': 'A', '\\xc3': 'A', '\\xc4': 'A', '\\xc5': 'A',\n '\\xe0': 'a', '\\xe1': 'a', '\\xe2': 'a', '\\xe3': 'a', '\\xe4': 'a', '\\xe5': 'a',\n '\\xc7': 'C', '\\xe7': 'c',\n '\\xd0': 'D', '\\xf0': 'd',\n '\\xc8': 'E', '\\xc9': 'E', '\\xca': 'E', '\\xcb': 'E',\n '\\xe8': 'e', '\\xe9': 'e', '\\xea': 'e', '\\xeb': 'e',\n '\\xcc': 'I', '\\xcd': 'I', '\\xce': 'I', '\\xcf': 'I',\n '\\xec': 'i', '\\xed': 'i', '\\xee': 'i', '\\xef': 'i',\n '\\xd1': 'N', '\\xf1': 'n',\n '\\xd2': 'O', '\\xd3': 'O', '\\xd4': 'O', '\\xd5': 'O', '\\xd6': 'O', '\\xd8': 'O',\n '\\xf2': 'o', '\\xf3': 'o', '\\xf4': 'o', '\\xf5': 'o', '\\xf6': 'o', '\\xf8': 'o',\n '\\xd9': 'U', '\\xda': 'U', '\\xdb': 'U', '\\xdc': 'U',\n '\\xf9': 'u', '\\xfa': 'u', '\\xfb': 'u', '\\xfc': 'u',\n '\\xdd': 'Y', '\\xfd': 'y', '\\xff': 'y',\n '\\xc6': 'Ae', '\\xe6': 'ae',\n '\\xde': 'Th', '\\xfe': 'th',\n '\\xdf': 'ss',\n // Latin Extended-A block.\n '\\u0100': 'A', '\\u0102': 'A', '\\u0104': 'A',\n '\\u0101': 'a', '\\u0103': 'a', '\\u0105': 'a',\n '\\u0106': 'C', '\\u0108': 'C', '\\u010a': 'C', '\\u010c': 'C',\n '\\u0107': 'c', '\\u0109': 'c', '\\u010b': 'c', '\\u010d': 'c',\n '\\u010e': 'D', '\\u0110': 'D', '\\u010f': 'd', '\\u0111': 'd',\n '\\u0112': 'E', '\\u0114': 'E', '\\u0116': 'E', '\\u0118': 'E', '\\u011a': 'E',\n '\\u0113': 'e', '\\u0115': 'e', '\\u0117': 'e', '\\u0119': 'e', '\\u011b': 'e',\n '\\u011c': 'G', '\\u011e': 'G', '\\u0120': 'G', '\\u0122': 'G',\n '\\u011d': 'g', '\\u011f': 'g', '\\u0121': 'g', '\\u0123': 'g',\n '\\u0124': 'H', '\\u0126': 'H', '\\u0125': 'h', '\\u0127': 'h',\n '\\u0128': 'I', '\\u012a': 'I', '\\u012c': 'I', '\\u012e': 'I', '\\u0130': 'I',\n '\\u0129': 'i', '\\u012b': 'i', '\\u012d': 'i', '\\u012f': 'i', '\\u0131': 'i',\n '\\u0134': 'J', '\\u0135': 'j',\n '\\u0136': 'K', '\\u0137': 'k', '\\u0138': 'k',\n '\\u0139': 'L', '\\u013b': 'L', '\\u013d': 'L', '\\u013f': 'L', '\\u0141': 'L',\n '\\u013a': 'l', '\\u013c': 'l', '\\u013e': 'l', '\\u0140': 'l', '\\u0142': 'l',\n '\\u0143': 'N', '\\u0145': 'N', '\\u0147': 'N', '\\u014a': 'N',\n '\\u0144': 'n', '\\u0146': 'n', '\\u0148': 'n', '\\u014b': 'n',\n '\\u014c': 'O', '\\u014e': 'O', '\\u0150': 'O',\n '\\u014d': 'o', '\\u014f': 'o', '\\u0151': 'o',\n '\\u0154': 'R', '\\u0156': 'R', '\\u0158': 'R',\n '\\u0155': 'r', '\\u0157': 'r', '\\u0159': 'r',\n '\\u015a': 'S', '\\u015c': 'S', '\\u015e': 'S', '\\u0160': 'S',\n '\\u015b': 's', '\\u015d': 's', '\\u015f': 's', '\\u0161': 's',\n '\\u0162': 'T', '\\u0164': 'T', '\\u0166': 'T',\n '\\u0163': 't', '\\u0165': 't', '\\u0167': 't',\n '\\u0168': 'U', '\\u016a': 'U', '\\u016c': 'U', '\\u016e': 'U', '\\u0170': 'U', '\\u0172': 'U',\n '\\u0169': 'u', '\\u016b': 'u', '\\u016d': 'u', '\\u016f': 'u', '\\u0171': 'u', '\\u0173': 'u',\n '\\u0174': 'W', '\\u0175': 'w',\n '\\u0176': 'Y', '\\u0177': 'y', '\\u0178': 'Y',\n '\\u0179': 'Z', '\\u017b': 'Z', '\\u017d': 'Z',\n '\\u017a': 'z', '\\u017c': 'z', '\\u017e': 'z',\n '\\u0132': 'IJ', '\\u0133': 'ij',\n '\\u0152': 'Oe', '\\u0153': 'oe',\n '\\u0149': \"'n\", '\\u017f': 's'\n };\n\n /** Used to map characters to HTML entities. */\n var htmlEscapes = {\n '&': '&',\n '<': '<',\n '>': '>',\n '\"': '"',\n \"'\": '''\n };\n\n /** Used to map HTML entities to characters. */\n var htmlUnescapes = {\n '&': '&',\n '<': '<',\n '>': '>',\n '"': '\"',\n ''': \"'\"\n };\n\n /** Used to escape characters for inclusion in compiled string literals. */\n var stringEscapes = {\n '\\\\': '\\\\',\n \"'\": \"'\",\n '\\n': 'n',\n '\\r': 'r',\n '\\u2028': 'u2028',\n '\\u2029': 'u2029'\n };\n\n /** Built-in method references without a dependency on `root`. */\n var freeParseFloat = parseFloat,\n freeParseInt = parseInt;\n\n /** Detect free variable `global` from Node.js. */\n var freeGlobal = typeof global == 'object' && global && global.Object === Object && global;\n\n /** Detect free variable `self`. */\n var freeSelf = typeof self == 'object' && self && self.Object === Object && self;\n\n /** Used as a reference to the global object. */\n var root = freeGlobal || freeSelf || Function('return this')();\n\n /** Detect free variable `exports`. */\n var freeExports = typeof exports == 'object' && exports && !exports.nodeType && exports;\n\n /** Detect free variable `module`. */\n var freeModule = freeExports && typeof module == 'object' && module && !module.nodeType && module;\n\n /** Detect the popular CommonJS extension `module.exports`. */\n var moduleExports = freeModule && freeModule.exports === freeExports;\n\n /** Detect free variable `process` from Node.js. */\n var freeProcess = moduleExports && freeGlobal.process;\n\n /** Used to access faster Node.js helpers. */\n var nodeUtil = (function() {\n try {\n // Use `util.types` for Node.js 10+.\n var types = freeModule && freeModule.require && freeModule.require('util').types;\n\n if (types) {\n return types;\n }\n\n // Legacy `process.binding('util')` for Node.js < 10.\n return freeProcess && freeProcess.binding && freeProcess.binding('util');\n } catch (e) {}\n }());\n\n /* Node.js helper references. */\n var nodeIsArrayBuffer = nodeUtil && nodeUtil.isArrayBuffer,\n nodeIsDate = nodeUtil && nodeUtil.isDate,\n nodeIsMap = nodeUtil && nodeUtil.isMap,\n nodeIsRegExp = nodeUtil && nodeUtil.isRegExp,\n nodeIsSet = nodeUtil && nodeUtil.isSet,\n nodeIsTypedArray = nodeUtil && nodeUtil.isTypedArray;\n\n /*--------------------------------------------------------------------------*/\n\n /**\n * A faster alternative to `Function#apply`, this function invokes `func`\n * with the `this` binding of `thisArg` and the arguments of `args`.\n *\n * @private\n * @param {Function} func The function to invoke.\n * @param {*} thisArg The `this` binding of `func`.\n * @param {Array} args The arguments to invoke `func` with.\n * @returns {*} Returns the result of `func`.\n */\n function apply(func, thisArg, args) {\n switch (args.length) {\n case 0: return func.call(thisArg);\n case 1: return func.call(thisArg, args[0]);\n case 2: return func.call(thisArg, args[0], args[1]);\n case 3: return func.call(thisArg, args[0], args[1], args[2]);\n }\n return func.apply(thisArg, args);\n }\n\n /**\n * A specialized version of `baseAggregator` for arrays.\n *\n * @private\n * @param {Array} [array] The array to iterate over.\n * @param {Function} setter The function to set `accumulator` values.\n * @param {Function} iteratee The iteratee to transform keys.\n * @param {Object} accumulator The initial aggregated object.\n * @returns {Function} Returns `accumulator`.\n */\n function arrayAggregator(array, setter, iteratee, accumulator) {\n var index = -1,\n length = array == null ? 0 : array.length;\n\n while (++index < length) {\n var value = array[index];\n setter(accumulator, value, iteratee(value), array);\n }\n return accumulator;\n }\n\n /**\n * A specialized version of `_.forEach` for arrays without support for\n * iteratee shorthands.\n *\n * @private\n * @param {Array} [array] The array to iterate over.\n * @param {Function} iteratee The function invoked per iteration.\n * @returns {Array} Returns `array`.\n */\n function arrayEach(array, iteratee) {\n var index = -1,\n length = array == null ? 0 : array.length;\n\n while (++index < length) {\n if (iteratee(array[index], index, array) === false) {\n break;\n }\n }\n return array;\n }\n\n /**\n * A specialized version of `_.forEachRight` for arrays without support for\n * iteratee shorthands.\n *\n * @private\n * @param {Array} [array] The array to iterate over.\n * @param {Function} iteratee The function invoked per iteration.\n * @returns {Array} Returns `array`.\n */\n function arrayEachRight(array, iteratee) {\n var length = array == null ? 0 : array.length;\n\n while (length--) {\n if (iteratee(array[length], length, array) === false) {\n break;\n }\n }\n return array;\n }\n\n /**\n * A specialized version of `_.every` for arrays without support for\n * iteratee shorthands.\n *\n * @private\n * @param {Array} [array] The array to iterate over.\n * @param {Function} predicate The function invoked per iteration.\n * @returns {boolean} Returns `true` if all elements pass the predicate check,\n * else `false`.\n */\n function arrayEvery(array, predicate) {\n var index = -1,\n length = array == null ? 0 : array.length;\n\n while (++index < length) {\n if (!predicate(array[index], index, array)) {\n return false;\n }\n }\n return true;\n }\n\n /**\n * A specialized version of `_.filter` for arrays without support for\n * iteratee shorthands.\n *\n * @private\n * @param {Array} [array] The array to iterate over.\n * @param {Function} predicate The function invoked per iteration.\n * @returns {Array} Returns the new filtered array.\n */\n function arrayFilter(array, predicate) {\n var index = -1,\n length = array == null ? 0 : array.length,\n resIndex = 0,\n result = [];\n\n while (++index < length) {\n var value = array[index];\n if (predicate(value, index, array)) {\n result[resIndex++] = value;\n }\n }\n return result;\n }\n\n /**\n * A specialized version of `_.includes` for arrays without support for\n * specifying an index to search from.\n *\n * @private\n * @param {Array} [array] The array to inspect.\n * @param {*} target The value to search for.\n * @returns {boolean} Returns `true` if `target` is found, else `false`.\n */\n function arrayIncludes(array, value) {\n var length = array == null ? 0 : array.length;\n return !!length && baseIndexOf(array, value, 0) > -1;\n }\n\n /**\n * This function is like `arrayIncludes` except that it accepts a comparator.\n *\n * @private\n * @param {Array} [array] The array to inspect.\n * @param {*} target The value to search for.\n * @param {Function} comparator The comparator invoked per element.\n * @returns {boolean} Returns `true` if `target` is found, else `false`.\n */\n function arrayIncludesWith(array, value, comparator) {\n var index = -1,\n length = array == null ? 0 : array.length;\n\n while (++index < length) {\n if (comparator(value, array[index])) {\n return true;\n }\n }\n return false;\n }\n\n /**\n * A specialized version of `_.map` for arrays without support for iteratee\n * shorthands.\n *\n * @private\n * @param {Array} [array] The array to iterate over.\n * @param {Function} iteratee The function invoked per iteration.\n * @returns {Array} Returns the new mapped array.\n */\n function arrayMap(array, iteratee) {\n var index = -1,\n length = array == null ? 0 : array.length,\n result = Array(length);\n\n while (++index < length) {\n result[index] = iteratee(array[index], index, array);\n }\n return result;\n }\n\n /**\n * Appends the elements of `values` to `array`.\n *\n * @private\n * @param {Array} array The array to modify.\n * @param {Array} values The values to append.\n * @returns {Array} Returns `array`.\n */\n function arrayPush(array, values) {\n var index = -1,\n length = values.length,\n offset = array.length;\n\n while (++index < length) {\n array[offset + index] = values[index];\n }\n return array;\n }\n\n /**\n * A specialized version of `_.reduce` for arrays without support for\n * iteratee shorthands.\n *\n * @private\n * @param {Array} [array] The array to iterate over.\n * @param {Function} iteratee The function invoked per iteration.\n * @param {*} [accumulator] The initial value.\n * @param {boolean} [initAccum] Specify using the first element of `array` as\n * the initial value.\n * @returns {*} Returns the accumulated value.\n */\n function arrayReduce(array, iteratee, accumulator, initAccum) {\n var index = -1,\n length = array == null ? 0 : array.length;\n\n if (initAccum && length) {\n accumulator = array[++index];\n }\n while (++index < length) {\n accumulator = iteratee(accumulator, array[index], index, array);\n }\n return accumulator;\n }\n\n /**\n * A specialized version of `_.reduceRight` for arrays without support for\n * iteratee shorthands.\n *\n * @private\n * @param {Array} [array] The array to iterate over.\n * @param {Function} iteratee The function invoked per iteration.\n * @param {*} [accumulator] The initial value.\n * @param {boolean} [initAccum] Specify using the last element of `array` as\n * the initial value.\n * @returns {*} Returns the accumulated value.\n */\n function arrayReduceRight(array, iteratee, accumulator, initAccum) {\n var length = array == null ? 0 : array.length;\n if (initAccum && length) {\n accumulator = array[--length];\n }\n while (length--) {\n accumulator = iteratee(accumulator, array[length], length, array);\n }\n return accumulator;\n }\n\n /**\n * A specialized version of `_.some` for arrays without support for iteratee\n * shorthands.\n *\n * @private\n * @param {Array} [array] The array to iterate over.\n * @param {Function} predicate The function invoked per iteration.\n * @returns {boolean} Returns `true` if any element passes the predicate check,\n * else `false`.\n */\n function arraySome(array, predicate) {\n var index = -1,\n length = array == null ? 0 : array.length;\n\n while (++index < length) {\n if (predicate(array[index], index, array)) {\n return true;\n }\n }\n return false;\n }\n\n /**\n * Gets the size of an ASCII `string`.\n *\n * @private\n * @param {string} string The string inspect.\n * @returns {number} Returns the string size.\n */\n var asciiSize = baseProperty('length');\n\n /**\n * Converts an ASCII `string` to an array.\n *\n * @private\n * @param {string} string The string to convert.\n * @returns {Array} Returns the converted array.\n */\n function asciiToArray(string) {\n return string.split('');\n }\n\n /**\n * Splits an ASCII `string` into an array of its words.\n *\n * @private\n * @param {string} The string to inspect.\n * @returns {Array} Returns the words of `string`.\n */\n function asciiWords(string) {\n return string.match(reAsciiWord) || [];\n }\n\n /**\n * The base implementation of methods like `_.findKey` and `_.findLastKey`,\n * without support for iteratee shorthands, which iterates over `collection`\n * using `eachFunc`.\n *\n * @private\n * @param {Array|Object} collection The collection to inspect.\n * @param {Function} predicate The function invoked per iteration.\n * @param {Function} eachFunc The function to iterate over `collection`.\n * @returns {*} Returns the found element or its key, else `undefined`.\n */\n function baseFindKey(collection, predicate, eachFunc) {\n var result;\n eachFunc(collection, function(value, key, collection) {\n if (predicate(value, key, collection)) {\n result = key;\n return false;\n }\n });\n return result;\n }\n\n /**\n * The base implementation of `_.findIndex` and `_.findLastIndex` without\n * support for iteratee shorthands.\n *\n * @private\n * @param {Array} array The array to inspect.\n * @param {Function} predicate The function invoked per iteration.\n * @param {number} fromIndex The index to search from.\n * @param {boolean} [fromRight] Specify iterating from right to left.\n * @returns {number} Returns the index of the matched value, else `-1`.\n */\n function baseFindIndex(array, predicate, fromIndex, fromRight) {\n var length = array.length,\n index = fromIndex + (fromRight ? 1 : -1);\n\n while ((fromRight ? index-- : ++index < length)) {\n if (predicate(array[index], index, array)) {\n return index;\n }\n }\n return -1;\n }\n\n /**\n * The base implementation of `_.indexOf` without `fromIndex` bounds checks.\n *\n * @private\n * @param {Array} array The array to inspect.\n * @param {*} value The value to search for.\n * @param {number} fromIndex The index to search from.\n * @returns {number} Returns the index of the matched value, else `-1`.\n */\n function baseIndexOf(array, value, fromIndex) {\n return value === value\n ? strictIndexOf(array, value, fromIndex)\n : baseFindIndex(array, baseIsNaN, fromIndex);\n }\n\n /**\n * This function is like `baseIndexOf` except that it accepts a comparator.\n *\n * @private\n * @param {Array} array The array to inspect.\n * @param {*} value The value to search for.\n * @param {number} fromIndex The index to search from.\n * @param {Function} comparator The comparator invoked per element.\n * @returns {number} Returns the index of the matched value, else `-1`.\n */\n function baseIndexOfWith(array, value, fromIndex, comparator) {\n var index = fromIndex - 1,\n length = array.length;\n\n while (++index < length) {\n if (comparator(array[index], value)) {\n return index;\n }\n }\n return -1;\n }\n\n /**\n * The base implementation of `_.isNaN` without support for number objects.\n *\n * @private\n * @param {*} value The value to check.\n * @returns {boolean} Returns `true` if `value` is `NaN`, else `false`.\n */\n function baseIsNaN(value) {\n return value !== value;\n }\n\n /**\n * The base implementation of `_.mean` and `_.meanBy` without support for\n * iteratee shorthands.\n *\n * @private\n * @param {Array} array The array to iterate over.\n * @param {Function} iteratee The function invoked per iteration.\n * @returns {number} Returns the mean.\n */\n function baseMean(array, iteratee) {\n var length = array == null ? 0 : array.length;\n return length ? (baseSum(array, iteratee) / length) : NAN;\n }\n\n /**\n * The base implementation of `_.property` without support for deep paths.\n *\n * @private\n * @param {string} key The key of the property to get.\n * @returns {Function} Returns the new accessor function.\n */\n function baseProperty(key) {\n return function(object) {\n return object == null ? undefined : object[key];\n };\n }\n\n /**\n * The base implementation of `_.propertyOf` without support for deep paths.\n *\n * @private\n * @param {Object} object The object to query.\n * @returns {Function} Returns the new accessor function.\n */\n function basePropertyOf(object) {\n return function(key) {\n return object == null ? undefined : object[key];\n };\n }\n\n /**\n * The base implementation of `_.reduce` and `_.reduceRight`, without support\n * for iteratee shorthands, which iterates over `collection` using `eachFunc`.\n *\n * @private\n * @param {Array|Object} collection The collection to iterate over.\n * @param {Function} iteratee The function invoked per iteration.\n * @param {*} accumulator The initial value.\n * @param {boolean} initAccum Specify using the first or last element of\n * `collection` as the initial value.\n * @param {Function} eachFunc The function to iterate over `collection`.\n * @returns {*} Returns the accumulated value.\n */\n function baseReduce(collection, iteratee, accumulator, initAccum, eachFunc) {\n eachFunc(collection, function(value, index, collection) {\n accumulator = initAccum\n ? (initAccum = false, value)\n : iteratee(accumulator, value, index, collection);\n });\n return accumulator;\n }\n\n /**\n * The base implementation of `_.sortBy` which uses `comparer` to define the\n * sort order of `array` and replaces criteria objects with their corresponding\n * values.\n *\n * @private\n * @param {Array} array The array to sort.\n * @param {Function} comparer The function to define sort order.\n * @returns {Array} Returns `array`.\n */\n function baseSortBy(array, comparer) {\n var length = array.length;\n\n array.sort(comparer);\n while (length--) {\n array[length] = array[length].value;\n }\n return array;\n }\n\n /**\n * The base implementation of `_.sum` and `_.sumBy` without support for\n * iteratee shorthands.\n *\n * @private\n * @param {Array} array The array to iterate over.\n * @param {Function} iteratee The function invoked per iteration.\n * @returns {number} Returns the sum.\n */\n function baseSum(array, iteratee) {\n var result,\n index = -1,\n length = array.length;\n\n while (++index < length) {\n var current = iteratee(array[index]);\n if (current !== undefined) {\n result = result === undefined ? current : (result + current);\n }\n }\n return result;\n }\n\n /**\n * The base implementation of `_.times` without support for iteratee shorthands\n * or max array length checks.\n *\n * @private\n * @param {number} n The number of times to invoke `iteratee`.\n * @param {Function} iteratee The function invoked per iteration.\n * @returns {Array} Returns the array of results.\n */\n function baseTimes(n, iteratee) {\n var index = -1,\n result = Array(n);\n\n while (++index < n) {\n result[index] = iteratee(index);\n }\n return result;\n }\n\n /**\n * The base implementation of `_.toPairs` and `_.toPairsIn` which creates an array\n * of key-value pairs for `object` corresponding to the property names of `props`.\n *\n * @private\n * @param {Object} object The object to query.\n * @param {Array} props The property names to get values for.\n * @returns {Object} Returns the key-value pairs.\n */\n function baseToPairs(object, props) {\n return arrayMap(props, function(key) {\n return [key, object[key]];\n });\n }\n\n /**\n * The base implementation of `_.trim`.\n *\n * @private\n * @param {string} string The string to trim.\n * @returns {string} Returns the trimmed string.\n */\n function baseTrim(string) {\n return string\n ? string.slice(0, trimmedEndIndex(string) + 1).replace(reTrimStart, '')\n : string;\n }\n\n /**\n * The base implementation of `_.unary` without support for storing metadata.\n *\n * @private\n * @param {Function} func The function to cap arguments for.\n * @returns {Function} Returns the new capped function.\n */\n function baseUnary(func) {\n return function(value) {\n return func(value);\n };\n }\n\n /**\n * The base implementation of `_.values` and `_.valuesIn` which creates an\n * array of `object` property values corresponding to the property names\n * of `props`.\n *\n * @private\n * @param {Object} object The object to query.\n * @param {Array} props The property names to get values for.\n * @returns {Object} Returns the array of property values.\n */\n function baseValues(object, props) {\n return arrayMap(props, function(key) {\n return object[key];\n });\n }\n\n /**\n * Checks if a `cache` value for `key` exists.\n *\n * @private\n * @param {Object} cache The cache to query.\n * @param {string} key The key of the entry to check.\n * @returns {boolean} Returns `true` if an entry for `key` exists, else `false`.\n */\n function cacheHas(cache, key) {\n return cache.has(key);\n }\n\n /**\n * Used by `_.trim` and `_.trimStart` to get the index of the first string symbol\n * that is not found in the character symbols.\n *\n * @private\n * @param {Array} strSymbols The string symbols to inspect.\n * @param {Array} chrSymbols The character symbols to find.\n * @returns {number} Returns the index of the first unmatched string symbol.\n */\n function charsStartIndex(strSymbols, chrSymbols) {\n var index = -1,\n length = strSymbols.length;\n\n while (++index < length && baseIndexOf(chrSymbols, strSymbols[index], 0) > -1) {}\n return index;\n }\n\n /**\n * Used by `_.trim` and `_.trimEnd` to get the index of the last string symbol\n * that is not found in the character symbols.\n *\n * @private\n * @param {Array} strSymbols The string symbols to inspect.\n * @param {Array} chrSymbols The character symbols to find.\n * @returns {number} Returns the index of the last unmatched string symbol.\n */\n function charsEndIndex(strSymbols, chrSymbols) {\n var index = strSymbols.length;\n\n while (index-- && baseIndexOf(chrSymbols, strSymbols[index], 0) > -1) {}\n return index;\n }\n\n /**\n * Gets the number of `placeholder` occurrences in `array`.\n *\n * @private\n * @param {Array} array The array to inspect.\n * @param {*} placeholder The placeholder to search for.\n * @returns {number} Returns the placeholder count.\n */\n function countHolders(array, placeholder) {\n var length = array.length,\n result = 0;\n\n while (length--) {\n if (array[length] === placeholder) {\n ++result;\n }\n }\n return result;\n }\n\n /**\n * Used by `_.deburr` to convert Latin-1 Supplement and Latin Extended-A\n * letters to basic Latin letters.\n *\n * @private\n * @param {string} letter The matched letter to deburr.\n * @returns {string} Returns the deburred letter.\n */\n var deburrLetter = basePropertyOf(deburredLetters);\n\n /**\n * Used by `_.escape` to convert characters to HTML entities.\n *\n * @private\n * @param {string} chr The matched character to escape.\n * @returns {string} Returns the escaped character.\n */\n var escapeHtmlChar = basePropertyOf(htmlEscapes);\n\n /**\n * Used by `_.template` to escape characters for inclusion in compiled string literals.\n *\n * @private\n * @param {string} chr The matched character to escape.\n * @returns {string} Returns the escaped character.\n */\n function escapeStringChar(chr) {\n return '\\\\' + stringEscapes[chr];\n }\n\n /**\n * Gets the value at `key` of `object`.\n *\n * @private\n * @param {Object} [object] The object to query.\n * @param {string} key The key of the property to get.\n * @returns {*} Returns the property value.\n */\n function getValue(object, key) {\n return object == null ? undefined : object[key];\n }\n\n /**\n * Checks if `string` contains Unicode symbols.\n *\n * @private\n * @param {string} string The string to inspect.\n * @returns {boolean} Returns `true` if a symbol is found, else `false`.\n */\n function hasUnicode(string) {\n return reHasUnicode.test(string);\n }\n\n /**\n * Checks if `string` contains a word composed of Unicode symbols.\n *\n * @private\n * @param {string} string The string to inspect.\n * @returns {boolean} Returns `true` if a word is found, else `false`.\n */\n function hasUnicodeWord(string) {\n return reHasUnicodeWord.test(string);\n }\n\n /**\n * Converts `iterator` to an array.\n *\n * @private\n * @param {Object} iterator The iterator to convert.\n * @returns {Array} Returns the converted array.\n */\n function iteratorToArray(iterator) {\n var data,\n result = [];\n\n while (!(data = iterator.next()).done) {\n result.push(data.value);\n }\n return result;\n }\n\n /**\n * Converts `map` to its key-value pairs.\n *\n * @private\n * @param {Object} map The map to convert.\n * @returns {Array} Returns the key-value pairs.\n */\n function mapToArray(map) {\n var index = -1,\n result = Array(map.size);\n\n map.forEach(function(value, key) {\n result[++index] = [key, value];\n });\n return result;\n }\n\n /**\n * Creates a unary function that invokes `func` with its argument transformed.\n *\n * @private\n * @param {Function} func The function to wrap.\n * @param {Function} transform The argument transform.\n * @returns {Function} Returns the new function.\n */\n function overArg(func, transform) {\n return function(arg) {\n return func(transform(arg));\n };\n }\n\n /**\n * Replaces all `placeholder` elements in `array` with an internal placeholder\n * and returns an array of their indexes.\n *\n * @private\n * @param {Array} array The array to modify.\n * @param {*} placeholder The placeholder to replace.\n * @returns {Array} Returns the new array of placeholder indexes.\n */\n function replaceHolders(array, placeholder) {\n var index = -1,\n length = array.length,\n resIndex = 0,\n result = [];\n\n while (++index < length) {\n var value = array[index];\n if (value === placeholder || value === PLACEHOLDER) {\n array[index] = PLACEHOLDER;\n result[resIndex++] = index;\n }\n }\n return result;\n }\n\n /**\n * Converts `set` to an array of its values.\n *\n * @private\n * @param {Object} set The set to convert.\n * @returns {Array} Returns the values.\n */\n function setToArray(set) {\n var index = -1,\n result = Array(set.size);\n\n set.forEach(function(value) {\n result[++index] = value;\n });\n return result;\n }\n\n /**\n * Converts `set` to its value-value pairs.\n *\n * @private\n * @param {Object} set The set to convert.\n * @returns {Array} Returns the value-value pairs.\n */\n function setToPairs(set) {\n var index = -1,\n result = Array(set.size);\n\n set.forEach(function(value) {\n result[++index] = [value, value];\n });\n return result;\n }\n\n /**\n * A specialized version of `_.indexOf` which performs strict equality\n * comparisons of values, i.e. `===`.\n *\n * @private\n * @param {Array} array The array to inspect.\n * @param {*} value The value to search for.\n * @param {number} fromIndex The index to search from.\n * @returns {number} Returns the index of the matched value, else `-1`.\n */\n function strictIndexOf(array, value, fromIndex) {\n var index = fromIndex - 1,\n length = array.length;\n\n while (++index < length) {\n if (array[index] === value) {\n return index;\n }\n }\n return -1;\n }\n\n /**\n * A specialized version of `_.lastIndexOf` which performs strict equality\n * comparisons of values, i.e. `===`.\n *\n * @private\n * @param {Array} array The array to inspect.\n * @param {*} value The value to search for.\n * @param {number} fromIndex The index to search from.\n * @returns {number} Returns the index of the matched value, else `-1`.\n */\n function strictLastIndexOf(array, value, fromIndex) {\n var index = fromIndex + 1;\n while (index--) {\n if (array[index] === value) {\n return index;\n }\n }\n return index;\n }\n\n /**\n * Gets the number of symbols in `string`.\n *\n * @private\n * @param {string} string The string to inspect.\n * @returns {number} Returns the string size.\n */\n function stringSize(string) {\n return hasUnicode(string)\n ? unicodeSize(string)\n : asciiSize(string);\n }\n\n /**\n * Converts `string` to an array.\n *\n * @private\n * @param {string} string The string to convert.\n * @returns {Array} Returns the converted array.\n */\n function stringToArray(string) {\n return hasUnicode(string)\n ? unicodeToArray(string)\n : asciiToArray(string);\n }\n\n /**\n * Used by `_.trim` and `_.trimEnd` to get the index of the last non-whitespace\n * character of `string`.\n *\n * @private\n * @param {string} string The string to inspect.\n * @returns {number} Returns the index of the last non-whitespace character.\n */\n function trimmedEndIndex(string) {\n var index = string.length;\n\n while (index-- && reWhitespace.test(string.charAt(index))) {}\n return index;\n }\n\n /**\n * Used by `_.unescape` to convert HTML entities to characters.\n *\n * @private\n * @param {string} chr The matched character to unescape.\n * @returns {string} Returns the unescaped character.\n */\n var unescapeHtmlChar = basePropertyOf(htmlUnescapes);\n\n /**\n * Gets the size of a Unicode `string`.\n *\n * @private\n * @param {string} string The string inspect.\n * @returns {number} Returns the string size.\n */\n function unicodeSize(string) {\n var result = reUnicode.lastIndex = 0;\n while (reUnicode.test(string)) {\n ++result;\n }\n return result;\n }\n\n /**\n * Converts a Unicode `string` to an array.\n *\n * @private\n * @param {string} string The string to convert.\n * @returns {Array} Returns the converted array.\n */\n function unicodeToArray(string) {\n return string.match(reUnicode) || [];\n }\n\n /**\n * Splits a Unicode `string` into an array of its words.\n *\n * @private\n * @param {string} The string to inspect.\n * @returns {Array} Returns the words of `string`.\n */\n function unicodeWords(string) {\n return string.match(reUnicodeWord) || [];\n }\n\n /*--------------------------------------------------------------------------*/\n\n /**\n * Create a new pristine `lodash` function using the `context` object.\n *\n * @static\n * @memberOf _\n * @since 1.1.0\n * @category Util\n * @param {Object} [context=root] The context object.\n * @returns {Function} Returns a new `lodash` function.\n * @example\n *\n * _.mixin({ 'foo': _.constant('foo') });\n *\n * var lodash = _.runInContext();\n * lodash.mixin({ 'bar': lodash.constant('bar') });\n *\n * _.isFunction(_.foo);\n * // => true\n * _.isFunction(_.bar);\n * // => false\n *\n * lodash.isFunction(lodash.foo);\n * // => false\n * lodash.isFunction(lodash.bar);\n * // => true\n *\n * // Create a suped-up `defer` in Node.js.\n * var defer = _.runInContext({ 'setTimeout': setImmediate }).defer;\n */\n var runInContext = (function runInContext(context) {\n context = context == null ? root : _.defaults(root.Object(), context, _.pick(root, contextProps));\n\n /** Built-in constructor references. */\n var Array = context.Array,\n Date = context.Date,\n Error = context.Error,\n Function = context.Function,\n Math = context.Math,\n Object = context.Object,\n RegExp = context.RegExp,\n String = context.String,\n TypeError = context.TypeError;\n\n /** Used for built-in method references. */\n var arrayProto = Array.prototype,\n funcProto = Function.prototype,\n objectProto = Object.prototype;\n\n /** Used to detect overreaching core-js shims. */\n var coreJsData = context['__core-js_shared__'];\n\n /** Used to resolve the decompiled source of functions. */\n var funcToString = funcProto.toString;\n\n /** Used to check objects for own properties. */\n var hasOwnProperty = objectProto.hasOwnProperty;\n\n /** Used to generate unique IDs. */\n var idCounter = 0;\n\n /** Used to detect methods masquerading as native. */\n var maskSrcKey = (function() {\n var uid = /[^.]+$/.exec(coreJsData && coreJsData.keys && coreJsData.keys.IE_PROTO || '');\n return uid ? ('Symbol(src)_1.' + uid) : '';\n }());\n\n /**\n * Used to resolve the\n * [`toStringTag`](http://ecma-international.org/ecma-262/7.0/#sec-object.prototype.tostring)\n * of values.\n */\n var nativeObjectToString = objectProto.toString;\n\n /** Used to infer the `Object` constructor. */\n var objectCtorString = funcToString.call(Object);\n\n /** Used to restore the original `_` reference in `_.noConflict`. */\n var oldDash = root._;\n\n /** Used to detect if a method is native. */\n var reIsNative = RegExp('^' +\n funcToString.call(hasOwnProperty).replace(reRegExpChar, '\\\\$&')\n .replace(/hasOwnProperty|(function).*?(?=\\\\\\()| for .+?(?=\\\\\\])/g, '$1.*?') + '$'\n );\n\n /** Built-in value references. */\n var Buffer = moduleExports ? context.Buffer : undefined,\n Symbol = context.Symbol,\n Uint8Array = context.Uint8Array,\n allocUnsafe = Buffer ? Buffer.allocUnsafe : undefined,\n getPrototype = overArg(Object.getPrototypeOf, Object),\n objectCreate = Object.create,\n propertyIsEnumerable = objectProto.propertyIsEnumerable,\n splice = arrayProto.splice,\n spreadableSymbol = Symbol ? Symbol.isConcatSpreadable : undefined,\n symIterator = Symbol ? Symbol.iterator : undefined,\n symToStringTag = Symbol ? Symbol.toStringTag : undefined;\n\n var defineProperty = (function() {\n try {\n var func = getNative(Object, 'defineProperty');\n func({}, '', {});\n return func;\n } catch (e) {}\n }());\n\n /** Mocked built-ins. */\n var ctxClearTimeout = context.clearTimeout !== root.clearTimeout && context.clearTimeout,\n ctxNow = Date && Date.now !== root.Date.now && Date.now,\n ctxSetTimeout = context.setTimeout !== root.setTimeout && context.setTimeout;\n\n /* Built-in method references for those with the same name as other `lodash` methods. */\n var nativeCeil = Math.ceil,\n nativeFloor = Math.floor,\n nativeGetSymbols = Object.getOwnPropertySymbols,\n nativeIsBuffer = Buffer ? Buffer.isBuffer : undefined,\n nativeIsFinite = context.isFinite,\n nativeJoin = arrayProto.join,\n nativeKeys = overArg(Object.keys, Object),\n nativeMax = Math.max,\n nativeMin = Math.min,\n nativeNow = Date.now,\n nativeParseInt = context.parseInt,\n nativeRandom = Math.random,\n nativeReverse = arrayProto.reverse;\n\n /* Built-in method references that are verified to be native. */\n var DataView = getNative(context, 'DataView'),\n Map = getNative(context, 'Map'),\n Promise = getNative(context, 'Promise'),\n Set = getNative(context, 'Set'),\n WeakMap = getNative(context, 'WeakMap'),\n nativeCreate = getNative(Object, 'create');\n\n /** Used to store function metadata. */\n var metaMap = WeakMap && new WeakMap;\n\n /** Used to lookup unminified function names. */\n var realNames = {};\n\n /** Used to detect maps, sets, and weakmaps. */\n var dataViewCtorString = toSource(DataView),\n mapCtorString = toSource(Map),\n promiseCtorString = toSource(Promise),\n setCtorString = toSource(Set),\n weakMapCtorString = toSource(WeakMap);\n\n /** Used to convert symbols to primitives and strings. */\n var symbolProto = Symbol ? Symbol.prototype : undefined,\n symbolValueOf = symbolProto ? symbolProto.valueOf : undefined,\n symbolToString = symbolProto ? symbolProto.toString : undefined;\n\n /*------------------------------------------------------------------------*/\n\n /**\n * Creates a `lodash` object which wraps `value` to enable implicit method\n * chain sequences. Methods that operate on and return arrays, collections,\n * and functions can be chained together. Methods that retrieve a single value\n * or may return a primitive value will automatically end the chain sequence\n * and return the unwrapped value. Otherwise, the value must be unwrapped\n * with `_#value`.\n *\n * Explicit chain sequences, which must be unwrapped with `_#value`, may be\n * enabled using `_.chain`.\n *\n * The execution of chained methods is lazy, that is, it's deferred until\n * `_#value` is implicitly or explicitly called.\n *\n * Lazy evaluation allows several methods to support shortcut fusion.\n * Shortcut fusion is an optimization to merge iteratee calls; this avoids\n * the creation of intermediate arrays and can greatly reduce the number of\n * iteratee executions. Sections of a chain sequence qualify for shortcut\n * fusion if the section is applied to an array and iteratees accept only\n * one argument. The heuristic for whether a section qualifies for shortcut\n * fusion is subject to change.\n *\n * Chaining is supported in custom builds as long as the `_#value` method is\n * directly or indirectly included in the build.\n *\n * In addition to lodash methods, wrappers have `Array` and `String` methods.\n *\n * The wrapper `Array` methods are:\n * `concat`, `join`, `pop`, `push`, `shift`, `sort`, `splice`, and `unshift`\n *\n * The wrapper `String` methods are:\n * `replace` and `split`\n *\n * The wrapper methods that support shortcut fusion are:\n * `at`, `compact`, `drop`, `dropRight`, `dropWhile`, `filter`, `find`,\n * `findLast`, `head`, `initial`, `last`, `map`, `reject`, `reverse`, `slice`,\n * `tail`, `take`, `takeRight`, `takeRightWhile`, `takeWhile`, and `toArray`\n *\n * The chainable wrapper methods are:\n * `after`, `ary`, `assign`, `assignIn`, `assignInWith`, `assignWith`, `at`,\n * `before`, `bind`, `bindAll`, `bindKey`, `castArray`, `chain`, `chunk`,\n * `commit`, `compact`, `concat`, `conforms`, `constant`, `countBy`, `create`,\n * `curry`, `debounce`, `defaults`, `defaultsDeep`, `defer`, `delay`,\n * `difference`, `differenceBy`, `differenceWith`, `drop`, `dropRight`,\n * `dropRightWhile`, `dropWhile`, `extend`, `extendWith`, `fill`, `filter`,\n * `flatMap`, `flatMapDeep`, `flatMapDepth`, `flatten`, `flattenDeep`,\n * `flattenDepth`, `flip`, `flow`, `flowRight`, `fromPairs`, `functions`,\n * `functionsIn`, `groupBy`, `initial`, `intersection`, `intersectionBy`,\n * `intersectionWith`, `invert`, `invertBy`, `invokeMap`, `iteratee`, `keyBy`,\n * `keys`, `keysIn`, `map`, `mapKeys`, `mapValues`, `matches`, `matchesProperty`,\n * `memoize`, `merge`, `mergeWith`, `method`, `methodOf`, `mixin`, `negate`,\n * `nthArg`, `omit`, `omitBy`, `once`, `orderBy`, `over`, `overArgs`,\n * `overEvery`, `overSome`, `partial`, `partialRight`, `partition`, `pick`,\n * `pickBy`, `plant`, `property`, `propertyOf`, `pull`, `pullAll`, `pullAllBy`,\n * `pullAllWith`, `pullAt`, `push`, `range`, `rangeRight`, `rearg`, `reject`,\n * `remove`, `rest`, `reverse`, `sampleSize`, `set`, `setWith`, `shuffle`,\n * `slice`, `sort`, `sortBy`, `splice`, `spread`, `tail`, `take`, `takeRight`,\n * `takeRightWhile`, `takeWhile`, `tap`, `throttle`, `thru`, `toArray`,\n * `toPairs`, `toPairsIn`, `toPath`, `toPlainObject`, `transform`, `unary`,\n * `union`, `unionBy`, `unionWith`, `uniq`, `uniqBy`, `uniqWith`, `unset`,\n * `unshift`, `unzip`, `unzipWith`, `update`, `updateWith`, `values`,\n * `valuesIn`, `without`, `wrap`, `xor`, `xorBy`, `xorWith`, `zip`,\n * `zipObject`, `zipObjectDeep`, and `zipWith`\n *\n * The wrapper methods that are **not** chainable by default are:\n * `add`, `attempt`, `camelCase`, `capitalize`, `ceil`, `clamp`, `clone`,\n * `cloneDeep`, `cloneDeepWith`, `cloneWith`, `conformsTo`, `deburr`,\n * `defaultTo`, `divide`, `each`, `eachRight`, `endsWith`, `eq`, `escape`,\n * `escapeRegExp`, `every`, `find`, `findIndex`, `findKey`, `findLast`,\n * `findLastIndex`, `findLastKey`, `first`, `floor`, `forEach`, `forEachRight`,\n * `forIn`, `forInRight`, `forOwn`, `forOwnRight`, `get`, `gt`, `gte`, `has`,\n * `hasIn`, `head`, `identity`, `includes`, `indexOf`, `inRange`, `invoke`,\n * `isArguments`, `isArray`, `isArrayBuffer`, `isArrayLike`, `isArrayLikeObject`,\n * `isBoolean`, `isBuffer`, `isDate`, `isElement`, `isEmpty`, `isEqual`,\n * `isEqualWith`, `isError`, `isFinite`, `isFunction`, `isInteger`, `isLength`,\n * `isMap`, `isMatch`, `isMatchWith`, `isNaN`, `isNative`, `isNil`, `isNull`,\n * `isNumber`, `isObject`, `isObjectLike`, `isPlainObject`, `isRegExp`,\n * `isSafeInteger`, `isSet`, `isString`, `isUndefined`, `isTypedArray`,\n * `isWeakMap`, `isWeakSet`, `join`, `kebabCase`, `last`, `lastIndexOf`,\n * `lowerCase`, `lowerFirst`, `lt`, `lte`, `max`, `maxBy`, `mean`, `meanBy`,\n * `min`, `minBy`, `multiply`, `noConflict`, `noop`, `now`, `nth`, `pad`,\n * `padEnd`, `padStart`, `parseInt`, `pop`, `random`, `reduce`, `reduceRight`,\n * `repeat`, `result`, `round`, `runInContext`, `sample`, `shift`, `size`,\n * `snakeCase`, `some`, `sortedIndex`, `sortedIndexBy`, `sortedLastIndex`,\n * `sortedLastIndexBy`, `startCase`, `startsWith`, `stubArray`, `stubFalse`,\n * `stubObject`, `stubString`, `stubTrue`, `subtract`, `sum`, `sumBy`,\n * `template`, `times`, `toFinite`, `toInteger`, `toJSON`, `toLength`,\n * `toLower`, `toNumber`, `toSafeInteger`, `toString`, `toUpper`, `trim`,\n * `trimEnd`, `trimStart`, `truncate`, `unescape`, `uniqueId`, `upperCase`,\n * `upperFirst`, `value`, and `words`\n *\n * @name _\n * @constructor\n * @category Seq\n * @param {*} value The value to wrap in a `lodash` instance.\n * @returns {Object} Returns the new `lodash` wrapper instance.\n * @example\n *\n * function square(n) {\n * return n * n;\n * }\n *\n * var wrapped = _([1, 2, 3]);\n *\n * // Returns an unwrapped value.\n * wrapped.reduce(_.add);\n * // => 6\n *\n * // Returns a wrapped value.\n * var squares = wrapped.map(square);\n *\n * _.isArray(squares);\n * // => false\n *\n * _.isArray(squares.value());\n * // => true\n */\n function lodash(value) {\n if (isObjectLike(value) && !isArray(value) && !(value instanceof LazyWrapper)) {\n if (value instanceof LodashWrapper) {\n return value;\n }\n if (hasOwnProperty.call(value, '__wrapped__')) {\n return wrapperClone(value);\n }\n }\n return new LodashWrapper(value);\n }\n\n /**\n * The base implementation of `_.create` without support for assigning\n * properties to the created object.\n *\n * @private\n * @param {Object} proto The object to inherit from.\n * @returns {Object} Returns the new object.\n */\n var baseCreate = (function() {\n function object() {}\n return function(proto) {\n if (!isObject(proto)) {\n return {};\n }\n if (objectCreate) {\n return objectCreate(proto);\n }\n object.prototype = proto;\n var result = new object;\n object.prototype = undefined;\n return result;\n };\n }());\n\n /**\n * The function whose prototype chain sequence wrappers inherit from.\n *\n * @private\n */\n function baseLodash() {\n // No operation performed.\n }\n\n /**\n * The base constructor for creating `lodash` wrapper objects.\n *\n * @private\n * @param {*} value The value to wrap.\n * @param {boolean} [chainAll] Enable explicit method chain sequences.\n */\n function LodashWrapper(value, chainAll) {\n this.__wrapped__ = value;\n this.__actions__ = [];\n this.__chain__ = !!chainAll;\n this.__index__ = 0;\n this.__values__ = undefined;\n }\n\n /**\n * By default, the template delimiters used by lodash are like those in\n * embedded Ruby (ERB) as well as ES2015 template strings. Change the\n * following template settings to use alternative delimiters.\n *\n * @static\n * @memberOf _\n * @type {Object}\n */\n lodash.templateSettings = {\n\n /**\n * Used to detect `data` property values to be HTML-escaped.\n *\n * @memberOf _.templateSettings\n * @type {RegExp}\n */\n 'escape': reEscape,\n\n /**\n * Used to detect code to be evaluated.\n *\n * @memberOf _.templateSettings\n * @type {RegExp}\n */\n 'evaluate': reEvaluate,\n\n /**\n * Used to detect `data` property values to inject.\n *\n * @memberOf _.templateSettings\n * @type {RegExp}\n */\n 'interpolate': reInterpolate,\n\n /**\n * Used to reference the data object in the template text.\n *\n * @memberOf _.templateSettings\n * @type {string}\n */\n 'variable': '',\n\n /**\n * Used to import variables into the compiled template.\n *\n * @memberOf _.templateSettings\n * @type {Object}\n */\n 'imports': {\n\n /**\n * A reference to the `lodash` function.\n *\n * @memberOf _.templateSettings.imports\n * @type {Function}\n */\n '_': lodash\n }\n };\n\n // Ensure wrappers are instances of `baseLodash`.\n lodash.prototype = baseLodash.prototype;\n lodash.prototype.constructor = lodash;\n\n LodashWrapper.prototype = baseCreate(baseLodash.prototype);\n LodashWrapper.prototype.constructor = LodashWrapper;\n\n /*------------------------------------------------------------------------*/\n\n /**\n * Creates a lazy wrapper object which wraps `value` to enable lazy evaluation.\n *\n * @private\n * @constructor\n * @param {*} value The value to wrap.\n */\n function LazyWrapper(value) {\n this.__wrapped__ = value;\n this.__actions__ = [];\n this.__dir__ = 1;\n this.__filtered__ = false;\n this.__iteratees__ = [];\n this.__takeCount__ = MAX_ARRAY_LENGTH;\n this.__views__ = [];\n }\n\n /**\n * Creates a clone of the lazy wrapper object.\n *\n * @private\n * @name clone\n * @memberOf LazyWrapper\n * @returns {Object} Returns the cloned `LazyWrapper` object.\n */\n function lazyClone() {\n var result = new LazyWrapper(this.__wrapped__);\n result.__actions__ = copyArray(this.__actions__);\n result.__dir__ = this.__dir__;\n result.__filtered__ = this.__filtered__;\n result.__iteratees__ = copyArray(this.__iteratees__);\n result.__takeCount__ = this.__takeCount__;\n result.__views__ = copyArray(this.__views__);\n return result;\n }\n\n /**\n * Reverses the direction of lazy iteration.\n *\n * @private\n * @name reverse\n * @memberOf LazyWrapper\n * @returns {Object} Returns the new reversed `LazyWrapper` object.\n */\n function lazyReverse() {\n if (this.__filtered__) {\n var result = new LazyWrapper(this);\n result.__dir__ = -1;\n result.__filtered__ = true;\n } else {\n result = this.clone();\n result.__dir__ *= -1;\n }\n return result;\n }\n\n /**\n * Extracts the unwrapped value from its lazy wrapper.\n *\n * @private\n * @name value\n * @memberOf LazyWrapper\n * @returns {*} Returns the unwrapped value.\n */\n function lazyValue() {\n var array = this.__wrapped__.value(),\n dir = this.__dir__,\n isArr = isArray(array),\n isRight = dir < 0,\n arrLength = isArr ? array.length : 0,\n view = getView(0, arrLength, this.__views__),\n start = view.start,\n end = view.end,\n length = end - start,\n index = isRight ? end : (start - 1),\n iteratees = this.__iteratees__,\n iterLength = iteratees.length,\n resIndex = 0,\n takeCount = nativeMin(length, this.__takeCount__);\n\n if (!isArr || (!isRight && arrLength == length && takeCount == length)) {\n return baseWrapperValue(array, this.__actions__);\n }\n var result = [];\n\n outer:\n while (length-- && resIndex < takeCount) {\n index += dir;\n\n var iterIndex = -1,\n value = array[index];\n\n while (++iterIndex < iterLength) {\n var data = iteratees[iterIndex],\n iteratee = data.iteratee,\n type = data.type,\n computed = iteratee(value);\n\n if (type == LAZY_MAP_FLAG) {\n value = computed;\n } else if (!computed) {\n if (type == LAZY_FILTER_FLAG) {\n continue outer;\n } else {\n break outer;\n }\n }\n }\n result[resIndex++] = value;\n }\n return result;\n }\n\n // Ensure `LazyWrapper` is an instance of `baseLodash`.\n LazyWrapper.prototype = baseCreate(baseLodash.prototype);\n LazyWrapper.prototype.constructor = LazyWrapper;\n\n /*------------------------------------------------------------------------*/\n\n /**\n * Creates a hash object.\n *\n * @private\n * @constructor\n * @param {Array} [entries] The key-value pairs to cache.\n */\n function Hash(entries) {\n var index = -1,\n length = entries == null ? 0 : entries.length;\n\n this.clear();\n while (++index < length) {\n var entry = entries[index];\n this.set(entry[0], entry[1]);\n }\n }\n\n /**\n * Removes all key-value entries from the hash.\n *\n * @private\n * @name clear\n * @memberOf Hash\n */\n function hashClear() {\n this.__data__ = nativeCreate ? nativeCreate(null) : {};\n this.size = 0;\n }\n\n /**\n * Removes `key` and its value from the hash.\n *\n * @private\n * @name delete\n * @memberOf Hash\n * @param {Object} hash The hash to modify.\n * @param {string} key The key of the value to remove.\n * @returns {boolean} Returns `true` if the entry was removed, else `false`.\n */\n function hashDelete(key) {\n var result = this.has(key) && delete this.__data__[key];\n this.size -= result ? 1 : 0;\n return result;\n }\n\n /**\n * Gets the hash value for `key`.\n *\n * @private\n * @name get\n * @memberOf Hash\n * @param {string} key The key of the value to get.\n * @returns {*} Returns the entry value.\n */\n function hashGet(key) {\n var data = this.__data__;\n if (nativeCreate) {\n var result = data[key];\n return result === HASH_UNDEFINED ? undefined : result;\n }\n return hasOwnProperty.call(data, key) ? data[key] : undefined;\n }\n\n /**\n * Checks if a hash value for `key` exists.\n *\n * @private\n * @name has\n * @memberOf Hash\n * @param {string} key The key of the entry to check.\n * @returns {boolean} Returns `true` if an entry for `key` exists, else `false`.\n */\n function hashHas(key) {\n var data = this.__data__;\n return nativeCreate ? (data[key] !== undefined) : hasOwnProperty.call(data, key);\n }\n\n /**\n * Sets the hash `key` to `value`.\n *\n * @private\n * @name set\n * @memberOf Hash\n * @param {string} key The key of the value to set.\n * @param {*} value The value to set.\n * @returns {Object} Returns the hash instance.\n */\n function hashSet(key, value) {\n var data = this.__data__;\n this.size += this.has(key) ? 0 : 1;\n data[key] = (nativeCreate && value === undefined) ? HASH_UNDEFINED : value;\n return this;\n }\n\n // Add methods to `Hash`.\n Hash.prototype.clear = hashClear;\n Hash.prototype['delete'] = hashDelete;\n Hash.prototype.get = hashGet;\n Hash.prototype.has = hashHas;\n Hash.prototype.set = hashSet;\n\n /*------------------------------------------------------------------------*/\n\n /**\n * Creates an list cache object.\n *\n * @private\n * @constructor\n * @param {Array} [entries] The key-value pairs to cache.\n */\n function ListCache(entries) {\n var index = -1,\n length = entries == null ? 0 : entries.length;\n\n this.clear();\n while (++index < length) {\n var entry = entries[index];\n this.set(entry[0], entry[1]);\n }\n }\n\n /**\n * Removes all key-value entries from the list cache.\n *\n * @private\n * @name clear\n * @memberOf ListCache\n */\n function listCacheClear() {\n this.__data__ = [];\n this.size = 0;\n }\n\n /**\n * Removes `key` and its value from the list cache.\n *\n * @private\n * @name delete\n * @memberOf ListCache\n * @param {string} key The key of the value to remove.\n * @returns {boolean} Returns `true` if the entry was removed, else `false`.\n */\n function listCacheDelete(key) {\n var data = this.__data__,\n index = assocIndexOf(data, key);\n\n if (index < 0) {\n return false;\n }\n var lastIndex = data.length - 1;\n if (index == lastIndex) {\n data.pop();\n } else {\n splice.call(data, index, 1);\n }\n --this.size;\n return true;\n }\n\n /**\n * Gets the list cache value for `key`.\n *\n * @private\n * @name get\n * @memberOf ListCache\n * @param {string} key The key of the value to get.\n * @returns {*} Returns the entry value.\n */\n function listCacheGet(key) {\n var data = this.__data__,\n index = assocIndexOf(data, key);\n\n return index < 0 ? undefined : data[index][1];\n }\n\n /**\n * Checks if a list cache value for `key` exists.\n *\n * @private\n * @name has\n * @memberOf ListCache\n * @param {string} key The key of the entry to check.\n * @returns {boolean} Returns `true` if an entry for `key` exists, else `false`.\n */\n function listCacheHas(key) {\n return assocIndexOf(this.__data__, key) > -1;\n }\n\n /**\n * Sets the list cache `key` to `value`.\n *\n * @private\n * @name set\n * @memberOf ListCache\n * @param {string} key The key of the value to set.\n * @param {*} value The value to set.\n * @returns {Object} Returns the list cache instance.\n */\n function listCacheSet(key, value) {\n var data = this.__data__,\n index = assocIndexOf(data, key);\n\n if (index < 0) {\n ++this.size;\n data.push([key, value]);\n } else {\n data[index][1] = value;\n }\n return this;\n }\n\n // Add methods to `ListCache`.\n ListCache.prototype.clear = listCacheClear;\n ListCache.prototype['delete'] = listCacheDelete;\n ListCache.prototype.get = listCacheGet;\n ListCache.prototype.has = listCacheHas;\n ListCache.prototype.set = listCacheSet;\n\n /*------------------------------------------------------------------------*/\n\n /**\n * Creates a map cache object to store key-value pairs.\n *\n * @private\n * @constructor\n * @param {Array} [entries] The key-value pairs to cache.\n */\n function MapCache(entries) {\n var index = -1,\n length = entries == null ? 0 : entries.length;\n\n this.clear();\n while (++index < length) {\n var entry = entries[index];\n this.set(entry[0], entry[1]);\n }\n }\n\n /**\n * Removes all key-value entries from the map.\n *\n * @private\n * @name clear\n * @memberOf MapCache\n */\n function mapCacheClear() {\n this.size = 0;\n this.__data__ = {\n 'hash': new Hash,\n 'map': new (Map || ListCache),\n 'string': new Hash\n };\n }\n\n /**\n * Removes `key` and its value from the map.\n *\n * @private\n * @name delete\n * @memberOf MapCache\n * @param {string} key The key of the value to remove.\n * @returns {boolean} Returns `true` if the entry was removed, else `false`.\n */\n function mapCacheDelete(key) {\n var result = getMapData(this, key)['delete'](key);\n this.size -= result ? 1 : 0;\n return result;\n }\n\n /**\n * Gets the map value for `key`.\n *\n * @private\n * @name get\n * @memberOf MapCache\n * @param {string} key The key of the value to get.\n * @returns {*} Returns the entry value.\n */\n function mapCacheGet(key) {\n return getMapData(this, key).get(key);\n }\n\n /**\n * Checks if a map value for `key` exists.\n *\n * @private\n * @name has\n * @memberOf MapCache\n * @param {string} key The key of the entry to check.\n * @returns {boolean} Returns `true` if an entry for `key` exists, else `false`.\n */\n function mapCacheHas(key) {\n return getMapData(this, key).has(key);\n }\n\n /**\n * Sets the map `key` to `value`.\n *\n * @private\n * @name set\n * @memberOf MapCache\n * @param {string} key The key of the value to set.\n * @param {*} value The value to set.\n * @returns {Object} Returns the map cache instance.\n */\n function mapCacheSet(key, value) {\n var data = getMapData(this, key),\n size = data.size;\n\n data.set(key, value);\n this.size += data.size == size ? 0 : 1;\n return this;\n }\n\n // Add methods to `MapCache`.\n MapCache.prototype.clear = mapCacheClear;\n MapCache.prototype['delete'] = mapCacheDelete;\n MapCache.prototype.get = mapCacheGet;\n MapCache.prototype.has = mapCacheHas;\n MapCache.prototype.set = mapCacheSet;\n\n /*------------------------------------------------------------------------*/\n\n /**\n *\n * Creates an array cache object to store unique values.\n *\n * @private\n * @constructor\n * @param {Array} [values] The values to cache.\n */\n function SetCache(values) {\n var index = -1,\n length = values == null ? 0 : values.length;\n\n this.__data__ = new MapCache;\n while (++index < length) {\n this.add(values[index]);\n }\n }\n\n /**\n * Adds `value` to the array cache.\n *\n * @private\n * @name add\n * @memberOf SetCache\n * @alias push\n * @param {*} value The value to cache.\n * @returns {Object} Returns the cache instance.\n */\n function setCacheAdd(value) {\n this.__data__.set(value, HASH_UNDEFINED);\n return this;\n }\n\n /**\n * Checks if `value` is in the array cache.\n *\n * @private\n * @name has\n * @memberOf SetCache\n * @param {*} value The value to search for.\n * @returns {number} Returns `true` if `value` is found, else `false`.\n */\n function setCacheHas(value) {\n return this.__data__.has(value);\n }\n\n // Add methods to `SetCache`.\n SetCache.prototype.add = SetCache.prototype.push = setCacheAdd;\n SetCache.prototype.has = setCacheHas;\n\n /*------------------------------------------------------------------------*/\n\n /**\n * Creates a stack cache object to store key-value pairs.\n *\n * @private\n * @constructor\n * @param {Array} [entries] The key-value pairs to cache.\n */\n function Stack(entries) {\n var data = this.__data__ = new ListCache(entries);\n this.size = data.size;\n }\n\n /**\n * Removes all key-value entries from the stack.\n *\n * @private\n * @name clear\n * @memberOf Stack\n */\n function stackClear() {\n this.__data__ = new ListCache;\n this.size = 0;\n }\n\n /**\n * Removes `key` and its value from the stack.\n *\n * @private\n * @name delete\n * @memberOf Stack\n * @param {string} key The key of the value to remove.\n * @returns {boolean} Returns `true` if the entry was removed, else `false`.\n */\n function stackDelete(key) {\n var data = this.__data__,\n result = data['delete'](key);\n\n this.size = data.size;\n return result;\n }\n\n /**\n * Gets the stack value for `key`.\n *\n * @private\n * @name get\n * @memberOf Stack\n * @param {string} key The key of the value to get.\n * @returns {*} Returns the entry value.\n */\n function stackGet(key) {\n return this.__data__.get(key);\n }\n\n /**\n * Checks if a stack value for `key` exists.\n *\n * @private\n * @name has\n * @memberOf Stack\n * @param {string} key The key of the entry to check.\n * @returns {boolean} Returns `true` if an entry for `key` exists, else `false`.\n */\n function stackHas(key) {\n return this.__data__.has(key);\n }\n\n /**\n * Sets the stack `key` to `value`.\n *\n * @private\n * @name set\n * @memberOf Stack\n * @param {string} key The key of the value to set.\n * @param {*} value The value to set.\n * @returns {Object} Returns the stack cache instance.\n */\n function stackSet(key, value) {\n var data = this.__data__;\n if (data instanceof ListCache) {\n var pairs = data.__data__;\n if (!Map || (pairs.length < LARGE_ARRAY_SIZE - 1)) {\n pairs.push([key, value]);\n this.size = ++data.size;\n return this;\n }\n data = this.__data__ = new MapCache(pairs);\n }\n data.set(key, value);\n this.size = data.size;\n return this;\n }\n\n // Add methods to `Stack`.\n Stack.prototype.clear = stackClear;\n Stack.prototype['delete'] = stackDelete;\n Stack.prototype.get = stackGet;\n Stack.prototype.has = stackHas;\n Stack.prototype.set = stackSet;\n\n /*------------------------------------------------------------------------*/\n\n /**\n * Creates an array of the enumerable property names of the array-like `value`.\n *\n * @private\n * @param {*} value The value to query.\n * @param {boolean} inherited Specify returning inherited property names.\n * @returns {Array} Returns the array of property names.\n */\n function arrayLikeKeys(value, inherited) {\n var isArr = isArray(value),\n isArg = !isArr && isArguments(value),\n isBuff = !isArr && !isArg && isBuffer(value),\n isType = !isArr && !isArg && !isBuff && isTypedArray(value),\n skipIndexes = isArr || isArg || isBuff || isType,\n result = skipIndexes ? baseTimes(value.length, String) : [],\n length = result.length;\n\n for (var key in value) {\n if ((inherited || hasOwnProperty.call(value, key)) &&\n !(skipIndexes && (\n // Safari 9 has enumerable `arguments.length` in strict mode.\n key == 'length' ||\n // Node.js 0.10 has enumerable non-index properties on buffers.\n (isBuff && (key == 'offset' || key == 'parent')) ||\n // PhantomJS 2 has enumerable non-index properties on typed arrays.\n (isType && (key == 'buffer' || key == 'byteLength' || key == 'byteOffset')) ||\n // Skip index properties.\n isIndex(key, length)\n ))) {\n result.push(key);\n }\n }\n return result;\n }\n\n /**\n * A specialized version of `_.sample` for arrays.\n *\n * @private\n * @param {Array} array The array to sample.\n * @returns {*} Returns the random element.\n */\n function arraySample(array) {\n var length = array.length;\n return length ? array[baseRandom(0, length - 1)] : undefined;\n }\n\n /**\n * A specialized version of `_.sampleSize` for arrays.\n *\n * @private\n * @param {Array} array The array to sample.\n * @param {number} n The number of elements to sample.\n * @returns {Array} Returns the random elements.\n */\n function arraySampleSize(array, n) {\n return shuffleSelf(copyArray(array), baseClamp(n, 0, array.length));\n }\n\n /**\n * A specialized version of `_.shuffle` for arrays.\n *\n * @private\n * @param {Array} array The array to shuffle.\n * @returns {Array} Returns the new shuffled array.\n */\n function arrayShuffle(array) {\n return shuffleSelf(copyArray(array));\n }\n\n /**\n * This function is like `assignValue` except that it doesn't assign\n * `undefined` values.\n *\n * @private\n * @param {Object} object The object to modify.\n * @param {string} key The key of the property to assign.\n * @param {*} value The value to assign.\n */\n function assignMergeValue(object, key, value) {\n if ((value !== undefined && !eq(object[key], value)) ||\n (value === undefined && !(key in object))) {\n baseAssignValue(object, key, value);\n }\n }\n\n /**\n * Assigns `value` to `key` of `object` if the existing value is not equivalent\n * using [`SameValueZero`](http://ecma-international.org/ecma-262/7.0/#sec-samevaluezero)\n * for equality comparisons.\n *\n * @private\n * @param {Object} object The object to modify.\n * @param {string} key The key of the property to assign.\n * @param {*} value The value to assign.\n */\n function assignValue(object, key, value) {\n var objValue = object[key];\n if (!(hasOwnProperty.call(object, key) && eq(objValue, value)) ||\n (value === undefined && !(key in object))) {\n baseAssignValue(object, key, value);\n }\n }\n\n /**\n * Gets the index at which the `key` is found in `array` of key-value pairs.\n *\n * @private\n * @param {Array} array The array to inspect.\n * @param {*} key The key to search for.\n * @returns {number} Returns the index of the matched value, else `-1`.\n */\n function assocIndexOf(array, key) {\n var length = array.length;\n while (length--) {\n if (eq(array[length][0], key)) {\n return length;\n }\n }\n return -1;\n }\n\n /**\n * Aggregates elements of `collection` on `accumulator` with keys transformed\n * by `iteratee` and values set by `setter`.\n *\n * @private\n * @param {Array|Object} collection The collection to iterate over.\n * @param {Function} setter The function to set `accumulator` values.\n * @param {Function} iteratee The iteratee to transform keys.\n * @param {Object} accumulator The initial aggregated object.\n * @returns {Function} Returns `accumulator`.\n */\n function baseAggregator(collection, setter, iteratee, accumulator) {\n baseEach(collection, function(value, key, collection) {\n setter(accumulator, value, iteratee(value), collection);\n });\n return accumulator;\n }\n\n /**\n * The base implementation of `_.assign` without support for multiple sources\n * or `customizer` functions.\n *\n * @private\n * @param {Object} object The destination object.\n * @param {Object} source The source object.\n * @returns {Object} Returns `object`.\n */\n function baseAssign(object, source) {\n return object && copyObject(source, keys(source), object);\n }\n\n /**\n * The base implementation of `_.assignIn` without support for multiple sources\n * or `customizer` functions.\n *\n * @private\n * @param {Object} object The destination object.\n * @param {Object} source The source object.\n * @returns {Object} Returns `object`.\n */\n function baseAssignIn(object, source) {\n return object && copyObject(source, keysIn(source), object);\n }\n\n /**\n * The base implementation of `assignValue` and `assignMergeValue` without\n * value checks.\n *\n * @private\n * @param {Object} object The object to modify.\n * @param {string} key The key of the property to assign.\n * @param {*} value The value to assign.\n */\n function baseAssignValue(object, key, value) {\n if (key == '__proto__' && defineProperty) {\n defineProperty(object, key, {\n 'configurable': true,\n 'enumerable': true,\n 'value': value,\n 'writable': true\n });\n } else {\n object[key] = value;\n }\n }\n\n /**\n * The base implementation of `_.at` without support for individual paths.\n *\n * @private\n * @param {Object} object The object to iterate over.\n * @param {string[]} paths The property paths to pick.\n * @returns {Array} Returns the picked elements.\n */\n function baseAt(object, paths) {\n var index = -1,\n length = paths.length,\n result = Array(length),\n skip = object == null;\n\n while (++index < length) {\n result[index] = skip ? undefined : get(object, paths[index]);\n }\n return result;\n }\n\n /**\n * The base implementation of `_.clamp` which doesn't coerce arguments.\n *\n * @private\n * @param {number} number The number to clamp.\n * @param {number} [lower] The lower bound.\n * @param {number} upper The upper bound.\n * @returns {number} Returns the clamped number.\n */\n function baseClamp(number, lower, upper) {\n if (number === number) {\n if (upper !== undefined) {\n number = number <= upper ? number : upper;\n }\n if (lower !== undefined) {\n number = number >= lower ? number : lower;\n }\n }\n return number;\n }\n\n /**\n * The base implementation of `_.clone` and `_.cloneDeep` which tracks\n * traversed objects.\n *\n * @private\n * @param {*} value The value to clone.\n * @param {boolean} bitmask The bitmask flags.\n * 1 - Deep clone\n * 2 - Flatten inherited properties\n * 4 - Clone symbols\n * @param {Function} [customizer] The function to customize cloning.\n * @param {string} [key] The key of `value`.\n * @param {Object} [object] The parent object of `value`.\n * @param {Object} [stack] Tracks traversed objects and their clone counterparts.\n * @returns {*} Returns the cloned value.\n */\n function baseClone(value, bitmask, customizer, key, object, stack) {\n var result,\n isDeep = bitmask & CLONE_DEEP_FLAG,\n isFlat = bitmask & CLONE_FLAT_FLAG,\n isFull = bitmask & CLONE_SYMBOLS_FLAG;\n\n if (customizer) {\n result = object ? customizer(value, key, object, stack) : customizer(value);\n }\n if (result !== undefined) {\n return result;\n }\n if (!isObject(value)) {\n return value;\n }\n var isArr = isArray(value);\n if (isArr) {\n result = initCloneArray(value);\n if (!isDeep) {\n return copyArray(value, result);\n }\n } else {\n var tag = getTag(value),\n isFunc = tag == funcTag || tag == genTag;\n\n if (isBuffer(value)) {\n return cloneBuffer(value, isDeep);\n }\n if (tag == objectTag || tag == argsTag || (isFunc && !object)) {\n result = (isFlat || isFunc) ? {} : initCloneObject(value);\n if (!isDeep) {\n return isFlat\n ? copySymbolsIn(value, baseAssignIn(result, value))\n : copySymbols(value, baseAssign(result, value));\n }\n } else {\n if (!cloneableTags[tag]) {\n return object ? value : {};\n }\n result = initCloneByTag(value, tag, isDeep);\n }\n }\n // Check for circular references and return its corresponding clone.\n stack || (stack = new Stack);\n var stacked = stack.get(value);\n if (stacked) {\n return stacked;\n }\n stack.set(value, result);\n\n if (isSet(value)) {\n value.forEach(function(subValue) {\n result.add(baseClone(subValue, bitmask, customizer, subValue, value, stack));\n });\n } else if (isMap(value)) {\n value.forEach(function(subValue, key) {\n result.set(key, baseClone(subValue, bitmask, customizer, key, value, stack));\n });\n }\n\n var keysFunc = isFull\n ? (isFlat ? getAllKeysIn : getAllKeys)\n : (isFlat ? keysIn : keys);\n\n var props = isArr ? undefined : keysFunc(value);\n arrayEach(props || value, function(subValue, key) {\n if (props) {\n key = subValue;\n subValue = value[key];\n }\n // Recursively populate clone (susceptible to call stack limits).\n assignValue(result, key, baseClone(subValue, bitmask, customizer, key, value, stack));\n });\n return result;\n }\n\n /**\n * The base implementation of `_.conforms` which doesn't clone `source`.\n *\n * @private\n * @param {Object} source The object of property predicates to conform to.\n * @returns {Function} Returns the new spec function.\n */\n function baseConforms(source) {\n var props = keys(source);\n return function(object) {\n return baseConformsTo(object, source, props);\n };\n }\n\n /**\n * The base implementation of `_.conformsTo` which accepts `props` to check.\n *\n * @private\n * @param {Object} object The object to inspect.\n * @param {Object} source The object of property predicates to conform to.\n * @returns {boolean} Returns `true` if `object` conforms, else `false`.\n */\n function baseConformsTo(object, source, props) {\n var length = props.length;\n if (object == null) {\n return !length;\n }\n object = Object(object);\n while (length--) {\n var key = props[length],\n predicate = source[key],\n value = object[key];\n\n if ((value === undefined && !(key in object)) || !predicate(value)) {\n return false;\n }\n }\n return true;\n }\n\n /**\n * The base implementation of `_.delay` and `_.defer` which accepts `args`\n * to provide to `func`.\n *\n * @private\n * @param {Function} func The function to delay.\n * @param {number} wait The number of milliseconds to delay invocation.\n * @param {Array} args The arguments to provide to `func`.\n * @returns {number|Object} Returns the timer id or timeout object.\n */\n function baseDelay(func, wait, args) {\n if (typeof func != 'function') {\n throw new TypeError(FUNC_ERROR_TEXT);\n }\n return setTimeout(function() { func.apply(undefined, args); }, wait);\n }\n\n /**\n * The base implementation of methods like `_.difference` without support\n * for excluding multiple arrays or iteratee shorthands.\n *\n * @private\n * @param {Array} array The array to inspect.\n * @param {Array} values The values to exclude.\n * @param {Function} [iteratee] The iteratee invoked per element.\n * @param {Function} [comparator] The comparator invoked per element.\n * @returns {Array} Returns the new array of filtered values.\n */\n function baseDifference(array, values, iteratee, comparator) {\n var index = -1,\n includes = arrayIncludes,\n isCommon = true,\n length = array.length,\n result = [],\n valuesLength = values.length;\n\n if (!length) {\n return result;\n }\n if (iteratee) {\n values = arrayMap(values, baseUnary(iteratee));\n }\n if (comparator) {\n includes = arrayIncludesWith;\n isCommon = false;\n }\n else if (values.length >= LARGE_ARRAY_SIZE) {\n includes = cacheHas;\n isCommon = false;\n values = new SetCache(values);\n }\n outer:\n while (++index < length) {\n var value = array[index],\n computed = iteratee == null ? value : iteratee(value);\n\n value = (comparator || value !== 0) ? value : 0;\n if (isCommon && computed === computed) {\n var valuesIndex = valuesLength;\n while (valuesIndex--) {\n if (values[valuesIndex] === computed) {\n continue outer;\n }\n }\n result.push(value);\n }\n else if (!includes(values, computed, comparator)) {\n result.push(value);\n }\n }\n return result;\n }\n\n /**\n * The base implementation of `_.forEach` without support for iteratee shorthands.\n *\n * @private\n * @param {Array|Object} collection The collection to iterate over.\n * @param {Function} iteratee The function invoked per iteration.\n * @returns {Array|Object} Returns `collection`.\n */\n var baseEach = createBaseEach(baseForOwn);\n\n /**\n * The base implementation of `_.forEachRight` without support for iteratee shorthands.\n *\n * @private\n * @param {Array|Object} collection The collection to iterate over.\n * @param {Function} iteratee The function invoked per iteration.\n * @returns {Array|Object} Returns `collection`.\n */\n var baseEachRight = createBaseEach(baseForOwnRight, true);\n\n /**\n * The base implementation of `_.every` without support for iteratee shorthands.\n *\n * @private\n * @param {Array|Object} collection The collection to iterate over.\n * @param {Function} predicate The function invoked per iteration.\n * @returns {boolean} Returns `true` if all elements pass the predicate check,\n * else `false`\n */\n function baseEvery(collection, predicate) {\n var result = true;\n baseEach(collection, function(value, index, collection) {\n result = !!predicate(value, index, collection);\n return result;\n });\n return result;\n }\n\n /**\n * The base implementation of methods like `_.max` and `_.min` which accepts a\n * `comparator` to determine the extremum value.\n *\n * @private\n * @param {Array} array The array to iterate over.\n * @param {Function} iteratee The iteratee invoked per iteration.\n * @param {Function} comparator The comparator used to compare values.\n * @returns {*} Returns the extremum value.\n */\n function baseExtremum(array, iteratee, comparator) {\n var index = -1,\n length = array.length;\n\n while (++index < length) {\n var value = array[index],\n current = iteratee(value);\n\n if (current != null && (computed === undefined\n ? (current === current && !isSymbol(current))\n : comparator(current, computed)\n )) {\n var computed = current,\n result = value;\n }\n }\n return result;\n }\n\n /**\n * The base implementation of `_.fill` without an iteratee call guard.\n *\n * @private\n * @param {Array} array The array to fill.\n * @param {*} value The value to fill `array` with.\n * @param {number} [start=0] The start position.\n * @param {number} [end=array.length] The end position.\n * @returns {Array} Returns `array`.\n */\n function baseFill(array, value, start, end) {\n var length = array.length;\n\n start = toInteger(start);\n if (start < 0) {\n start = -start > length ? 0 : (length + start);\n }\n end = (end === undefined || end > length) ? length : toInteger(end);\n if (end < 0) {\n end += length;\n }\n end = start > end ? 0 : toLength(end);\n while (start < end) {\n array[start++] = value;\n }\n return array;\n }\n\n /**\n * The base implementation of `_.filter` without support for iteratee shorthands.\n *\n * @private\n * @param {Array|Object} collection The collection to iterate over.\n * @param {Function} predicate The function invoked per iteration.\n * @returns {Array} Returns the new filtered array.\n */\n function baseFilter(collection, predicate) {\n var result = [];\n baseEach(collection, function(value, index, collection) {\n if (predicate(value, index, collection)) {\n result.push(value);\n }\n });\n return result;\n }\n\n /**\n * The base implementation of `_.flatten` with support for restricting flattening.\n *\n * @private\n * @param {Array} array The array to flatten.\n * @param {number} depth The maximum recursion depth.\n * @param {boolean} [predicate=isFlattenable] The function invoked per iteration.\n * @param {boolean} [isStrict] Restrict to values that pass `predicate` checks.\n * @param {Array} [result=[]] The initial result value.\n * @returns {Array} Returns the new flattened array.\n */\n function baseFlatten(array, depth, predicate, isStrict, result) {\n var index = -1,\n length = array.length;\n\n predicate || (predicate = isFlattenable);\n result || (result = []);\n\n while (++index < length) {\n var value = array[index];\n if (depth > 0 && predicate(value)) {\n if (depth > 1) {\n // Recursively flatten arrays (susceptible to call stack limits).\n baseFlatten(value, depth - 1, predicate, isStrict, result);\n } else {\n arrayPush(result, value);\n }\n } else if (!isStrict) {\n result[result.length] = value;\n }\n }\n return result;\n }\n\n /**\n * The base implementation of `baseForOwn` which iterates over `object`\n * properties returned by `keysFunc` and invokes `iteratee` for each property.\n * Iteratee functions may exit iteration early by explicitly returning `false`.\n *\n * @private\n * @param {Object} object The object to iterate over.\n * @param {Function} iteratee The function invoked per iteration.\n * @param {Function} keysFunc The function to get the keys of `object`.\n * @returns {Object} Returns `object`.\n */\n var baseFor = createBaseFor();\n\n /**\n * This function is like `baseFor` except that it iterates over properties\n * in the opposite order.\n *\n * @private\n * @param {Object} object The object to iterate over.\n * @param {Function} iteratee The function invoked per iteration.\n * @param {Function} keysFunc The function to get the keys of `object`.\n * @returns {Object} Returns `object`.\n */\n var baseForRight = createBaseFor(true);\n\n /**\n * The base implementation of `_.forOwn` without support for iteratee shorthands.\n *\n * @private\n * @param {Object} object The object to iterate over.\n * @param {Function} iteratee The function invoked per iteration.\n * @returns {Object} Returns `object`.\n */\n function baseForOwn(object, iteratee) {\n return object && baseFor(object, iteratee, keys);\n }\n\n /**\n * The base implementation of `_.forOwnRight` without support for iteratee shorthands.\n *\n * @private\n * @param {Object} object The object to iterate over.\n * @param {Function} iteratee The function invoked per iteration.\n * @returns {Object} Returns `object`.\n */\n function baseForOwnRight(object, iteratee) {\n return object && baseForRight(object, iteratee, keys);\n }\n\n /**\n * The base implementation of `_.functions` which creates an array of\n * `object` function property names filtered from `props`.\n *\n * @private\n * @param {Object} object The object to inspect.\n * @param {Array} props The property names to filter.\n * @returns {Array} Returns the function names.\n */\n function baseFunctions(object, props) {\n return arrayFilter(props, function(key) {\n return isFunction(object[key]);\n });\n }\n\n /**\n * The base implementation of `_.get` without support for default values.\n *\n * @private\n * @param {Object} object The object to query.\n * @param {Array|string} path The path of the property to get.\n * @returns {*} Returns the resolved value.\n */\n function baseGet(object, path) {\n path = castPath(path, object);\n\n var index = 0,\n length = path.length;\n\n while (object != null && index < length) {\n object = object[toKey(path[index++])];\n }\n return (index && index == length) ? object : undefined;\n }\n\n /**\n * The base implementation of `getAllKeys` and `getAllKeysIn` which uses\n * `keysFunc` and `symbolsFunc` to get the enumerable property names and\n * symbols of `object`.\n *\n * @private\n * @param {Object} object The object to query.\n * @param {Function} keysFunc The function to get the keys of `object`.\n * @param {Function} symbolsFunc The function to get the symbols of `object`.\n * @returns {Array} Returns the array of property names and symbols.\n */\n function baseGetAllKeys(object, keysFunc, symbolsFunc) {\n var result = keysFunc(object);\n return isArray(object) ? result : arrayPush(result, symbolsFunc(object));\n }\n\n /**\n * The base implementation of `getTag` without fallbacks for buggy environments.\n *\n * @private\n * @param {*} value The value to query.\n * @returns {string} Returns the `toStringTag`.\n */\n function baseGetTag(value) {\n if (value == null) {\n return value === undefined ? undefinedTag : nullTag;\n }\n return (symToStringTag && symToStringTag in Object(value))\n ? getRawTag(value)\n : objectToString(value);\n }\n\n /**\n * The base implementation of `_.gt` which doesn't coerce arguments.\n *\n * @private\n * @param {*} value The value to compare.\n * @param {*} other The other value to compare.\n * @returns {boolean} Returns `true` if `value` is greater than `other`,\n * else `false`.\n */\n function baseGt(value, other) {\n return value > other;\n }\n\n /**\n * The base implementation of `_.has` without support for deep paths.\n *\n * @private\n * @param {Object} [object] The object to query.\n * @param {Array|string} key The key to check.\n * @returns {boolean} Returns `true` if `key` exists, else `false`.\n */\n function baseHas(object, key) {\n return object != null && hasOwnProperty.call(object, key);\n }\n\n /**\n * The base implementation of `_.hasIn` without support for deep paths.\n *\n * @private\n * @param {Object} [object] The object to query.\n * @param {Array|string} key The key to check.\n * @returns {boolean} Returns `true` if `key` exists, else `false`.\n */\n function baseHasIn(object, key) {\n return object != null && key in Object(object);\n }\n\n /**\n * The base implementation of `_.inRange` which doesn't coerce arguments.\n *\n * @private\n * @param {number} number The number to check.\n * @param {number} start The start of the range.\n * @param {number} end The end of the range.\n * @returns {boolean} Returns `true` if `number` is in the range, else `false`.\n */\n function baseInRange(number, start, end) {\n return number >= nativeMin(start, end) && number < nativeMax(start, end);\n }\n\n /**\n * The base implementation of methods like `_.intersection`, without support\n * for iteratee shorthands, that accepts an array of arrays to inspect.\n *\n * @private\n * @param {Array} arrays The arrays to inspect.\n * @param {Function} [iteratee] The iteratee invoked per element.\n * @param {Function} [comparator] The comparator invoked per element.\n * @returns {Array} Returns the new array of shared values.\n */\n function baseIntersection(arrays, iteratee, comparator) {\n var includes = comparator ? arrayIncludesWith : arrayIncludes,\n length = arrays[0].length,\n othLength = arrays.length,\n othIndex = othLength,\n caches = Array(othLength),\n maxLength = Infinity,\n result = [];\n\n while (othIndex--) {\n var array = arrays[othIndex];\n if (othIndex && iteratee) {\n array = arrayMap(array, baseUnary(iteratee));\n }\n maxLength = nativeMin(array.length, maxLength);\n caches[othIndex] = !comparator && (iteratee || (length >= 120 && array.length >= 120))\n ? new SetCache(othIndex && array)\n : undefined;\n }\n array = arrays[0];\n\n var index = -1,\n seen = caches[0];\n\n outer:\n while (++index < length && result.length < maxLength) {\n var value = array[index],\n computed = iteratee ? iteratee(value) : value;\n\n value = (comparator || value !== 0) ? value : 0;\n if (!(seen\n ? cacheHas(seen, computed)\n : includes(result, computed, comparator)\n )) {\n othIndex = othLength;\n while (--othIndex) {\n var cache = caches[othIndex];\n if (!(cache\n ? cacheHas(cache, computed)\n : includes(arrays[othIndex], computed, comparator))\n ) {\n continue outer;\n }\n }\n if (seen) {\n seen.push(computed);\n }\n result.push(value);\n }\n }\n return result;\n }\n\n /**\n * The base implementation of `_.invert` and `_.invertBy` which inverts\n * `object` with values transformed by `iteratee` and set by `setter`.\n *\n * @private\n * @param {Object} object The object to iterate over.\n * @param {Function} setter The function to set `accumulator` values.\n * @param {Function} iteratee The iteratee to transform values.\n * @param {Object} accumulator The initial inverted object.\n * @returns {Function} Returns `accumulator`.\n */\n function baseInverter(object, setter, iteratee, accumulator) {\n baseForOwn(object, function(value, key, object) {\n setter(accumulator, iteratee(value), key, object);\n });\n return accumulator;\n }\n\n /**\n * The base implementation of `_.invoke` without support for individual\n * method arguments.\n *\n * @private\n * @param {Object} object The object to query.\n * @param {Array|string} path The path of the method to invoke.\n * @param {Array} args The arguments to invoke the method with.\n * @returns {*} Returns the result of the invoked method.\n */\n function baseInvoke(object, path, args) {\n path = castPath(path, object);\n object = parent(object, path);\n var func = object == null ? object : object[toKey(last(path))];\n return func == null ? undefined : apply(func, object, args);\n }\n\n /**\n * The base implementation of `_.isArguments`.\n *\n * @private\n * @param {*} value The value to check.\n * @returns {boolean} Returns `true` if `value` is an `arguments` object,\n */\n function baseIsArguments(value) {\n return isObjectLike(value) && baseGetTag(value) == argsTag;\n }\n\n /**\n * The base implementation of `_.isArrayBuffer` without Node.js optimizations.\n *\n * @private\n * @param {*} value The value to check.\n * @returns {boolean} Returns `true` if `value` is an array buffer, else `false`.\n */\n function baseIsArrayBuffer(value) {\n return isObjectLike(value) && baseGetTag(value) == arrayBufferTag;\n }\n\n /**\n * The base implementation of `_.isDate` without Node.js optimizations.\n *\n * @private\n * @param {*} value The value to check.\n * @returns {boolean} Returns `true` if `value` is a date object, else `false`.\n */\n function baseIsDate(value) {\n return isObjectLike(value) && baseGetTag(value) == dateTag;\n }\n\n /**\n * The base implementation of `_.isEqual` which supports partial comparisons\n * and tracks traversed objects.\n *\n * @private\n * @param {*} value The value to compare.\n * @param {*} other The other value to compare.\n * @param {boolean} bitmask The bitmask flags.\n * 1 - Unordered comparison\n * 2 - Partial comparison\n * @param {Function} [customizer] The function to customize comparisons.\n * @param {Object} [stack] Tracks traversed `value` and `other` objects.\n * @returns {boolean} Returns `true` if the values are equivalent, else `false`.\n */\n function baseIsEqual(value, other, bitmask, customizer, stack) {\n if (value === other) {\n return true;\n }\n if (value == null || other == null || (!isObjectLike(value) && !isObjectLike(other))) {\n return value !== value && other !== other;\n }\n return baseIsEqualDeep(value, other, bitmask, customizer, baseIsEqual, stack);\n }\n\n /**\n * A specialized version of `baseIsEqual` for arrays and objects which performs\n * deep comparisons and tracks traversed objects enabling objects with circular\n * references to be compared.\n *\n * @private\n * @param {Object} object The object to compare.\n * @param {Object} other The other object to compare.\n * @param {number} bitmask The bitmask flags. See `baseIsEqual` for more details.\n * @param {Function} customizer The function to customize comparisons.\n * @param {Function} equalFunc The function to determine equivalents of values.\n * @param {Object} [stack] Tracks traversed `object` and `other` objects.\n * @returns {boolean} Returns `true` if the objects are equivalent, else `false`.\n */\n function baseIsEqualDeep(object, other, bitmask, customizer, equalFunc, stack) {\n var objIsArr = isArray(object),\n othIsArr = isArray(other),\n objTag = objIsArr ? arrayTag : getTag(object),\n othTag = othIsArr ? arrayTag : getTag(other);\n\n objTag = objTag == argsTag ? objectTag : objTag;\n othTag = othTag == argsTag ? objectTag : othTag;\n\n var objIsObj = objTag == objectTag,\n othIsObj = othTag == objectTag,\n isSameTag = objTag == othTag;\n\n if (isSameTag && isBuffer(object)) {\n if (!isBuffer(other)) {\n return false;\n }\n objIsArr = true;\n objIsObj = false;\n }\n if (isSameTag && !objIsObj) {\n stack || (stack = new Stack);\n return (objIsArr || isTypedArray(object))\n ? equalArrays(object, other, bitmask, customizer, equalFunc, stack)\n : equalByTag(object, other, objTag, bitmask, customizer, equalFunc, stack);\n }\n if (!(bitmask & COMPARE_PARTIAL_FLAG)) {\n var objIsWrapped = objIsObj && hasOwnProperty.call(object, '__wrapped__'),\n othIsWrapped = othIsObj && hasOwnProperty.call(other, '__wrapped__');\n\n if (objIsWrapped || othIsWrapped) {\n var objUnwrapped = objIsWrapped ? object.value() : object,\n othUnwrapped = othIsWrapped ? other.value() : other;\n\n stack || (stack = new Stack);\n return equalFunc(objUnwrapped, othUnwrapped, bitmask, customizer, stack);\n }\n }\n if (!isSameTag) {\n return false;\n }\n stack || (stack = new Stack);\n return equalObjects(object, other, bitmask, customizer, equalFunc, stack);\n }\n\n /**\n * The base implementation of `_.isMap` without Node.js optimizations.\n *\n * @private\n * @param {*} value The value to check.\n * @returns {boolean} Returns `true` if `value` is a map, else `false`.\n */\n function baseIsMap(value) {\n return isObjectLike(value) && getTag(value) == mapTag;\n }\n\n /**\n * The base implementation of `_.isMatch` without support for iteratee shorthands.\n *\n * @private\n * @param {Object} object The object to inspect.\n * @param {Object} source The object of property values to match.\n * @param {Array} matchData The property names, values, and compare flags to match.\n * @param {Function} [customizer] The function to customize comparisons.\n * @returns {boolean} Returns `true` if `object` is a match, else `false`.\n */\n function baseIsMatch(object, source, matchData, customizer) {\n var index = matchData.length,\n length = index,\n noCustomizer = !customizer;\n\n if (object == null) {\n return !length;\n }\n object = Object(object);\n while (index--) {\n var data = matchData[index];\n if ((noCustomizer && data[2])\n ? data[1] !== object[data[0]]\n : !(data[0] in object)\n ) {\n return false;\n }\n }\n while (++index < length) {\n data = matchData[index];\n var key = data[0],\n objValue = object[key],\n srcValue = data[1];\n\n if (noCustomizer && data[2]) {\n if (objValue === undefined && !(key in object)) {\n return false;\n }\n } else {\n var stack = new Stack;\n if (customizer) {\n var result = customizer(objValue, srcValue, key, object, source, stack);\n }\n if (!(result === undefined\n ? baseIsEqual(srcValue, objValue, COMPARE_PARTIAL_FLAG | COMPARE_UNORDERED_FLAG, customizer, stack)\n : result\n )) {\n return false;\n }\n }\n }\n return true;\n }\n\n /**\n * The base implementation of `_.isNative` without bad shim checks.\n *\n * @private\n * @param {*} value The value to check.\n * @returns {boolean} Returns `true` if `value` is a native function,\n * else `false`.\n */\n function baseIsNative(value) {\n if (!isObject(value) || isMasked(value)) {\n return false;\n }\n var pattern = isFunction(value) ? reIsNative : reIsHostCtor;\n return pattern.test(toSource(value));\n }\n\n /**\n * The base implementation of `_.isRegExp` without Node.js optimizations.\n *\n * @private\n * @param {*} value The value to check.\n * @returns {boolean} Returns `true` if `value` is a regexp, else `false`.\n */\n function baseIsRegExp(value) {\n return isObjectLike(value) && baseGetTag(value) == regexpTag;\n }\n\n /**\n * The base implementation of `_.isSet` without Node.js optimizations.\n *\n * @private\n * @param {*} value The value to check.\n * @returns {boolean} Returns `true` if `value` is a set, else `false`.\n */\n function baseIsSet(value) {\n return isObjectLike(value) && getTag(value) == setTag;\n }\n\n /**\n * The base implementation of `_.isTypedArray` without Node.js optimizations.\n *\n * @private\n * @param {*} value The value to check.\n * @returns {boolean} Returns `true` if `value` is a typed array, else `false`.\n */\n function baseIsTypedArray(value) {\n return isObjectLike(value) &&\n isLength(value.length) && !!typedArrayTags[baseGetTag(value)];\n }\n\n /**\n * The base implementation of `_.iteratee`.\n *\n * @private\n * @param {*} [value=_.identity] The value to convert to an iteratee.\n * @returns {Function} Returns the iteratee.\n */\n function baseIteratee(value) {\n // Don't store the `typeof` result in a variable to avoid a JIT bug in Safari 9.\n // See https://bugs.webkit.org/show_bug.cgi?id=156034 for more details.\n if (typeof value == 'function') {\n return value;\n }\n if (value == null) {\n return identity;\n }\n if (typeof value == 'object') {\n return isArray(value)\n ? baseMatchesProperty(value[0], value[1])\n : baseMatches(value);\n }\n return property(value);\n }\n\n /**\n * The base implementation of `_.keys` which doesn't treat sparse arrays as dense.\n *\n * @private\n * @param {Object} object The object to query.\n * @returns {Array} Returns the array of property names.\n */\n function baseKeys(object) {\n if (!isPrototype(object)) {\n return nativeKeys(object);\n }\n var result = [];\n for (var key in Object(object)) {\n if (hasOwnProperty.call(object, key) && key != 'constructor') {\n result.push(key);\n }\n }\n return result;\n }\n\n /**\n * The base implementation of `_.keysIn` which doesn't treat sparse arrays as dense.\n *\n * @private\n * @param {Object} object The object to query.\n * @returns {Array} Returns the array of property names.\n */\n function baseKeysIn(object) {\n if (!isObject(object)) {\n return nativeKeysIn(object);\n }\n var isProto = isPrototype(object),\n result = [];\n\n for (var key in object) {\n if (!(key == 'constructor' && (isProto || !hasOwnProperty.call(object, key)))) {\n result.push(key);\n }\n }\n return result;\n }\n\n /**\n * The base implementation of `_.lt` which doesn't coerce arguments.\n *\n * @private\n * @param {*} value The value to compare.\n * @param {*} other The other value to compare.\n * @returns {boolean} Returns `true` if `value` is less than `other`,\n * else `false`.\n */\n function baseLt(value, other) {\n return value < other;\n }\n\n /**\n * The base implementation of `_.map` without support for iteratee shorthands.\n *\n * @private\n * @param {Array|Object} collection The collection to iterate over.\n * @param {Function} iteratee The function invoked per iteration.\n * @returns {Array} Returns the new mapped array.\n */\n function baseMap(collection, iteratee) {\n var index = -1,\n result = isArrayLike(collection) ? Array(collection.length) : [];\n\n baseEach(collection, function(value, key, collection) {\n result[++index] = iteratee(value, key, collection);\n });\n return result;\n }\n\n /**\n * The base implementation of `_.matches` which doesn't clone `source`.\n *\n * @private\n * @param {Object} source The object of property values to match.\n * @returns {Function} Returns the new spec function.\n */\n function baseMatches(source) {\n var matchData = getMatchData(source);\n if (matchData.length == 1 && matchData[0][2]) {\n return matchesStrictComparable(matchData[0][0], matchData[0][1]);\n }\n return function(object) {\n return object === source || baseIsMatch(object, source, matchData);\n };\n }\n\n /**\n * The base implementation of `_.matchesProperty` which doesn't clone `srcValue`.\n *\n * @private\n * @param {string} path The path of the property to get.\n * @param {*} srcValue The value to match.\n * @returns {Function} Returns the new spec function.\n */\n function baseMatchesProperty(path, srcValue) {\n if (isKey(path) && isStrictComparable(srcValue)) {\n return matchesStrictComparable(toKey(path), srcValue);\n }\n return function(object) {\n var objValue = get(object, path);\n return (objValue === undefined && objValue === srcValue)\n ? hasIn(object, path)\n : baseIsEqual(srcValue, objValue, COMPARE_PARTIAL_FLAG | COMPARE_UNORDERED_FLAG);\n };\n }\n\n /**\n * The base implementation of `_.merge` without support for multiple sources.\n *\n * @private\n * @param {Object} object The destination object.\n * @param {Object} source The source object.\n * @param {number} srcIndex The index of `source`.\n * @param {Function} [customizer] The function to customize merged values.\n * @param {Object} [stack] Tracks traversed source values and their merged\n * counterparts.\n */\n function baseMerge(object, source, srcIndex, customizer, stack) {\n if (object === source) {\n return;\n }\n baseFor(source, function(srcValue, key) {\n stack || (stack = new Stack);\n if (isObject(srcValue)) {\n baseMergeDeep(object, source, key, srcIndex, baseMerge, customizer, stack);\n }\n else {\n var newValue = customizer\n ? customizer(safeGet(object, key), srcValue, (key + ''), object, source, stack)\n : undefined;\n\n if (newValue === undefined) {\n newValue = srcValue;\n }\n assignMergeValue(object, key, newValue);\n }\n }, keysIn);\n }\n\n /**\n * A specialized version of `baseMerge` for arrays and objects which performs\n * deep merges and tracks traversed objects enabling objects with circular\n * references to be merged.\n *\n * @private\n * @param {Object} object The destination object.\n * @param {Object} source The source object.\n * @param {string} key The key of the value to merge.\n * @param {number} srcIndex The index of `source`.\n * @param {Function} mergeFunc The function to merge values.\n * @param {Function} [customizer] The function to customize assigned values.\n * @param {Object} [stack] Tracks traversed source values and their merged\n * counterparts.\n */\n function baseMergeDeep(object, source, key, srcIndex, mergeFunc, customizer, stack) {\n var objValue = safeGet(object, key),\n srcValue = safeGet(source, key),\n stacked = stack.get(srcValue);\n\n if (stacked) {\n assignMergeValue(object, key, stacked);\n return;\n }\n var newValue = customizer\n ? customizer(objValue, srcValue, (key + ''), object, source, stack)\n : undefined;\n\n var isCommon = newValue === undefined;\n\n if (isCommon) {\n var isArr = isArray(srcValue),\n isBuff = !isArr && isBuffer(srcValue),\n isTyped = !isArr && !isBuff && isTypedArray(srcValue);\n\n newValue = srcValue;\n if (isArr || isBuff || isTyped) {\n if (isArray(objValue)) {\n newValue = objValue;\n }\n else if (isArrayLikeObject(objValue)) {\n newValue = copyArray(objValue);\n }\n else if (isBuff) {\n isCommon = false;\n newValue = cloneBuffer(srcValue, true);\n }\n else if (isTyped) {\n isCommon = false;\n newValue = cloneTypedArray(srcValue, true);\n }\n else {\n newValue = [];\n }\n }\n else if (isPlainObject(srcValue) || isArguments(srcValue)) {\n newValue = objValue;\n if (isArguments(objValue)) {\n newValue = toPlainObject(objValue);\n }\n else if (!isObject(objValue) || isFunction(objValue)) {\n newValue = initCloneObject(srcValue);\n }\n }\n else {\n isCommon = false;\n }\n }\n if (isCommon) {\n // Recursively merge objects and arrays (susceptible to call stack limits).\n stack.set(srcValue, newValue);\n mergeFunc(newValue, srcValue, srcIndex, customizer, stack);\n stack['delete'](srcValue);\n }\n assignMergeValue(object, key, newValue);\n }\n\n /**\n * The base implementation of `_.nth` which doesn't coerce arguments.\n *\n * @private\n * @param {Array} array The array to query.\n * @param {number} n The index of the element to return.\n * @returns {*} Returns the nth element of `array`.\n */\n function baseNth(array, n) {\n var length = array.length;\n if (!length) {\n return;\n }\n n += n < 0 ? length : 0;\n return isIndex(n, length) ? array[n] : undefined;\n }\n\n /**\n * The base implementation of `_.orderBy` without param guards.\n *\n * @private\n * @param {Array|Object} collection The collection to iterate over.\n * @param {Function[]|Object[]|string[]} iteratees The iteratees to sort by.\n * @param {string[]} orders The sort orders of `iteratees`.\n * @returns {Array} Returns the new sorted array.\n */\n function baseOrderBy(collection, iteratees, orders) {\n if (iteratees.length) {\n iteratees = arrayMap(iteratees, function(iteratee) {\n if (isArray(iteratee)) {\n return function(value) {\n return baseGet(value, iteratee.length === 1 ? iteratee[0] : iteratee);\n }\n }\n return iteratee;\n });\n } else {\n iteratees = [identity];\n }\n\n var index = -1;\n iteratees = arrayMap(iteratees, baseUnary(getIteratee()));\n\n var result = baseMap(collection, function(value, key, collection) {\n var criteria = arrayMap(iteratees, function(iteratee) {\n return iteratee(value);\n });\n return { 'criteria': criteria, 'index': ++index, 'value': value };\n });\n\n return baseSortBy(result, function(object, other) {\n return compareMultiple(object, other, orders);\n });\n }\n\n /**\n * The base implementation of `_.pick` without support for individual\n * property identifiers.\n *\n * @private\n * @param {Object} object The source object.\n * @param {string[]} paths The property paths to pick.\n * @returns {Object} Returns the new object.\n */\n function basePick(object, paths) {\n return basePickBy(object, paths, function(value, path) {\n return hasIn(object, path);\n });\n }\n\n /**\n * The base implementation of `_.pickBy` without support for iteratee shorthands.\n *\n * @private\n * @param {Object} object The source object.\n * @param {string[]} paths The property paths to pick.\n * @param {Function} predicate The function invoked per property.\n * @returns {Object} Returns the new object.\n */\n function basePickBy(object, paths, predicate) {\n var index = -1,\n length = paths.length,\n result = {};\n\n while (++index < length) {\n var path = paths[index],\n value = baseGet(object, path);\n\n if (predicate(value, path)) {\n baseSet(result, castPath(path, object), value);\n }\n }\n return result;\n }\n\n /**\n * A specialized version of `baseProperty` which supports deep paths.\n *\n * @private\n * @param {Array|string} path The path of the property to get.\n * @returns {Function} Returns the new accessor function.\n */\n function basePropertyDeep(path) {\n return function(object) {\n return baseGet(object, path);\n };\n }\n\n /**\n * The base implementation of `_.pullAllBy` without support for iteratee\n * shorthands.\n *\n * @private\n * @param {Array} array The array to modify.\n * @param {Array} values The values to remove.\n * @param {Function} [iteratee] The iteratee invoked per element.\n * @param {Function} [comparator] The comparator invoked per element.\n * @returns {Array} Returns `array`.\n */\n function basePullAll(array, values, iteratee, comparator) {\n var indexOf = comparator ? baseIndexOfWith : baseIndexOf,\n index = -1,\n length = values.length,\n seen = array;\n\n if (array === values) {\n values = copyArray(values);\n }\n if (iteratee) {\n seen = arrayMap(array, baseUnary(iteratee));\n }\n while (++index < length) {\n var fromIndex = 0,\n value = values[index],\n computed = iteratee ? iteratee(value) : value;\n\n while ((fromIndex = indexOf(seen, computed, fromIndex, comparator)) > -1) {\n if (seen !== array) {\n splice.call(seen, fromIndex, 1);\n }\n splice.call(array, fromIndex, 1);\n }\n }\n return array;\n }\n\n /**\n * The base implementation of `_.pullAt` without support for individual\n * indexes or capturing the removed elements.\n *\n * @private\n * @param {Array} array The array to modify.\n * @param {number[]} indexes The indexes of elements to remove.\n * @returns {Array} Returns `array`.\n */\n function basePullAt(array, indexes) {\n var length = array ? indexes.length : 0,\n lastIndex = length - 1;\n\n while (length--) {\n var index = indexes[length];\n if (length == lastIndex || index !== previous) {\n var previous = index;\n if (isIndex(index)) {\n splice.call(array, index, 1);\n } else {\n baseUnset(array, index);\n }\n }\n }\n return array;\n }\n\n /**\n * The base implementation of `_.random` without support for returning\n * floating-point numbers.\n *\n * @private\n * @param {number} lower The lower bound.\n * @param {number} upper The upper bound.\n * @returns {number} Returns the random number.\n */\n function baseRandom(lower, upper) {\n return lower + nativeFloor(nativeRandom() * (upper - lower + 1));\n }\n\n /**\n * The base implementation of `_.range` and `_.rangeRight` which doesn't\n * coerce arguments.\n *\n * @private\n * @param {number} start The start of the range.\n * @param {number} end The end of the range.\n * @param {number} step The value to increment or decrement by.\n * @param {boolean} [fromRight] Specify iterating from right to left.\n * @returns {Array} Returns the range of numbers.\n */\n function baseRange(start, end, step, fromRight) {\n var index = -1,\n length = nativeMax(nativeCeil((end - start) / (step || 1)), 0),\n result = Array(length);\n\n while (length--) {\n result[fromRight ? length : ++index] = start;\n start += step;\n }\n return result;\n }\n\n /**\n * The base implementation of `_.repeat` which doesn't coerce arguments.\n *\n * @private\n * @param {string} string The string to repeat.\n * @param {number} n The number of times to repeat the string.\n * @returns {string} Returns the repeated string.\n */\n function baseRepeat(string, n) {\n var result = '';\n if (!string || n < 1 || n > MAX_SAFE_INTEGER) {\n return result;\n }\n // Leverage the exponentiation by squaring algorithm for a faster repeat.\n // See https://en.wikipedia.org/wiki/Exponentiation_by_squaring for more details.\n do {\n if (n % 2) {\n result += string;\n }\n n = nativeFloor(n / 2);\n if (n) {\n string += string;\n }\n } while (n);\n\n return result;\n }\n\n /**\n * The base implementation of `_.rest` which doesn't validate or coerce arguments.\n *\n * @private\n * @param {Function} func The function to apply a rest parameter to.\n * @param {number} [start=func.length-1] The start position of the rest parameter.\n * @returns {Function} Returns the new function.\n */\n function baseRest(func, start) {\n return setToString(overRest(func, start, identity), func + '');\n }\n\n /**\n * The base implementation of `_.sample`.\n *\n * @private\n * @param {Array|Object} collection The collection to sample.\n * @returns {*} Returns the random element.\n */\n function baseSample(collection) {\n return arraySample(values(collection));\n }\n\n /**\n * The base implementation of `_.sampleSize` without param guards.\n *\n * @private\n * @param {Array|Object} collection The collection to sample.\n * @param {number} n The number of elements to sample.\n * @returns {Array} Returns the random elements.\n */\n function baseSampleSize(collection, n) {\n var array = values(collection);\n return shuffleSelf(array, baseClamp(n, 0, array.length));\n }\n\n /**\n * The base implementation of `_.set`.\n *\n * @private\n * @param {Object} object The object to modify.\n * @param {Array|string} path The path of the property to set.\n * @param {*} value The value to set.\n * @param {Function} [customizer] The function to customize path creation.\n * @returns {Object} Returns `object`.\n */\n function baseSet(object, path, value, customizer) {\n if (!isObject(object)) {\n return object;\n }\n path = castPath(path, object);\n\n var index = -1,\n length = path.length,\n lastIndex = length - 1,\n nested = object;\n\n while (nested != null && ++index < length) {\n var key = toKey(path[index]),\n newValue = value;\n\n if (key === '__proto__' || key === 'constructor' || key === 'prototype') {\n return object;\n }\n\n if (index != lastIndex) {\n var objValue = nested[key];\n newValue = customizer ? customizer(objValue, key, nested) : undefined;\n if (newValue === undefined) {\n newValue = isObject(objValue)\n ? objValue\n : (isIndex(path[index + 1]) ? [] : {});\n }\n }\n assignValue(nested, key, newValue);\n nested = nested[key];\n }\n return object;\n }\n\n /**\n * The base implementation of `setData` without support for hot loop shorting.\n *\n * @private\n * @param {Function} func The function to associate metadata with.\n * @param {*} data The metadata.\n * @returns {Function} Returns `func`.\n */\n var baseSetData = !metaMap ? identity : function(func, data) {\n metaMap.set(func, data);\n return func;\n };\n\n /**\n * The base implementation of `setToString` without support for hot loop shorting.\n *\n * @private\n * @param {Function} func The function to modify.\n * @param {Function} string The `toString` result.\n * @returns {Function} Returns `func`.\n */\n var baseSetToString = !defineProperty ? identity : function(func, string) {\n return defineProperty(func, 'toString', {\n 'configurable': true,\n 'enumerable': false,\n 'value': constant(string),\n 'writable': true\n });\n };\n\n /**\n * The base implementation of `_.shuffle`.\n *\n * @private\n * @param {Array|Object} collection The collection to shuffle.\n * @returns {Array} Returns the new shuffled array.\n */\n function baseShuffle(collection) {\n return shuffleSelf(values(collection));\n }\n\n /**\n * The base implementation of `_.slice` without an iteratee call guard.\n *\n * @private\n * @param {Array} array The array to slice.\n * @param {number} [start=0] The start position.\n * @param {number} [end=array.length] The end position.\n * @returns {Array} Returns the slice of `array`.\n */\n function baseSlice(array, start, end) {\n var index = -1,\n length = array.length;\n\n if (start < 0) {\n start = -start > length ? 0 : (length + start);\n }\n end = end > length ? length : end;\n if (end < 0) {\n end += length;\n }\n length = start > end ? 0 : ((end - start) >>> 0);\n start >>>= 0;\n\n var result = Array(length);\n while (++index < length) {\n result[index] = array[index + start];\n }\n return result;\n }\n\n /**\n * The base implementation of `_.some` without support for iteratee shorthands.\n *\n * @private\n * @param {Array|Object} collection The collection to iterate over.\n * @param {Function} predicate The function invoked per iteration.\n * @returns {boolean} Returns `true` if any element passes the predicate check,\n * else `false`.\n */\n function baseSome(collection, predicate) {\n var result;\n\n baseEach(collection, function(value, index, collection) {\n result = predicate(value, index, collection);\n return !result;\n });\n return !!result;\n }\n\n /**\n * The base implementation of `_.sortedIndex` and `_.sortedLastIndex` which\n * performs a binary search of `array` to determine the index at which `value`\n * should be inserted into `array` in order to maintain its sort order.\n *\n * @private\n * @param {Array} array The sorted array to inspect.\n * @param {*} value The value to evaluate.\n * @param {boolean} [retHighest] Specify returning the highest qualified index.\n * @returns {number} Returns the index at which `value` should be inserted\n * into `array`.\n */\n function baseSortedIndex(array, value, retHighest) {\n var low = 0,\n high = array == null ? low : array.length;\n\n if (typeof value == 'number' && value === value && high <= HALF_MAX_ARRAY_LENGTH) {\n while (low < high) {\n var mid = (low + high) >>> 1,\n computed = array[mid];\n\n if (computed !== null && !isSymbol(computed) &&\n (retHighest ? (computed <= value) : (computed < value))) {\n low = mid + 1;\n } else {\n high = mid;\n }\n }\n return high;\n }\n return baseSortedIndexBy(array, value, identity, retHighest);\n }\n\n /**\n * The base implementation of `_.sortedIndexBy` and `_.sortedLastIndexBy`\n * which invokes `iteratee` for `value` and each element of `array` to compute\n * their sort ranking. The iteratee is invoked with one argument; (value).\n *\n * @private\n * @param {Array} array The sorted array to inspect.\n * @param {*} value The value to evaluate.\n * @param {Function} iteratee The iteratee invoked per element.\n * @param {boolean} [retHighest] Specify returning the highest qualified index.\n * @returns {number} Returns the index at which `value` should be inserted\n * into `array`.\n */\n function baseSortedIndexBy(array, value, iteratee, retHighest) {\n var low = 0,\n high = array == null ? 0 : array.length;\n if (high === 0) {\n return 0;\n }\n\n value = iteratee(value);\n var valIsNaN = value !== value,\n valIsNull = value === null,\n valIsSymbol = isSymbol(value),\n valIsUndefined = value === undefined;\n\n while (low < high) {\n var mid = nativeFloor((low + high) / 2),\n computed = iteratee(array[mid]),\n othIsDefined = computed !== undefined,\n othIsNull = computed === null,\n othIsReflexive = computed === computed,\n othIsSymbol = isSymbol(computed);\n\n if (valIsNaN) {\n var setLow = retHighest || othIsReflexive;\n } else if (valIsUndefined) {\n setLow = othIsReflexive && (retHighest || othIsDefined);\n } else if (valIsNull) {\n setLow = othIsReflexive && othIsDefined && (retHighest || !othIsNull);\n } else if (valIsSymbol) {\n setLow = othIsReflexive && othIsDefined && !othIsNull && (retHighest || !othIsSymbol);\n } else if (othIsNull || othIsSymbol) {\n setLow = false;\n } else {\n setLow = retHighest ? (computed <= value) : (computed < value);\n }\n if (setLow) {\n low = mid + 1;\n } else {\n high = mid;\n }\n }\n return nativeMin(high, MAX_ARRAY_INDEX);\n }\n\n /**\n * The base implementation of `_.sortedUniq` and `_.sortedUniqBy` without\n * support for iteratee shorthands.\n *\n * @private\n * @param {Array} array The array to inspect.\n * @param {Function} [iteratee] The iteratee invoked per element.\n * @returns {Array} Returns the new duplicate free array.\n */\n function baseSortedUniq(array, iteratee) {\n var index = -1,\n length = array.length,\n resIndex = 0,\n result = [];\n\n while (++index < length) {\n var value = array[index],\n computed = iteratee ? iteratee(value) : value;\n\n if (!index || !eq(computed, seen)) {\n var seen = computed;\n result[resIndex++] = value === 0 ? 0 : value;\n }\n }\n return result;\n }\n\n /**\n * The base implementation of `_.toNumber` which doesn't ensure correct\n * conversions of binary, hexadecimal, or octal string values.\n *\n * @private\n * @param {*} value The value to process.\n * @returns {number} Returns the number.\n */\n function baseToNumber(value) {\n if (typeof value == 'number') {\n return value;\n }\n if (isSymbol(value)) {\n return NAN;\n }\n return +value;\n }\n\n /**\n * The base implementation of `_.toString` which doesn't convert nullish\n * values to empty strings.\n *\n * @private\n * @param {*} value The value to process.\n * @returns {string} Returns the string.\n */\n function baseToString(value) {\n // Exit early for strings to avoid a performance hit in some environments.\n if (typeof value == 'string') {\n return value;\n }\n if (isArray(value)) {\n // Recursively convert values (susceptible to call stack limits).\n return arrayMap(value, baseToString) + '';\n }\n if (isSymbol(value)) {\n return symbolToString ? symbolToString.call(value) : '';\n }\n var result = (value + '');\n return (result == '0' && (1 / value) == -INFINITY) ? '-0' : result;\n }\n\n /**\n * The base implementation of `_.uniqBy` without support for iteratee shorthands.\n *\n * @private\n * @param {Array} array The array to inspect.\n * @param {Function} [iteratee] The iteratee invoked per element.\n * @param {Function} [comparator] The comparator invoked per element.\n * @returns {Array} Returns the new duplicate free array.\n */\n function baseUniq(array, iteratee, comparator) {\n var index = -1,\n includes = arrayIncludes,\n length = array.length,\n isCommon = true,\n result = [],\n seen = result;\n\n if (comparator) {\n isCommon = false;\n includes = arrayIncludesWith;\n }\n else if (length >= LARGE_ARRAY_SIZE) {\n var set = iteratee ? null : createSet(array);\n if (set) {\n return setToArray(set);\n }\n isCommon = false;\n includes = cacheHas;\n seen = new SetCache;\n }\n else {\n seen = iteratee ? [] : result;\n }\n outer:\n while (++index < length) {\n var value = array[index],\n computed = iteratee ? iteratee(value) : value;\n\n value = (comparator || value !== 0) ? value : 0;\n if (isCommon && computed === computed) {\n var seenIndex = seen.length;\n while (seenIndex--) {\n if (seen[seenIndex] === computed) {\n continue outer;\n }\n }\n if (iteratee) {\n seen.push(computed);\n }\n result.push(value);\n }\n else if (!includes(seen, computed, comparator)) {\n if (seen !== result) {\n seen.push(computed);\n }\n result.push(value);\n }\n }\n return result;\n }\n\n /**\n * The base implementation of `_.unset`.\n *\n * @private\n * @param {Object} object The object to modify.\n * @param {Array|string} path The property path to unset.\n * @returns {boolean} Returns `true` if the property is deleted, else `false`.\n */\n function baseUnset(object, path) {\n path = castPath(path, object);\n object = parent(object, path);\n return object == null || delete object[toKey(last(path))];\n }\n\n /**\n * The base implementation of `_.update`.\n *\n * @private\n * @param {Object} object The object to modify.\n * @param {Array|string} path The path of the property to update.\n * @param {Function} updater The function to produce the updated value.\n * @param {Function} [customizer] The function to customize path creation.\n * @returns {Object} Returns `object`.\n */\n function baseUpdate(object, path, updater, customizer) {\n return baseSet(object, path, updater(baseGet(object, path)), customizer);\n }\n\n /**\n * The base implementation of methods like `_.dropWhile` and `_.takeWhile`\n * without support for iteratee shorthands.\n *\n * @private\n * @param {Array} array The array to query.\n * @param {Function} predicate The function invoked per iteration.\n * @param {boolean} [isDrop] Specify dropping elements instead of taking them.\n * @param {boolean} [fromRight] Specify iterating from right to left.\n * @returns {Array} Returns the slice of `array`.\n */\n function baseWhile(array, predicate, isDrop, fromRight) {\n var length = array.length,\n index = fromRight ? length : -1;\n\n while ((fromRight ? index-- : ++index < length) &&\n predicate(array[index], index, array)) {}\n\n return isDrop\n ? baseSlice(array, (fromRight ? 0 : index), (fromRight ? index + 1 : length))\n : baseSlice(array, (fromRight ? index + 1 : 0), (fromRight ? length : index));\n }\n\n /**\n * The base implementation of `wrapperValue` which returns the result of\n * performing a sequence of actions on the unwrapped `value`, where each\n * successive action is supplied the return value of the previous.\n *\n * @private\n * @param {*} value The unwrapped value.\n * @param {Array} actions Actions to perform to resolve the unwrapped value.\n * @returns {*} Returns the resolved value.\n */\n function baseWrapperValue(value, actions) {\n var result = value;\n if (result instanceof LazyWrapper) {\n result = result.value();\n }\n return arrayReduce(actions, function(result, action) {\n return action.func.apply(action.thisArg, arrayPush([result], action.args));\n }, result);\n }\n\n /**\n * The base implementation of methods like `_.xor`, without support for\n * iteratee shorthands, that accepts an array of arrays to inspect.\n *\n * @private\n * @param {Array} arrays The arrays to inspect.\n * @param {Function} [iteratee] The iteratee invoked per element.\n * @param {Function} [comparator] The comparator invoked per element.\n * @returns {Array} Returns the new array of values.\n */\n function baseXor(arrays, iteratee, comparator) {\n var length = arrays.length;\n if (length < 2) {\n return length ? baseUniq(arrays[0]) : [];\n }\n var index = -1,\n result = Array(length);\n\n while (++index < length) {\n var array = arrays[index],\n othIndex = -1;\n\n while (++othIndex < length) {\n if (othIndex != index) {\n result[index] = baseDifference(result[index] || array, arrays[othIndex], iteratee, comparator);\n }\n }\n }\n return baseUniq(baseFlatten(result, 1), iteratee, comparator);\n }\n\n /**\n * This base implementation of `_.zipObject` which assigns values using `assignFunc`.\n *\n * @private\n * @param {Array} props The property identifiers.\n * @param {Array} values The property values.\n * @param {Function} assignFunc The function to assign values.\n * @returns {Object} Returns the new object.\n */\n function baseZipObject(props, values, assignFunc) {\n var index = -1,\n length = props.length,\n valsLength = values.length,\n result = {};\n\n while (++index < length) {\n var value = index < valsLength ? values[index] : undefined;\n assignFunc(result, props[index], value);\n }\n return result;\n }\n\n /**\n * Casts `value` to an empty array if it's not an array like object.\n *\n * @private\n * @param {*} value The value to inspect.\n * @returns {Array|Object} Returns the cast array-like object.\n */\n function castArrayLikeObject(value) {\n return isArrayLikeObject(value) ? value : [];\n }\n\n /**\n * Casts `value` to `identity` if it's not a function.\n *\n * @private\n * @param {*} value The value to inspect.\n * @returns {Function} Returns cast function.\n */\n function castFunction(value) {\n return typeof value == 'function' ? value : identity;\n }\n\n /**\n * Casts `value` to a path array if it's not one.\n *\n * @private\n * @param {*} value The value to inspect.\n * @param {Object} [object] The object to query keys on.\n * @returns {Array} Returns the cast property path array.\n */\n function castPath(value, object) {\n if (isArray(value)) {\n return value;\n }\n return isKey(value, object) ? [value] : stringToPath(toString(value));\n }\n\n /**\n * A `baseRest` alias which can be replaced with `identity` by module\n * replacement plugins.\n *\n * @private\n * @type {Function}\n * @param {Function} func The function to apply a rest parameter to.\n * @returns {Function} Returns the new function.\n */\n var castRest = baseRest;\n\n /**\n * Casts `array` to a slice if it's needed.\n *\n * @private\n * @param {Array} array The array to inspect.\n * @param {number} start The start position.\n * @param {number} [end=array.length] The end position.\n * @returns {Array} Returns the cast slice.\n */\n function castSlice(array, start, end) {\n var length = array.length;\n end = end === undefined ? length : end;\n return (!start && end >= length) ? array : baseSlice(array, start, end);\n }\n\n /**\n * A simple wrapper around the global [`clearTimeout`](https://mdn.io/clearTimeout).\n *\n * @private\n * @param {number|Object} id The timer id or timeout object of the timer to clear.\n */\n var clearTimeout = ctxClearTimeout || function(id) {\n return root.clearTimeout(id);\n };\n\n /**\n * Creates a clone of `buffer`.\n *\n * @private\n * @param {Buffer} buffer The buffer to clone.\n * @param {boolean} [isDeep] Specify a deep clone.\n * @returns {Buffer} Returns the cloned buffer.\n */\n function cloneBuffer(buffer, isDeep) {\n if (isDeep) {\n return buffer.slice();\n }\n var length = buffer.length,\n result = allocUnsafe ? allocUnsafe(length) : new buffer.constructor(length);\n\n buffer.copy(result);\n return result;\n }\n\n /**\n * Creates a clone of `arrayBuffer`.\n *\n * @private\n * @param {ArrayBuffer} arrayBuffer The array buffer to clone.\n * @returns {ArrayBuffer} Returns the cloned array buffer.\n */\n function cloneArrayBuffer(arrayBuffer) {\n var result = new arrayBuffer.constructor(arrayBuffer.byteLength);\n new Uint8Array(result).set(new Uint8Array(arrayBuffer));\n return result;\n }\n\n /**\n * Creates a clone of `dataView`.\n *\n * @private\n * @param {Object} dataView The data view to clone.\n * @param {boolean} [isDeep] Specify a deep clone.\n * @returns {Object} Returns the cloned data view.\n */\n function cloneDataView(dataView, isDeep) {\n var buffer = isDeep ? cloneArrayBuffer(dataView.buffer) : dataView.buffer;\n return new dataView.constructor(buffer, dataView.byteOffset, dataView.byteLength);\n }\n\n /**\n * Creates a clone of `regexp`.\n *\n * @private\n * @param {Object} regexp The regexp to clone.\n * @returns {Object} Returns the cloned regexp.\n */\n function cloneRegExp(regexp) {\n var result = new regexp.constructor(regexp.source, reFlags.exec(regexp));\n result.lastIndex = regexp.lastIndex;\n return result;\n }\n\n /**\n * Creates a clone of the `symbol` object.\n *\n * @private\n * @param {Object} symbol The symbol object to clone.\n * @returns {Object} Returns the cloned symbol object.\n */\n function cloneSymbol(symbol) {\n return symbolValueOf ? Object(symbolValueOf.call(symbol)) : {};\n }\n\n /**\n * Creates a clone of `typedArray`.\n *\n * @private\n * @param {Object} typedArray The typed array to clone.\n * @param {boolean} [isDeep] Specify a deep clone.\n * @returns {Object} Returns the cloned typed array.\n */\n function cloneTypedArray(typedArray, isDeep) {\n var buffer = isDeep ? cloneArrayBuffer(typedArray.buffer) : typedArray.buffer;\n return new typedArray.constructor(buffer, typedArray.byteOffset, typedArray.length);\n }\n\n /**\n * Compares values to sort them in ascending order.\n *\n * @private\n * @param {*} value The value to compare.\n * @param {*} other The other value to compare.\n * @returns {number} Returns the sort order indicator for `value`.\n */\n function compareAscending(value, other) {\n if (value !== other) {\n var valIsDefined = value !== undefined,\n valIsNull = value === null,\n valIsReflexive = value === value,\n valIsSymbol = isSymbol(value);\n\n var othIsDefined = other !== undefined,\n othIsNull = other === null,\n othIsReflexive = other === other,\n othIsSymbol = isSymbol(other);\n\n if ((!othIsNull && !othIsSymbol && !valIsSymbol && value > other) ||\n (valIsSymbol && othIsDefined && othIsReflexive && !othIsNull && !othIsSymbol) ||\n (valIsNull && othIsDefined && othIsReflexive) ||\n (!valIsDefined && othIsReflexive) ||\n !valIsReflexive) {\n return 1;\n }\n if ((!valIsNull && !valIsSymbol && !othIsSymbol && value < other) ||\n (othIsSymbol && valIsDefined && valIsReflexive && !valIsNull && !valIsSymbol) ||\n (othIsNull && valIsDefined && valIsReflexive) ||\n (!othIsDefined && valIsReflexive) ||\n !othIsReflexive) {\n return -1;\n }\n }\n return 0;\n }\n\n /**\n * Used by `_.orderBy` to compare multiple properties of a value to another\n * and stable sort them.\n *\n * If `orders` is unspecified, all values are sorted in ascending order. Otherwise,\n * specify an order of \"desc\" for descending or \"asc\" for ascending sort order\n * of corresponding values.\n *\n * @private\n * @param {Object} object The object to compare.\n * @param {Object} other The other object to compare.\n * @param {boolean[]|string[]} orders The order to sort by for each property.\n * @returns {number} Returns the sort order indicator for `object`.\n */\n function compareMultiple(object, other, orders) {\n var index = -1,\n objCriteria = object.criteria,\n othCriteria = other.criteria,\n length = objCriteria.length,\n ordersLength = orders.length;\n\n while (++index < length) {\n var result = compareAscending(objCriteria[index], othCriteria[index]);\n if (result) {\n if (index >= ordersLength) {\n return result;\n }\n var order = orders[index];\n return result * (order == 'desc' ? -1 : 1);\n }\n }\n // Fixes an `Array#sort` bug in the JS engine embedded in Adobe applications\n // that causes it, under certain circumstances, to provide the same value for\n // `object` and `other`. See https://github.com/jashkenas/underscore/pull/1247\n // for more details.\n //\n // This also ensures a stable sort in V8 and other engines.\n // See https://bugs.chromium.org/p/v8/issues/detail?id=90 for more details.\n return object.index - other.index;\n }\n\n /**\n * Creates an array that is the composition of partially applied arguments,\n * placeholders, and provided arguments into a single array of arguments.\n *\n * @private\n * @param {Array} args The provided arguments.\n * @param {Array} partials The arguments to prepend to those provided.\n * @param {Array} holders The `partials` placeholder indexes.\n * @params {boolean} [isCurried] Specify composing for a curried function.\n * @returns {Array} Returns the new array of composed arguments.\n */\n function composeArgs(args, partials, holders, isCurried) {\n var argsIndex = -1,\n argsLength = args.length,\n holdersLength = holders.length,\n leftIndex = -1,\n leftLength = partials.length,\n rangeLength = nativeMax(argsLength - holdersLength, 0),\n result = Array(leftLength + rangeLength),\n isUncurried = !isCurried;\n\n while (++leftIndex < leftLength) {\n result[leftIndex] = partials[leftIndex];\n }\n while (++argsIndex < holdersLength) {\n if (isUncurried || argsIndex < argsLength) {\n result[holders[argsIndex]] = args[argsIndex];\n }\n }\n while (rangeLength--) {\n result[leftIndex++] = args[argsIndex++];\n }\n return result;\n }\n\n /**\n * This function is like `composeArgs` except that the arguments composition\n * is tailored for `_.partialRight`.\n *\n * @private\n * @param {Array} args The provided arguments.\n * @param {Array} partials The arguments to append to those provided.\n * @param {Array} holders The `partials` placeholder indexes.\n * @params {boolean} [isCurried] Specify composing for a curried function.\n * @returns {Array} Returns the new array of composed arguments.\n */\n function composeArgsRight(args, partials, holders, isCurried) {\n var argsIndex = -1,\n argsLength = args.length,\n holdersIndex = -1,\n holdersLength = holders.length,\n rightIndex = -1,\n rightLength = partials.length,\n rangeLength = nativeMax(argsLength - holdersLength, 0),\n result = Array(rangeLength + rightLength),\n isUncurried = !isCurried;\n\n while (++argsIndex < rangeLength) {\n result[argsIndex] = args[argsIndex];\n }\n var offset = argsIndex;\n while (++rightIndex < rightLength) {\n result[offset + rightIndex] = partials[rightIndex];\n }\n while (++holdersIndex < holdersLength) {\n if (isUncurried || argsIndex < argsLength) {\n result[offset + holders[holdersIndex]] = args[argsIndex++];\n }\n }\n return result;\n }\n\n /**\n * Copies the values of `source` to `array`.\n *\n * @private\n * @param {Array} source The array to copy values from.\n * @param {Array} [array=[]] The array to copy values to.\n * @returns {Array} Returns `array`.\n */\n function copyArray(source, array) {\n var index = -1,\n length = source.length;\n\n array || (array = Array(length));\n while (++index < length) {\n array[index] = source[index];\n }\n return array;\n }\n\n /**\n * Copies properties of `source` to `object`.\n *\n * @private\n * @param {Object} source The object to copy properties from.\n * @param {Array} props The property identifiers to copy.\n * @param {Object} [object={}] The object to copy properties to.\n * @param {Function} [customizer] The function to customize copied values.\n * @returns {Object} Returns `object`.\n */\n function copyObject(source, props, object, customizer) {\n var isNew = !object;\n object || (object = {});\n\n var index = -1,\n length = props.length;\n\n while (++index < length) {\n var key = props[index];\n\n var newValue = customizer\n ? customizer(object[key], source[key], key, object, source)\n : undefined;\n\n if (newValue === undefined) {\n newValue = source[key];\n }\n if (isNew) {\n baseAssignValue(object, key, newValue);\n } else {\n assignValue(object, key, newValue);\n }\n }\n return object;\n }\n\n /**\n * Copies own symbols of `source` to `object`.\n *\n * @private\n * @param {Object} source The object to copy symbols from.\n * @param {Object} [object={}] The object to copy symbols to.\n * @returns {Object} Returns `object`.\n */\n function copySymbols(source, object) {\n return copyObject(source, getSymbols(source), object);\n }\n\n /**\n * Copies own and inherited symbols of `source` to `object`.\n *\n * @private\n * @param {Object} source The object to copy symbols from.\n * @param {Object} [object={}] The object to copy symbols to.\n * @returns {Object} Returns `object`.\n */\n function copySymbolsIn(source, object) {\n return copyObject(source, getSymbolsIn(source), object);\n }\n\n /**\n * Creates a function like `_.groupBy`.\n *\n * @private\n * @param {Function} setter The function to set accumulator values.\n * @param {Function} [initializer] The accumulator object initializer.\n * @returns {Function} Returns the new aggregator function.\n */\n function createAggregator(setter, initializer) {\n return function(collection, iteratee) {\n var func = isArray(collection) ? arrayAggregator : baseAggregator,\n accumulator = initializer ? initializer() : {};\n\n return func(collection, setter, getIteratee(iteratee, 2), accumulator);\n };\n }\n\n /**\n * Creates a function like `_.assign`.\n *\n * @private\n * @param {Function} assigner The function to assign values.\n * @returns {Function} Returns the new assigner function.\n */\n function createAssigner(assigner) {\n return baseRest(function(object, sources) {\n var index = -1,\n length = sources.length,\n customizer = length > 1 ? sources[length - 1] : undefined,\n guard = length > 2 ? sources[2] : undefined;\n\n customizer = (assigner.length > 3 && typeof customizer == 'function')\n ? (length--, customizer)\n : undefined;\n\n if (guard && isIterateeCall(sources[0], sources[1], guard)) {\n customizer = length < 3 ? undefined : customizer;\n length = 1;\n }\n object = Object(object);\n while (++index < length) {\n var source = sources[index];\n if (source) {\n assigner(object, source, index, customizer);\n }\n }\n return object;\n });\n }\n\n /**\n * Creates a `baseEach` or `baseEachRight` function.\n *\n * @private\n * @param {Function} eachFunc The function to iterate over a collection.\n * @param {boolean} [fromRight] Specify iterating from right to left.\n * @returns {Function} Returns the new base function.\n */\n function createBaseEach(eachFunc, fromRight) {\n return function(collection, iteratee) {\n if (collection == null) {\n return collection;\n }\n if (!isArrayLike(collection)) {\n return eachFunc(collection, iteratee);\n }\n var length = collection.length,\n index = fromRight ? length : -1,\n iterable = Object(collection);\n\n while ((fromRight ? index-- : ++index < length)) {\n if (iteratee(iterable[index], index, iterable) === false) {\n break;\n }\n }\n return collection;\n };\n }\n\n /**\n * Creates a base function for methods like `_.forIn` and `_.forOwn`.\n *\n * @private\n * @param {boolean} [fromRight] Specify iterating from right to left.\n * @returns {Function} Returns the new base function.\n */\n function createBaseFor(fromRight) {\n return function(object, iteratee, keysFunc) {\n var index = -1,\n iterable = Object(object),\n props = keysFunc(object),\n length = props.length;\n\n while (length--) {\n var key = props[fromRight ? length : ++index];\n if (iteratee(iterable[key], key, iterable) === false) {\n break;\n }\n }\n return object;\n };\n }\n\n /**\n * Creates a function that wraps `func` to invoke it with the optional `this`\n * binding of `thisArg`.\n *\n * @private\n * @param {Function} func The function to wrap.\n * @param {number} bitmask The bitmask flags. See `createWrap` for more details.\n * @param {*} [thisArg] The `this` binding of `func`.\n * @returns {Function} Returns the new wrapped function.\n */\n function createBind(func, bitmask, thisArg) {\n var isBind = bitmask & WRAP_BIND_FLAG,\n Ctor = createCtor(func);\n\n function wrapper() {\n var fn = (this && this !== root && this instanceof wrapper) ? Ctor : func;\n return fn.apply(isBind ? thisArg : this, arguments);\n }\n return wrapper;\n }\n\n /**\n * Creates a function like `_.lowerFirst`.\n *\n * @private\n * @param {string} methodName The name of the `String` case method to use.\n * @returns {Function} Returns the new case function.\n */\n function createCaseFirst(methodName) {\n return function(string) {\n string = toString(string);\n\n var strSymbols = hasUnicode(string)\n ? stringToArray(string)\n : undefined;\n\n var chr = strSymbols\n ? strSymbols[0]\n : string.charAt(0);\n\n var trailing = strSymbols\n ? castSlice(strSymbols, 1).join('')\n : string.slice(1);\n\n return chr[methodName]() + trailing;\n };\n }\n\n /**\n * Creates a function like `_.camelCase`.\n *\n * @private\n * @param {Function} callback The function to combine each word.\n * @returns {Function} Returns the new compounder function.\n */\n function createCompounder(callback) {\n return function(string) {\n return arrayReduce(words(deburr(string).replace(reApos, '')), callback, '');\n };\n }\n\n /**\n * Creates a function that produces an instance of `Ctor` regardless of\n * whether it was invoked as part of a `new` expression or by `call` or `apply`.\n *\n * @private\n * @param {Function} Ctor The constructor to wrap.\n * @returns {Function} Returns the new wrapped function.\n */\n function createCtor(Ctor) {\n return function() {\n // Use a `switch` statement to work with class constructors. See\n // http://ecma-international.org/ecma-262/7.0/#sec-ecmascript-function-objects-call-thisargument-argumentslist\n // for more details.\n var args = arguments;\n switch (args.length) {\n case 0: return new Ctor;\n case 1: return new Ctor(args[0]);\n case 2: return new Ctor(args[0], args[1]);\n case 3: return new Ctor(args[0], args[1], args[2]);\n case 4: return new Ctor(args[0], args[1], args[2], args[3]);\n case 5: return new Ctor(args[0], args[1], args[2], args[3], args[4]);\n case 6: return new Ctor(args[0], args[1], args[2], args[3], args[4], args[5]);\n case 7: return new Ctor(args[0], args[1], args[2], args[3], args[4], args[5], args[6]);\n }\n var thisBinding = baseCreate(Ctor.prototype),\n result = Ctor.apply(thisBinding, args);\n\n // Mimic the constructor's `return` behavior.\n // See https://es5.github.io/#x13.2.2 for more details.\n return isObject(result) ? result : thisBinding;\n };\n }\n\n /**\n * Creates a function that wraps `func` to enable currying.\n *\n * @private\n * @param {Function} func The function to wrap.\n * @param {number} bitmask The bitmask flags. See `createWrap` for more details.\n * @param {number} arity The arity of `func`.\n * @returns {Function} Returns the new wrapped function.\n */\n function createCurry(func, bitmask, arity) {\n var Ctor = createCtor(func);\n\n function wrapper() {\n var length = arguments.length,\n args = Array(length),\n index = length,\n placeholder = getHolder(wrapper);\n\n while (index--) {\n args[index] = arguments[index];\n }\n var holders = (length < 3 && args[0] !== placeholder && args[length - 1] !== placeholder)\n ? []\n : replaceHolders(args, placeholder);\n\n length -= holders.length;\n if (length < arity) {\n return createRecurry(\n func, bitmask, createHybrid, wrapper.placeholder, undefined,\n args, holders, undefined, undefined, arity - length);\n }\n var fn = (this && this !== root && this instanceof wrapper) ? Ctor : func;\n return apply(fn, this, args);\n }\n return wrapper;\n }\n\n /**\n * Creates a `_.find` or `_.findLast` function.\n *\n * @private\n * @param {Function} findIndexFunc The function to find the collection index.\n * @returns {Function} Returns the new find function.\n */\n function createFind(findIndexFunc) {\n return function(collection, predicate, fromIndex) {\n var iterable = Object(collection);\n if (!isArrayLike(collection)) {\n var iteratee = getIteratee(predicate, 3);\n collection = keys(collection);\n predicate = function(key) { return iteratee(iterable[key], key, iterable); };\n }\n var index = findIndexFunc(collection, predicate, fromIndex);\n return index > -1 ? iterable[iteratee ? collection[index] : index] : undefined;\n };\n }\n\n /**\n * Creates a `_.flow` or `_.flowRight` function.\n *\n * @private\n * @param {boolean} [fromRight] Specify iterating from right to left.\n * @returns {Function} Returns the new flow function.\n */\n function createFlow(fromRight) {\n return flatRest(function(funcs) {\n var length = funcs.length,\n index = length,\n prereq = LodashWrapper.prototype.thru;\n\n if (fromRight) {\n funcs.reverse();\n }\n while (index--) {\n var func = funcs[index];\n if (typeof func != 'function') {\n throw new TypeError(FUNC_ERROR_TEXT);\n }\n if (prereq && !wrapper && getFuncName(func) == 'wrapper') {\n var wrapper = new LodashWrapper([], true);\n }\n }\n index = wrapper ? index : length;\n while (++index < length) {\n func = funcs[index];\n\n var funcName = getFuncName(func),\n data = funcName == 'wrapper' ? getData(func) : undefined;\n\n if (data && isLaziable(data[0]) &&\n data[1] == (WRAP_ARY_FLAG | WRAP_CURRY_FLAG | WRAP_PARTIAL_FLAG | WRAP_REARG_FLAG) &&\n !data[4].length && data[9] == 1\n ) {\n wrapper = wrapper[getFuncName(data[0])].apply(wrapper, data[3]);\n } else {\n wrapper = (func.length == 1 && isLaziable(func))\n ? wrapper[funcName]()\n : wrapper.thru(func);\n }\n }\n return function() {\n var args = arguments,\n value = args[0];\n\n if (wrapper && args.length == 1 && isArray(value)) {\n return wrapper.plant(value).value();\n }\n var index = 0,\n result = length ? funcs[index].apply(this, args) : value;\n\n while (++index < length) {\n result = funcs[index].call(this, result);\n }\n return result;\n };\n });\n }\n\n /**\n * Creates a function that wraps `func` to invoke it with optional `this`\n * binding of `thisArg`, partial application, and currying.\n *\n * @private\n * @param {Function|string} func The function or method name to wrap.\n * @param {number} bitmask The bitmask flags. See `createWrap` for more details.\n * @param {*} [thisArg] The `this` binding of `func`.\n * @param {Array} [partials] The arguments to prepend to those provided to\n * the new function.\n * @param {Array} [holders] The `partials` placeholder indexes.\n * @param {Array} [partialsRight] The arguments to append to those provided\n * to the new function.\n * @param {Array} [holdersRight] The `partialsRight` placeholder indexes.\n * @param {Array} [argPos] The argument positions of the new function.\n * @param {number} [ary] The arity cap of `func`.\n * @param {number} [arity] The arity of `func`.\n * @returns {Function} Returns the new wrapped function.\n */\n function createHybrid(func, bitmask, thisArg, partials, holders, partialsRight, holdersRight, argPos, ary, arity) {\n var isAry = bitmask & WRAP_ARY_FLAG,\n isBind = bitmask & WRAP_BIND_FLAG,\n isBindKey = bitmask & WRAP_BIND_KEY_FLAG,\n isCurried = bitmask & (WRAP_CURRY_FLAG | WRAP_CURRY_RIGHT_FLAG),\n isFlip = bitmask & WRAP_FLIP_FLAG,\n Ctor = isBindKey ? undefined : createCtor(func);\n\n function wrapper() {\n var length = arguments.length,\n args = Array(length),\n index = length;\n\n while (index--) {\n args[index] = arguments[index];\n }\n if (isCurried) {\n var placeholder = getHolder(wrapper),\n holdersCount = countHolders(args, placeholder);\n }\n if (partials) {\n args = composeArgs(args, partials, holders, isCurried);\n }\n if (partialsRight) {\n args = composeArgsRight(args, partialsRight, holdersRight, isCurried);\n }\n length -= holdersCount;\n if (isCurried && length < arity) {\n var newHolders = replaceHolders(args, placeholder);\n return createRecurry(\n func, bitmask, createHybrid, wrapper.placeholder, thisArg,\n args, newHolders, argPos, ary, arity - length\n );\n }\n var thisBinding = isBind ? thisArg : this,\n fn = isBindKey ? thisBinding[func] : func;\n\n length = args.length;\n if (argPos) {\n args = reorder(args, argPos);\n } else if (isFlip && length > 1) {\n args.reverse();\n }\n if (isAry && ary < length) {\n args.length = ary;\n }\n if (this && this !== root && this instanceof wrapper) {\n fn = Ctor || createCtor(fn);\n }\n return fn.apply(thisBinding, args);\n }\n return wrapper;\n }\n\n /**\n * Creates a function like `_.invertBy`.\n *\n * @private\n * @param {Function} setter The function to set accumulator values.\n * @param {Function} toIteratee The function to resolve iteratees.\n * @returns {Function} Returns the new inverter function.\n */\n function createInverter(setter, toIteratee) {\n return function(object, iteratee) {\n return baseInverter(object, setter, toIteratee(iteratee), {});\n };\n }\n\n /**\n * Creates a function that performs a mathematical operation on two values.\n *\n * @private\n * @param {Function} operator The function to perform the operation.\n * @param {number} [defaultValue] The value used for `undefined` arguments.\n * @returns {Function} Returns the new mathematical operation function.\n */\n function createMathOperation(operator, defaultValue) {\n return function(value, other) {\n var result;\n if (value === undefined && other === undefined) {\n return defaultValue;\n }\n if (value !== undefined) {\n result = value;\n }\n if (other !== undefined) {\n if (result === undefined) {\n return other;\n }\n if (typeof value == 'string' || typeof other == 'string') {\n value = baseToString(value);\n other = baseToString(other);\n } else {\n value = baseToNumber(value);\n other = baseToNumber(other);\n }\n result = operator(value, other);\n }\n return result;\n };\n }\n\n /**\n * Creates a function like `_.over`.\n *\n * @private\n * @param {Function} arrayFunc The function to iterate over iteratees.\n * @returns {Function} Returns the new over function.\n */\n function createOver(arrayFunc) {\n return flatRest(function(iteratees) {\n iteratees = arrayMap(iteratees, baseUnary(getIteratee()));\n return baseRest(function(args) {\n var thisArg = this;\n return arrayFunc(iteratees, function(iteratee) {\n return apply(iteratee, thisArg, args);\n });\n });\n });\n }\n\n /**\n * Creates the padding for `string` based on `length`. The `chars` string\n * is truncated if the number of characters exceeds `length`.\n *\n * @private\n * @param {number} length The padding length.\n * @param {string} [chars=' '] The string used as padding.\n * @returns {string} Returns the padding for `string`.\n */\n function createPadding(length, chars) {\n chars = chars === undefined ? ' ' : baseToString(chars);\n\n var charsLength = chars.length;\n if (charsLength < 2) {\n return charsLength ? baseRepeat(chars, length) : chars;\n }\n var result = baseRepeat(chars, nativeCeil(length / stringSize(chars)));\n return hasUnicode(chars)\n ? castSlice(stringToArray(result), 0, length).join('')\n : result.slice(0, length);\n }\n\n /**\n * Creates a function that wraps `func` to invoke it with the `this` binding\n * of `thisArg` and `partials` prepended to the arguments it receives.\n *\n * @private\n * @param {Function} func The function to wrap.\n * @param {number} bitmask The bitmask flags. See `createWrap` for more details.\n * @param {*} thisArg The `this` binding of `func`.\n * @param {Array} partials The arguments to prepend to those provided to\n * the new function.\n * @returns {Function} Returns the new wrapped function.\n */\n function createPartial(func, bitmask, thisArg, partials) {\n var isBind = bitmask & WRAP_BIND_FLAG,\n Ctor = createCtor(func);\n\n function wrapper() {\n var argsIndex = -1,\n argsLength = arguments.length,\n leftIndex = -1,\n leftLength = partials.length,\n args = Array(leftLength + argsLength),\n fn = (this && this !== root && this instanceof wrapper) ? Ctor : func;\n\n while (++leftIndex < leftLength) {\n args[leftIndex] = partials[leftIndex];\n }\n while (argsLength--) {\n args[leftIndex++] = arguments[++argsIndex];\n }\n return apply(fn, isBind ? thisArg : this, args);\n }\n return wrapper;\n }\n\n /**\n * Creates a `_.range` or `_.rangeRight` function.\n *\n * @private\n * @param {boolean} [fromRight] Specify iterating from right to left.\n * @returns {Function} Returns the new range function.\n */\n function createRange(fromRight) {\n return function(start, end, step) {\n if (step && typeof step != 'number' && isIterateeCall(start, end, step)) {\n end = step = undefined;\n }\n // Ensure the sign of `-0` is preserved.\n start = toFinite(start);\n if (end === undefined) {\n end = start;\n start = 0;\n } else {\n end = toFinite(end);\n }\n step = step === undefined ? (start < end ? 1 : -1) : toFinite(step);\n return baseRange(start, end, step, fromRight);\n };\n }\n\n /**\n * Creates a function that performs a relational operation on two values.\n *\n * @private\n * @param {Function} operator The function to perform the operation.\n * @returns {Function} Returns the new relational operation function.\n */\n function createRelationalOperation(operator) {\n return function(value, other) {\n if (!(typeof value == 'string' && typeof other == 'string')) {\n value = toNumber(value);\n other = toNumber(other);\n }\n return operator(value, other);\n };\n }\n\n /**\n * Creates a function that wraps `func` to continue currying.\n *\n * @private\n * @param {Function} func The function to wrap.\n * @param {number} bitmask The bitmask flags. See `createWrap` for more details.\n * @param {Function} wrapFunc The function to create the `func` wrapper.\n * @param {*} placeholder The placeholder value.\n * @param {*} [thisArg] The `this` binding of `func`.\n * @param {Array} [partials] The arguments to prepend to those provided to\n * the new function.\n * @param {Array} [holders] The `partials` placeholder indexes.\n * @param {Array} [argPos] The argument positions of the new function.\n * @param {number} [ary] The arity cap of `func`.\n * @param {number} [arity] The arity of `func`.\n * @returns {Function} Returns the new wrapped function.\n */\n function createRecurry(func, bitmask, wrapFunc, placeholder, thisArg, partials, holders, argPos, ary, arity) {\n var isCurry = bitmask & WRAP_CURRY_FLAG,\n newHolders = isCurry ? holders : undefined,\n newHoldersRight = isCurry ? undefined : holders,\n newPartials = isCurry ? partials : undefined,\n newPartialsRight = isCurry ? undefined : partials;\n\n bitmask |= (isCurry ? WRAP_PARTIAL_FLAG : WRAP_PARTIAL_RIGHT_FLAG);\n bitmask &= ~(isCurry ? WRAP_PARTIAL_RIGHT_FLAG : WRAP_PARTIAL_FLAG);\n\n if (!(bitmask & WRAP_CURRY_BOUND_FLAG)) {\n bitmask &= ~(WRAP_BIND_FLAG | WRAP_BIND_KEY_FLAG);\n }\n var newData = [\n func, bitmask, thisArg, newPartials, newHolders, newPartialsRight,\n newHoldersRight, argPos, ary, arity\n ];\n\n var result = wrapFunc.apply(undefined, newData);\n if (isLaziable(func)) {\n setData(result, newData);\n }\n result.placeholder = placeholder;\n return setWrapToString(result, func, bitmask);\n }\n\n /**\n * Creates a function like `_.round`.\n *\n * @private\n * @param {string} methodName The name of the `Math` method to use when rounding.\n * @returns {Function} Returns the new round function.\n */\n function createRound(methodName) {\n var func = Math[methodName];\n return function(number, precision) {\n number = toNumber(number);\n precision = precision == null ? 0 : nativeMin(toInteger(precision), 292);\n if (precision && nativeIsFinite(number)) {\n // Shift with exponential notation to avoid floating-point issues.\n // See [MDN](https://mdn.io/round#Examples) for more details.\n var pair = (toString(number) + 'e').split('e'),\n value = func(pair[0] + 'e' + (+pair[1] + precision));\n\n pair = (toString(value) + 'e').split('e');\n return +(pair[0] + 'e' + (+pair[1] - precision));\n }\n return func(number);\n };\n }\n\n /**\n * Creates a set object of `values`.\n *\n * @private\n * @param {Array} values The values to add to the set.\n * @returns {Object} Returns the new set.\n */\n var createSet = !(Set && (1 / setToArray(new Set([,-0]))[1]) == INFINITY) ? noop : function(values) {\n return new Set(values);\n };\n\n /**\n * Creates a `_.toPairs` or `_.toPairsIn` function.\n *\n * @private\n * @param {Function} keysFunc The function to get the keys of a given object.\n * @returns {Function} Returns the new pairs function.\n */\n function createToPairs(keysFunc) {\n return function(object) {\n var tag = getTag(object);\n if (tag == mapTag) {\n return mapToArray(object);\n }\n if (tag == setTag) {\n return setToPairs(object);\n }\n return baseToPairs(object, keysFunc(object));\n };\n }\n\n /**\n * Creates a function that either curries or invokes `func` with optional\n * `this` binding and partially applied arguments.\n *\n * @private\n * @param {Function|string} func The function or method name to wrap.\n * @param {number} bitmask The bitmask flags.\n * 1 - `_.bind`\n * 2 - `_.bindKey`\n * 4 - `_.curry` or `_.curryRight` of a bound function\n * 8 - `_.curry`\n * 16 - `_.curryRight`\n * 32 - `_.partial`\n * 64 - `_.partialRight`\n * 128 - `_.rearg`\n * 256 - `_.ary`\n * 512 - `_.flip`\n * @param {*} [thisArg] The `this` binding of `func`.\n * @param {Array} [partials] The arguments to be partially applied.\n * @param {Array} [holders] The `partials` placeholder indexes.\n * @param {Array} [argPos] The argument positions of the new function.\n * @param {number} [ary] The arity cap of `func`.\n * @param {number} [arity] The arity of `func`.\n * @returns {Function} Returns the new wrapped function.\n */\n function createWrap(func, bitmask, thisArg, partials, holders, argPos, ary, arity) {\n var isBindKey = bitmask & WRAP_BIND_KEY_FLAG;\n if (!isBindKey && typeof func != 'function') {\n throw new TypeError(FUNC_ERROR_TEXT);\n }\n var length = partials ? partials.length : 0;\n if (!length) {\n bitmask &= ~(WRAP_PARTIAL_FLAG | WRAP_PARTIAL_RIGHT_FLAG);\n partials = holders = undefined;\n }\n ary = ary === undefined ? ary : nativeMax(toInteger(ary), 0);\n arity = arity === undefined ? arity : toInteger(arity);\n length -= holders ? holders.length : 0;\n\n if (bitmask & WRAP_PARTIAL_RIGHT_FLAG) {\n var partialsRight = partials,\n holdersRight = holders;\n\n partials = holders = undefined;\n }\n var data = isBindKey ? undefined : getData(func);\n\n var newData = [\n func, bitmask, thisArg, partials, holders, partialsRight, holdersRight,\n argPos, ary, arity\n ];\n\n if (data) {\n mergeData(newData, data);\n }\n func = newData[0];\n bitmask = newData[1];\n thisArg = newData[2];\n partials = newData[3];\n holders = newData[4];\n arity = newData[9] = newData[9] === undefined\n ? (isBindKey ? 0 : func.length)\n : nativeMax(newData[9] - length, 0);\n\n if (!arity && bitmask & (WRAP_CURRY_FLAG | WRAP_CURRY_RIGHT_FLAG)) {\n bitmask &= ~(WRAP_CURRY_FLAG | WRAP_CURRY_RIGHT_FLAG);\n }\n if (!bitmask || bitmask == WRAP_BIND_FLAG) {\n var result = createBind(func, bitmask, thisArg);\n } else if (bitmask == WRAP_CURRY_FLAG || bitmask == WRAP_CURRY_RIGHT_FLAG) {\n result = createCurry(func, bitmask, arity);\n } else if ((bitmask == WRAP_PARTIAL_FLAG || bitmask == (WRAP_BIND_FLAG | WRAP_PARTIAL_FLAG)) && !holders.length) {\n result = createPartial(func, bitmask, thisArg, partials);\n } else {\n result = createHybrid.apply(undefined, newData);\n }\n var setter = data ? baseSetData : setData;\n return setWrapToString(setter(result, newData), func, bitmask);\n }\n\n /**\n * Used by `_.defaults` to customize its `_.assignIn` use to assign properties\n * of source objects to the destination object for all destination properties\n * that resolve to `undefined`.\n *\n * @private\n * @param {*} objValue The destination value.\n * @param {*} srcValue The source value.\n * @param {string} key The key of the property to assign.\n * @param {Object} object The parent object of `objValue`.\n * @returns {*} Returns the value to assign.\n */\n function customDefaultsAssignIn(objValue, srcValue, key, object) {\n if (objValue === undefined ||\n (eq(objValue, objectProto[key]) && !hasOwnProperty.call(object, key))) {\n return srcValue;\n }\n return objValue;\n }\n\n /**\n * Used by `_.defaultsDeep` to customize its `_.merge` use to merge source\n * objects into destination objects that are passed thru.\n *\n * @private\n * @param {*} objValue The destination value.\n * @param {*} srcValue The source value.\n * @param {string} key The key of the property to merge.\n * @param {Object} object The parent object of `objValue`.\n * @param {Object} source The parent object of `srcValue`.\n * @param {Object} [stack] Tracks traversed source values and their merged\n * counterparts.\n * @returns {*} Returns the value to assign.\n */\n function customDefaultsMerge(objValue, srcValue, key, object, source, stack) {\n if (isObject(objValue) && isObject(srcValue)) {\n // Recursively merge objects and arrays (susceptible to call stack limits).\n stack.set(srcValue, objValue);\n baseMerge(objValue, srcValue, undefined, customDefaultsMerge, stack);\n stack['delete'](srcValue);\n }\n return objValue;\n }\n\n /**\n * Used by `_.omit` to customize its `_.cloneDeep` use to only clone plain\n * objects.\n *\n * @private\n * @param {*} value The value to inspect.\n * @param {string} key The key of the property to inspect.\n * @returns {*} Returns the uncloned value or `undefined` to defer cloning to `_.cloneDeep`.\n */\n function customOmitClone(value) {\n return isPlainObject(value) ? undefined : value;\n }\n\n /**\n * A specialized version of `baseIsEqualDeep` for arrays with support for\n * partial deep comparisons.\n *\n * @private\n * @param {Array} array The array to compare.\n * @param {Array} other The other array to compare.\n * @param {number} bitmask The bitmask flags. See `baseIsEqual` for more details.\n * @param {Function} customizer The function to customize comparisons.\n * @param {Function} equalFunc The function to determine equivalents of values.\n * @param {Object} stack Tracks traversed `array` and `other` objects.\n * @returns {boolean} Returns `true` if the arrays are equivalent, else `false`.\n */\n function equalArrays(array, other, bitmask, customizer, equalFunc, stack) {\n var isPartial = bitmask & COMPARE_PARTIAL_FLAG,\n arrLength = array.length,\n othLength = other.length;\n\n if (arrLength != othLength && !(isPartial && othLength > arrLength)) {\n return false;\n }\n // Check that cyclic values are equal.\n var arrStacked = stack.get(array);\n var othStacked = stack.get(other);\n if (arrStacked && othStacked) {\n return arrStacked == other && othStacked == array;\n }\n var index = -1,\n result = true,\n seen = (bitmask & COMPARE_UNORDERED_FLAG) ? new SetCache : undefined;\n\n stack.set(array, other);\n stack.set(other, array);\n\n // Ignore non-index properties.\n while (++index < arrLength) {\n var arrValue = array[index],\n othValue = other[index];\n\n if (customizer) {\n var compared = isPartial\n ? customizer(othValue, arrValue, index, other, array, stack)\n : customizer(arrValue, othValue, index, array, other, stack);\n }\n if (compared !== undefined) {\n if (compared) {\n continue;\n }\n result = false;\n break;\n }\n // Recursively compare arrays (susceptible to call stack limits).\n if (seen) {\n if (!arraySome(other, function(othValue, othIndex) {\n if (!cacheHas(seen, othIndex) &&\n (arrValue === othValue || equalFunc(arrValue, othValue, bitmask, customizer, stack))) {\n return seen.push(othIndex);\n }\n })) {\n result = false;\n break;\n }\n } else if (!(\n arrValue === othValue ||\n equalFunc(arrValue, othValue, bitmask, customizer, stack)\n )) {\n result = false;\n break;\n }\n }\n stack['delete'](array);\n stack['delete'](other);\n return result;\n }\n\n /**\n * A specialized version of `baseIsEqualDeep` for comparing objects of\n * the same `toStringTag`.\n *\n * **Note:** This function only supports comparing values with tags of\n * `Boolean`, `Date`, `Error`, `Number`, `RegExp`, or `String`.\n *\n * @private\n * @param {Object} object The object to compare.\n * @param {Object} other The other object to compare.\n * @param {string} tag The `toStringTag` of the objects to compare.\n * @param {number} bitmask The bitmask flags. See `baseIsEqual` for more details.\n * @param {Function} customizer The function to customize comparisons.\n * @param {Function} equalFunc The function to determine equivalents of values.\n * @param {Object} stack Tracks traversed `object` and `other` objects.\n * @returns {boolean} Returns `true` if the objects are equivalent, else `false`.\n */\n function equalByTag(object, other, tag, bitmask, customizer, equalFunc, stack) {\n switch (tag) {\n case dataViewTag:\n if ((object.byteLength != other.byteLength) ||\n (object.byteOffset != other.byteOffset)) {\n return false;\n }\n object = object.buffer;\n other = other.buffer;\n\n case arrayBufferTag:\n if ((object.byteLength != other.byteLength) ||\n !equalFunc(new Uint8Array(object), new Uint8Array(other))) {\n return false;\n }\n return true;\n\n case boolTag:\n case dateTag:\n case numberTag:\n // Coerce booleans to `1` or `0` and dates to milliseconds.\n // Invalid dates are coerced to `NaN`.\n return eq(+object, +other);\n\n case errorTag:\n return object.name == other.name && object.message == other.message;\n\n case regexpTag:\n case stringTag:\n // Coerce regexes to strings and treat strings, primitives and objects,\n // as equal. See http://www.ecma-international.org/ecma-262/7.0/#sec-regexp.prototype.tostring\n // for more details.\n return object == (other + '');\n\n case mapTag:\n var convert = mapToArray;\n\n case setTag:\n var isPartial = bitmask & COMPARE_PARTIAL_FLAG;\n convert || (convert = setToArray);\n\n if (object.size != other.size && !isPartial) {\n return false;\n }\n // Assume cyclic values are equal.\n var stacked = stack.get(object);\n if (stacked) {\n return stacked == other;\n }\n bitmask |= COMPARE_UNORDERED_FLAG;\n\n // Recursively compare objects (susceptible to call stack limits).\n stack.set(object, other);\n var result = equalArrays(convert(object), convert(other), bitmask, customizer, equalFunc, stack);\n stack['delete'](object);\n return result;\n\n case symbolTag:\n if (symbolValueOf) {\n return symbolValueOf.call(object) == symbolValueOf.call(other);\n }\n }\n return false;\n }\n\n /**\n * A specialized version of `baseIsEqualDeep` for objects with support for\n * partial deep comparisons.\n *\n * @private\n * @param {Object} object The object to compare.\n * @param {Object} other The other object to compare.\n * @param {number} bitmask The bitmask flags. See `baseIsEqual` for more details.\n * @param {Function} customizer The function to customize comparisons.\n * @param {Function} equalFunc The function to determine equivalents of values.\n * @param {Object} stack Tracks traversed `object` and `other` objects.\n * @returns {boolean} Returns `true` if the objects are equivalent, else `false`.\n */\n function equalObjects(object, other, bitmask, customizer, equalFunc, stack) {\n var isPartial = bitmask & COMPARE_PARTIAL_FLAG,\n objProps = getAllKeys(object),\n objLength = objProps.length,\n othProps = getAllKeys(other),\n othLength = othProps.length;\n\n if (objLength != othLength && !isPartial) {\n return false;\n }\n var index = objLength;\n while (index--) {\n var key = objProps[index];\n if (!(isPartial ? key in other : hasOwnProperty.call(other, key))) {\n return false;\n }\n }\n // Check that cyclic values are equal.\n var objStacked = stack.get(object);\n var othStacked = stack.get(other);\n if (objStacked && othStacked) {\n return objStacked == other && othStacked == object;\n }\n var result = true;\n stack.set(object, other);\n stack.set(other, object);\n\n var skipCtor = isPartial;\n while (++index < objLength) {\n key = objProps[index];\n var objValue = object[key],\n othValue = other[key];\n\n if (customizer) {\n var compared = isPartial\n ? customizer(othValue, objValue, key, other, object, stack)\n : customizer(objValue, othValue, key, object, other, stack);\n }\n // Recursively compare objects (susceptible to call stack limits).\n if (!(compared === undefined\n ? (objValue === othValue || equalFunc(objValue, othValue, bitmask, customizer, stack))\n : compared\n )) {\n result = false;\n break;\n }\n skipCtor || (skipCtor = key == 'constructor');\n }\n if (result && !skipCtor) {\n var objCtor = object.constructor,\n othCtor = other.constructor;\n\n // Non `Object` object instances with different constructors are not equal.\n if (objCtor != othCtor &&\n ('constructor' in object && 'constructor' in other) &&\n !(typeof objCtor == 'function' && objCtor instanceof objCtor &&\n typeof othCtor == 'function' && othCtor instanceof othCtor)) {\n result = false;\n }\n }\n stack['delete'](object);\n stack['delete'](other);\n return result;\n }\n\n /**\n * A specialized version of `baseRest` which flattens the rest array.\n *\n * @private\n * @param {Function} func The function to apply a rest parameter to.\n * @returns {Function} Returns the new function.\n */\n function flatRest(func) {\n return setToString(overRest(func, undefined, flatten), func + '');\n }\n\n /**\n * Creates an array of own enumerable property names and symbols of `object`.\n *\n * @private\n * @param {Object} object The object to query.\n * @returns {Array} Returns the array of property names and symbols.\n */\n function getAllKeys(object) {\n return baseGetAllKeys(object, keys, getSymbols);\n }\n\n /**\n * Creates an array of own and inherited enumerable property names and\n * symbols of `object`.\n *\n * @private\n * @param {Object} object The object to query.\n * @returns {Array} Returns the array of property names and symbols.\n */\n function getAllKeysIn(object) {\n return baseGetAllKeys(object, keysIn, getSymbolsIn);\n }\n\n /**\n * Gets metadata for `func`.\n *\n * @private\n * @param {Function} func The function to query.\n * @returns {*} Returns the metadata for `func`.\n */\n var getData = !metaMap ? noop : function(func) {\n return metaMap.get(func);\n };\n\n /**\n * Gets the name of `func`.\n *\n * @private\n * @param {Function} func The function to query.\n * @returns {string} Returns the function name.\n */\n function getFuncName(func) {\n var result = (func.name + ''),\n array = realNames[result],\n length = hasOwnProperty.call(realNames, result) ? array.length : 0;\n\n while (length--) {\n var data = array[length],\n otherFunc = data.func;\n if (otherFunc == null || otherFunc == func) {\n return data.name;\n }\n }\n return result;\n }\n\n /**\n * Gets the argument placeholder value for `func`.\n *\n * @private\n * @param {Function} func The function to inspect.\n * @returns {*} Returns the placeholder value.\n */\n function getHolder(func) {\n var object = hasOwnProperty.call(lodash, 'placeholder') ? lodash : func;\n return object.placeholder;\n }\n\n /**\n * Gets the appropriate \"iteratee\" function. If `_.iteratee` is customized,\n * this function returns the custom method, otherwise it returns `baseIteratee`.\n * If arguments are provided, the chosen function is invoked with them and\n * its result is returned.\n *\n * @private\n * @param {*} [value] The value to convert to an iteratee.\n * @param {number} [arity] The arity of the created iteratee.\n * @returns {Function} Returns the chosen function or its result.\n */\n function getIteratee() {\n var result = lodash.iteratee || iteratee;\n result = result === iteratee ? baseIteratee : result;\n return arguments.length ? result(arguments[0], arguments[1]) : result;\n }\n\n /**\n * Gets the data for `map`.\n *\n * @private\n * @param {Object} map The map to query.\n * @param {string} key The reference key.\n * @returns {*} Returns the map data.\n */\n function getMapData(map, key) {\n var data = map.__data__;\n return isKeyable(key)\n ? data[typeof key == 'string' ? 'string' : 'hash']\n : data.map;\n }\n\n /**\n * Gets the property names, values, and compare flags of `object`.\n *\n * @private\n * @param {Object} object The object to query.\n * @returns {Array} Returns the match data of `object`.\n */\n function getMatchData(object) {\n var result = keys(object),\n length = result.length;\n\n while (length--) {\n var key = result[length],\n value = object[key];\n\n result[length] = [key, value, isStrictComparable(value)];\n }\n return result;\n }\n\n /**\n * Gets the native function at `key` of `object`.\n *\n * @private\n * @param {Object} object The object to query.\n * @param {string} key The key of the method to get.\n * @returns {*} Returns the function if it's native, else `undefined`.\n */\n function getNative(object, key) {\n var value = getValue(object, key);\n return baseIsNative(value) ? value : undefined;\n }\n\n /**\n * A specialized version of `baseGetTag` which ignores `Symbol.toStringTag` values.\n *\n * @private\n * @param {*} value The value to query.\n * @returns {string} Returns the raw `toStringTag`.\n */\n function getRawTag(value) {\n var isOwn = hasOwnProperty.call(value, symToStringTag),\n tag = value[symToStringTag];\n\n try {\n value[symToStringTag] = undefined;\n var unmasked = true;\n } catch (e) {}\n\n var result = nativeObjectToString.call(value);\n if (unmasked) {\n if (isOwn) {\n value[symToStringTag] = tag;\n } else {\n delete value[symToStringTag];\n }\n }\n return result;\n }\n\n /**\n * Creates an array of the own enumerable symbols of `object`.\n *\n * @private\n * @param {Object} object The object to query.\n * @returns {Array} Returns the array of symbols.\n */\n var getSymbols = !nativeGetSymbols ? stubArray : function(object) {\n if (object == null) {\n return [];\n }\n object = Object(object);\n return arrayFilter(nativeGetSymbols(object), function(symbol) {\n return propertyIsEnumerable.call(object, symbol);\n });\n };\n\n /**\n * Creates an array of the own and inherited enumerable symbols of `object`.\n *\n * @private\n * @param {Object} object The object to query.\n * @returns {Array} Returns the array of symbols.\n */\n var getSymbolsIn = !nativeGetSymbols ? stubArray : function(object) {\n var result = [];\n while (object) {\n arrayPush(result, getSymbols(object));\n object = getPrototype(object);\n }\n return result;\n };\n\n /**\n * Gets the `toStringTag` of `value`.\n *\n * @private\n * @param {*} value The value to query.\n * @returns {string} Returns the `toStringTag`.\n */\n var getTag = baseGetTag;\n\n // Fallback for data views, maps, sets, and weak maps in IE 11 and promises in Node.js < 6.\n if ((DataView && getTag(new DataView(new ArrayBuffer(1))) != dataViewTag) ||\n (Map && getTag(new Map) != mapTag) ||\n (Promise && getTag(Promise.resolve()) != promiseTag) ||\n (Set && getTag(new Set) != setTag) ||\n (WeakMap && getTag(new WeakMap) != weakMapTag)) {\n getTag = function(value) {\n var result = baseGetTag(value),\n Ctor = result == objectTag ? value.constructor : undefined,\n ctorString = Ctor ? toSource(Ctor) : '';\n\n if (ctorString) {\n switch (ctorString) {\n case dataViewCtorString: return dataViewTag;\n case mapCtorString: return mapTag;\n case promiseCtorString: return promiseTag;\n case setCtorString: return setTag;\n case weakMapCtorString: return weakMapTag;\n }\n }\n return result;\n };\n }\n\n /**\n * Gets the view, applying any `transforms` to the `start` and `end` positions.\n *\n * @private\n * @param {number} start The start of the view.\n * @param {number} end The end of the view.\n * @param {Array} transforms The transformations to apply to the view.\n * @returns {Object} Returns an object containing the `start` and `end`\n * positions of the view.\n */\n function getView(start, end, transforms) {\n var index = -1,\n length = transforms.length;\n\n while (++index < length) {\n var data = transforms[index],\n size = data.size;\n\n switch (data.type) {\n case 'drop': start += size; break;\n case 'dropRight': end -= size; break;\n case 'take': end = nativeMin(end, start + size); break;\n case 'takeRight': start = nativeMax(start, end - size); break;\n }\n }\n return { 'start': start, 'end': end };\n }\n\n /**\n * Extracts wrapper details from the `source` body comment.\n *\n * @private\n * @param {string} source The source to inspect.\n * @returns {Array} Returns the wrapper details.\n */\n function getWrapDetails(source) {\n var match = source.match(reWrapDetails);\n return match ? match[1].split(reSplitDetails) : [];\n }\n\n /**\n * Checks if `path` exists on `object`.\n *\n * @private\n * @param {Object} object The object to query.\n * @param {Array|string} path The path to check.\n * @param {Function} hasFunc The function to check properties.\n * @returns {boolean} Returns `true` if `path` exists, else `false`.\n */\n function hasPath(object, path, hasFunc) {\n path = castPath(path, object);\n\n var index = -1,\n length = path.length,\n result = false;\n\n while (++index < length) {\n var key = toKey(path[index]);\n if (!(result = object != null && hasFunc(object, key))) {\n break;\n }\n object = object[key];\n }\n if (result || ++index != length) {\n return result;\n }\n length = object == null ? 0 : object.length;\n return !!length && isLength(length) && isIndex(key, length) &&\n (isArray(object) || isArguments(object));\n }\n\n /**\n * Initializes an array clone.\n *\n * @private\n * @param {Array} array The array to clone.\n * @returns {Array} Returns the initialized clone.\n */\n function initCloneArray(array) {\n var length = array.length,\n result = new array.constructor(length);\n\n // Add properties assigned by `RegExp#exec`.\n if (length && typeof array[0] == 'string' && hasOwnProperty.call(array, 'index')) {\n result.index = array.index;\n result.input = array.input;\n }\n return result;\n }\n\n /**\n * Initializes an object clone.\n *\n * @private\n * @param {Object} object The object to clone.\n * @returns {Object} Returns the initialized clone.\n */\n function initCloneObject(object) {\n return (typeof object.constructor == 'function' && !isPrototype(object))\n ? baseCreate(getPrototype(object))\n : {};\n }\n\n /**\n * Initializes an object clone based on its `toStringTag`.\n *\n * **Note:** This function only supports cloning values with tags of\n * `Boolean`, `Date`, `Error`, `Map`, `Number`, `RegExp`, `Set`, or `String`.\n *\n * @private\n * @param {Object} object The object to clone.\n * @param {string} tag The `toStringTag` of the object to clone.\n * @param {boolean} [isDeep] Specify a deep clone.\n * @returns {Object} Returns the initialized clone.\n */\n function initCloneByTag(object, tag, isDeep) {\n var Ctor = object.constructor;\n switch (tag) {\n case arrayBufferTag:\n return cloneArrayBuffer(object);\n\n case boolTag:\n case dateTag:\n return new Ctor(+object);\n\n case dataViewTag:\n return cloneDataView(object, isDeep);\n\n case float32Tag: case float64Tag:\n case int8Tag: case int16Tag: case int32Tag:\n case uint8Tag: case uint8ClampedTag: case uint16Tag: case uint32Tag:\n return cloneTypedArray(object, isDeep);\n\n case mapTag:\n return new Ctor;\n\n case numberTag:\n case stringTag:\n return new Ctor(object);\n\n case regexpTag:\n return cloneRegExp(object);\n\n case setTag:\n return new Ctor;\n\n case symbolTag:\n return cloneSymbol(object);\n }\n }\n\n /**\n * Inserts wrapper `details` in a comment at the top of the `source` body.\n *\n * @private\n * @param {string} source The source to modify.\n * @returns {Array} details The details to insert.\n * @returns {string} Returns the modified source.\n */\n function insertWrapDetails(source, details) {\n var length = details.length;\n if (!length) {\n return source;\n }\n var lastIndex = length - 1;\n details[lastIndex] = (length > 1 ? '& ' : '') + details[lastIndex];\n details = details.join(length > 2 ? ', ' : ' ');\n return source.replace(reWrapComment, '{\\n/* [wrapped with ' + details + '] */\\n');\n }\n\n /**\n * Checks if `value` is a flattenable `arguments` object or array.\n *\n * @private\n * @param {*} value The value to check.\n * @returns {boolean} Returns `true` if `value` is flattenable, else `false`.\n */\n function isFlattenable(value) {\n return isArray(value) || isArguments(value) ||\n !!(spreadableSymbol && value && value[spreadableSymbol]);\n }\n\n /**\n * Checks if `value` is a valid array-like index.\n *\n * @private\n * @param {*} value The value to check.\n * @param {number} [length=MAX_SAFE_INTEGER] The upper bounds of a valid index.\n * @returns {boolean} Returns `true` if `value` is a valid index, else `false`.\n */\n function isIndex(value, length) {\n var type = typeof value;\n length = length == null ? MAX_SAFE_INTEGER : length;\n\n return !!length &&\n (type == 'number' ||\n (type != 'symbol' && reIsUint.test(value))) &&\n (value > -1 && value % 1 == 0 && value < length);\n }\n\n /**\n * Checks if the given arguments are from an iteratee call.\n *\n * @private\n * @param {*} value The potential iteratee value argument.\n * @param {*} index The potential iteratee index or key argument.\n * @param {*} object The potential iteratee object argument.\n * @returns {boolean} Returns `true` if the arguments are from an iteratee call,\n * else `false`.\n */\n function isIterateeCall(value, index, object) {\n if (!isObject(object)) {\n return false;\n }\n var type = typeof index;\n if (type == 'number'\n ? (isArrayLike(object) && isIndex(index, object.length))\n : (type == 'string' && index in object)\n ) {\n return eq(object[index], value);\n }\n return false;\n }\n\n /**\n * Checks if `value` is a property name and not a property path.\n *\n * @private\n * @param {*} value The value to check.\n * @param {Object} [object] The object to query keys on.\n * @returns {boolean} Returns `true` if `value` is a property name, else `false`.\n */\n function isKey(value, object) {\n if (isArray(value)) {\n return false;\n }\n var type = typeof value;\n if (type == 'number' || type == 'symbol' || type == 'boolean' ||\n value == null || isSymbol(value)) {\n return true;\n }\n return reIsPlainProp.test(value) || !reIsDeepProp.test(value) ||\n (object != null && value in Object(object));\n }\n\n /**\n * Checks if `value` is suitable for use as unique object key.\n *\n * @private\n * @param {*} value The value to check.\n * @returns {boolean} Returns `true` if `value` is suitable, else `false`.\n */\n function isKeyable(value) {\n var type = typeof value;\n return (type == 'string' || type == 'number' || type == 'symbol' || type == 'boolean')\n ? (value !== '__proto__')\n : (value === null);\n }\n\n /**\n * Checks if `func` has a lazy counterpart.\n *\n * @private\n * @param {Function} func The function to check.\n * @returns {boolean} Returns `true` if `func` has a lazy counterpart,\n * else `false`.\n */\n function isLaziable(func) {\n var funcName = getFuncName(func),\n other = lodash[funcName];\n\n if (typeof other != 'function' || !(funcName in LazyWrapper.prototype)) {\n return false;\n }\n if (func === other) {\n return true;\n }\n var data = getData(other);\n return !!data && func === data[0];\n }\n\n /**\n * Checks if `func` has its source masked.\n *\n * @private\n * @param {Function} func The function to check.\n * @returns {boolean} Returns `true` if `func` is masked, else `false`.\n */\n function isMasked(func) {\n return !!maskSrcKey && (maskSrcKey in func);\n }\n\n /**\n * Checks if `func` is capable of being masked.\n *\n * @private\n * @param {*} value The value to check.\n * @returns {boolean} Returns `true` if `func` is maskable, else `false`.\n */\n var isMaskable = coreJsData ? isFunction : stubFalse;\n\n /**\n * Checks if `value` is likely a prototype object.\n *\n * @private\n * @param {*} value The value to check.\n * @returns {boolean} Returns `true` if `value` is a prototype, else `false`.\n */\n function isPrototype(value) {\n var Ctor = value && value.constructor,\n proto = (typeof Ctor == 'function' && Ctor.prototype) || objectProto;\n\n return value === proto;\n }\n\n /**\n * Checks if `value` is suitable for strict equality comparisons, i.e. `===`.\n *\n * @private\n * @param {*} value The value to check.\n * @returns {boolean} Returns `true` if `value` if suitable for strict\n * equality comparisons, else `false`.\n */\n function isStrictComparable(value) {\n return value === value && !isObject(value);\n }\n\n /**\n * A specialized version of `matchesProperty` for source values suitable\n * for strict equality comparisons, i.e. `===`.\n *\n * @private\n * @param {string} key The key of the property to get.\n * @param {*} srcValue The value to match.\n * @returns {Function} Returns the new spec function.\n */\n function matchesStrictComparable(key, srcValue) {\n return function(object) {\n if (object == null) {\n return false;\n }\n return object[key] === srcValue &&\n (srcValue !== undefined || (key in Object(object)));\n };\n }\n\n /**\n * A specialized version of `_.memoize` which clears the memoized function's\n * cache when it exceeds `MAX_MEMOIZE_SIZE`.\n *\n * @private\n * @param {Function} func The function to have its output memoized.\n * @returns {Function} Returns the new memoized function.\n */\n function memoizeCapped(func) {\n var result = memoize(func, function(key) {\n if (cache.size === MAX_MEMOIZE_SIZE) {\n cache.clear();\n }\n return key;\n });\n\n var cache = result.cache;\n return result;\n }\n\n /**\n * Merges the function metadata of `source` into `data`.\n *\n * Merging metadata reduces the number of wrappers used to invoke a function.\n * This is possible because methods like `_.bind`, `_.curry`, and `_.partial`\n * may be applied regardless of execution order. Methods like `_.ary` and\n * `_.rearg` modify function arguments, making the order in which they are\n * executed important, preventing the merging of metadata. However, we make\n * an exception for a safe combined case where curried functions have `_.ary`\n * and or `_.rearg` applied.\n *\n * @private\n * @param {Array} data The destination metadata.\n * @param {Array} source The source metadata.\n * @returns {Array} Returns `data`.\n */\n function mergeData(data, source) {\n var bitmask = data[1],\n srcBitmask = source[1],\n newBitmask = bitmask | srcBitmask,\n isCommon = newBitmask < (WRAP_BIND_FLAG | WRAP_BIND_KEY_FLAG | WRAP_ARY_FLAG);\n\n var isCombo =\n ((srcBitmask == WRAP_ARY_FLAG) && (bitmask == WRAP_CURRY_FLAG)) ||\n ((srcBitmask == WRAP_ARY_FLAG) && (bitmask == WRAP_REARG_FLAG) && (data[7].length <= source[8])) ||\n ((srcBitmask == (WRAP_ARY_FLAG | WRAP_REARG_FLAG)) && (source[7].length <= source[8]) && (bitmask == WRAP_CURRY_FLAG));\n\n // Exit early if metadata can't be merged.\n if (!(isCommon || isCombo)) {\n return data;\n }\n // Use source `thisArg` if available.\n if (srcBitmask & WRAP_BIND_FLAG) {\n data[2] = source[2];\n // Set when currying a bound function.\n newBitmask |= bitmask & WRAP_BIND_FLAG ? 0 : WRAP_CURRY_BOUND_FLAG;\n }\n // Compose partial arguments.\n var value = source[3];\n if (value) {\n var partials = data[3];\n data[3] = partials ? composeArgs(partials, value, source[4]) : value;\n data[4] = partials ? replaceHolders(data[3], PLACEHOLDER) : source[4];\n }\n // Compose partial right arguments.\n value = source[5];\n if (value) {\n partials = data[5];\n data[5] = partials ? composeArgsRight(partials, value, source[6]) : value;\n data[6] = partials ? replaceHolders(data[5], PLACEHOLDER) : source[6];\n }\n // Use source `argPos` if available.\n value = source[7];\n if (value) {\n data[7] = value;\n }\n // Use source `ary` if it's smaller.\n if (srcBitmask & WRAP_ARY_FLAG) {\n data[8] = data[8] == null ? source[8] : nativeMin(data[8], source[8]);\n }\n // Use source `arity` if one is not provided.\n if (data[9] == null) {\n data[9] = source[9];\n }\n // Use source `func` and merge bitmasks.\n data[0] = source[0];\n data[1] = newBitmask;\n\n return data;\n }\n\n /**\n * This function is like\n * [`Object.keys`](http://ecma-international.org/ecma-262/7.0/#sec-object.keys)\n * except that it includes inherited enumerable properties.\n *\n * @private\n * @param {Object} object The object to query.\n * @returns {Array} Returns the array of property names.\n */\n function nativeKeysIn(object) {\n var result = [];\n if (object != null) {\n for (var key in Object(object)) {\n result.push(key);\n }\n }\n return result;\n }\n\n /**\n * Converts `value` to a string using `Object.prototype.toString`.\n *\n * @private\n * @param {*} value The value to convert.\n * @returns {string} Returns the converted string.\n */\n function objectToString(value) {\n return nativeObjectToString.call(value);\n }\n\n /**\n * A specialized version of `baseRest` which transforms the rest array.\n *\n * @private\n * @param {Function} func The function to apply a rest parameter to.\n * @param {number} [start=func.length-1] The start position of the rest parameter.\n * @param {Function} transform The rest array transform.\n * @returns {Function} Returns the new function.\n */\n function overRest(func, start, transform) {\n start = nativeMax(start === undefined ? (func.length - 1) : start, 0);\n return function() {\n var args = arguments,\n index = -1,\n length = nativeMax(args.length - start, 0),\n array = Array(length);\n\n while (++index < length) {\n array[index] = args[start + index];\n }\n index = -1;\n var otherArgs = Array(start + 1);\n while (++index < start) {\n otherArgs[index] = args[index];\n }\n otherArgs[start] = transform(array);\n return apply(func, this, otherArgs);\n };\n }\n\n /**\n * Gets the parent value at `path` of `object`.\n *\n * @private\n * @param {Object} object The object to query.\n * @param {Array} path The path to get the parent value of.\n * @returns {*} Returns the parent value.\n */\n function parent(object, path) {\n return path.length < 2 ? object : baseGet(object, baseSlice(path, 0, -1));\n }\n\n /**\n * Reorder `array` according to the specified indexes where the element at\n * the first index is assigned as the first element, the element at\n * the second index is assigned as the second element, and so on.\n *\n * @private\n * @param {Array} array The array to reorder.\n * @param {Array} indexes The arranged array indexes.\n * @returns {Array} Returns `array`.\n */\n function reorder(array, indexes) {\n var arrLength = array.length,\n length = nativeMin(indexes.length, arrLength),\n oldArray = copyArray(array);\n\n while (length--) {\n var index = indexes[length];\n array[length] = isIndex(index, arrLength) ? oldArray[index] : undefined;\n }\n return array;\n }\n\n /**\n * Gets the value at `key`, unless `key` is \"__proto__\" or \"constructor\".\n *\n * @private\n * @param {Object} object The object to query.\n * @param {string} key The key of the property to get.\n * @returns {*} Returns the property value.\n */\n function safeGet(object, key) {\n if (key === 'constructor' && typeof object[key] === 'function') {\n return;\n }\n\n if (key == '__proto__') {\n return;\n }\n\n return object[key];\n }\n\n /**\n * Sets metadata for `func`.\n *\n * **Note:** If this function becomes hot, i.e. is invoked a lot in a short\n * period of time, it will trip its breaker and transition to an identity\n * function to avoid garbage collection pauses in V8. See\n * [V8 issue 2070](https://bugs.chromium.org/p/v8/issues/detail?id=2070)\n * for more details.\n *\n * @private\n * @param {Function} func The function to associate metadata with.\n * @param {*} data The metadata.\n * @returns {Function} Returns `func`.\n */\n var setData = shortOut(baseSetData);\n\n /**\n * A simple wrapper around the global [`setTimeout`](https://mdn.io/setTimeout).\n *\n * @private\n * @param {Function} func The function to delay.\n * @param {number} wait The number of milliseconds to delay invocation.\n * @returns {number|Object} Returns the timer id or timeout object.\n */\n var setTimeout = ctxSetTimeout || function(func, wait) {\n return root.setTimeout(func, wait);\n };\n\n /**\n * Sets the `toString` method of `func` to return `string`.\n *\n * @private\n * @param {Function} func The function to modify.\n * @param {Function} string The `toString` result.\n * @returns {Function} Returns `func`.\n */\n var setToString = shortOut(baseSetToString);\n\n /**\n * Sets the `toString` method of `wrapper` to mimic the source of `reference`\n * with wrapper details in a comment at the top of the source body.\n *\n * @private\n * @param {Function} wrapper The function to modify.\n * @param {Function} reference The reference function.\n * @param {number} bitmask The bitmask flags. See `createWrap` for more details.\n * @returns {Function} Returns `wrapper`.\n */\n function setWrapToString(wrapper, reference, bitmask) {\n var source = (reference + '');\n return setToString(wrapper, insertWrapDetails(source, updateWrapDetails(getWrapDetails(source), bitmask)));\n }\n\n /**\n * Creates a function that'll short out and invoke `identity` instead\n * of `func` when it's called `HOT_COUNT` or more times in `HOT_SPAN`\n * milliseconds.\n *\n * @private\n * @param {Function} func The function to restrict.\n * @returns {Function} Returns the new shortable function.\n */\n function shortOut(func) {\n var count = 0,\n lastCalled = 0;\n\n return function() {\n var stamp = nativeNow(),\n remaining = HOT_SPAN - (stamp - lastCalled);\n\n lastCalled = stamp;\n if (remaining > 0) {\n if (++count >= HOT_COUNT) {\n return arguments[0];\n }\n } else {\n count = 0;\n }\n return func.apply(undefined, arguments);\n };\n }\n\n /**\n * A specialized version of `_.shuffle` which mutates and sets the size of `array`.\n *\n * @private\n * @param {Array} array The array to shuffle.\n * @param {number} [size=array.length] The size of `array`.\n * @returns {Array} Returns `array`.\n */\n function shuffleSelf(array, size) {\n var index = -1,\n length = array.length,\n lastIndex = length - 1;\n\n size = size === undefined ? length : size;\n while (++index < size) {\n var rand = baseRandom(index, lastIndex),\n value = array[rand];\n\n array[rand] = array[index];\n array[index] = value;\n }\n array.length = size;\n return array;\n }\n\n /**\n * Converts `string` to a property path array.\n *\n * @private\n * @param {string} string The string to convert.\n * @returns {Array} Returns the property path array.\n */\n var stringToPath = memoizeCapped(function(string) {\n var result = [];\n if (string.charCodeAt(0) === 46 /* . */) {\n result.push('');\n }\n string.replace(rePropName, function(match, number, quote, subString) {\n result.push(quote ? subString.replace(reEscapeChar, '$1') : (number || match));\n });\n return result;\n });\n\n /**\n * Converts `value` to a string key if it's not a string or symbol.\n *\n * @private\n * @param {*} value The value to inspect.\n * @returns {string|symbol} Returns the key.\n */\n function toKey(value) {\n if (typeof value == 'string' || isSymbol(value)) {\n return value;\n }\n var result = (value + '');\n return (result == '0' && (1 / value) == -INFINITY) ? '-0' : result;\n }\n\n /**\n * Converts `func` to its source code.\n *\n * @private\n * @param {Function} func The function to convert.\n * @returns {string} Returns the source code.\n */\n function toSource(func) {\n if (func != null) {\n try {\n return funcToString.call(func);\n } catch (e) {}\n try {\n return (func + '');\n } catch (e) {}\n }\n return '';\n }\n\n /**\n * Updates wrapper `details` based on `bitmask` flags.\n *\n * @private\n * @returns {Array} details The details to modify.\n * @param {number} bitmask The bitmask flags. See `createWrap` for more details.\n * @returns {Array} Returns `details`.\n */\n function updateWrapDetails(details, bitmask) {\n arrayEach(wrapFlags, function(pair) {\n var value = '_.' + pair[0];\n if ((bitmask & pair[1]) && !arrayIncludes(details, value)) {\n details.push(value);\n }\n });\n return details.sort();\n }\n\n /**\n * Creates a clone of `wrapper`.\n *\n * @private\n * @param {Object} wrapper The wrapper to clone.\n * @returns {Object} Returns the cloned wrapper.\n */\n function wrapperClone(wrapper) {\n if (wrapper instanceof LazyWrapper) {\n return wrapper.clone();\n }\n var result = new LodashWrapper(wrapper.__wrapped__, wrapper.__chain__);\n result.__actions__ = copyArray(wrapper.__actions__);\n result.__index__ = wrapper.__index__;\n result.__values__ = wrapper.__values__;\n return result;\n }\n\n /*------------------------------------------------------------------------*/\n\n /**\n * Creates an array of elements split into groups the length of `size`.\n * If `array` can't be split evenly, the final chunk will be the remaining\n * elements.\n *\n * @static\n * @memberOf _\n * @since 3.0.0\n * @category Array\n * @param {Array} array The array to process.\n * @param {number} [size=1] The length of each chunk\n * @param- {Object} [guard] Enables use as an iteratee for methods like `_.map`.\n * @returns {Array} Returns the new array of chunks.\n * @example\n *\n * _.chunk(['a', 'b', 'c', 'd'], 2);\n * // => [['a', 'b'], ['c', 'd']]\n *\n * _.chunk(['a', 'b', 'c', 'd'], 3);\n * // => [['a', 'b', 'c'], ['d']]\n */\n function chunk(array, size, guard) {\n if ((guard ? isIterateeCall(array, size, guard) : size === undefined)) {\n size = 1;\n } else {\n size = nativeMax(toInteger(size), 0);\n }\n var length = array == null ? 0 : array.length;\n if (!length || size < 1) {\n return [];\n }\n var index = 0,\n resIndex = 0,\n result = Array(nativeCeil(length / size));\n\n while (index < length) {\n result[resIndex++] = baseSlice(array, index, (index += size));\n }\n return result;\n }\n\n /**\n * Creates an array with all falsey values removed. The values `false`, `null`,\n * `0`, `\"\"`, `undefined`, and `NaN` are falsey.\n *\n * @static\n * @memberOf _\n * @since 0.1.0\n * @category Array\n * @param {Array} array The array to compact.\n * @returns {Array} Returns the new array of filtered values.\n * @example\n *\n * _.compact([0, 1, false, 2, '', 3]);\n * // => [1, 2, 3]\n */\n function compact(array) {\n var index = -1,\n length = array == null ? 0 : array.length,\n resIndex = 0,\n result = [];\n\n while (++index < length) {\n var value = array[index];\n if (value) {\n result[resIndex++] = value;\n }\n }\n return result;\n }\n\n /**\n * Creates a new array concatenating `array` with any additional arrays\n * and/or values.\n *\n * @static\n * @memberOf _\n * @since 4.0.0\n * @category Array\n * @param {Array} array The array to concatenate.\n * @param {...*} [values] The values to concatenate.\n * @returns {Array} Returns the new concatenated array.\n * @example\n *\n * var array = [1];\n * var other = _.concat(array, 2, [3], [[4]]);\n *\n * console.log(other);\n * // => [1, 2, 3, [4]]\n *\n * console.log(array);\n * // => [1]\n */\n function concat() {\n var length = arguments.length;\n if (!length) {\n return [];\n }\n var args = Array(length - 1),\n array = arguments[0],\n index = length;\n\n while (index--) {\n args[index - 1] = arguments[index];\n }\n return arrayPush(isArray(array) ? copyArray(array) : [array], baseFlatten(args, 1));\n }\n\n /**\n * Creates an array of `array` values not included in the other given arrays\n * using [`SameValueZero`](http://ecma-international.org/ecma-262/7.0/#sec-samevaluezero)\n * for equality comparisons. The order and references of result values are\n * determined by the first array.\n *\n * **Note:** Unlike `_.pullAll`, this method returns a new array.\n *\n * @static\n * @memberOf _\n * @since 0.1.0\n * @category Array\n * @param {Array} array The array to inspect.\n * @param {...Array} [values] The values to exclude.\n * @returns {Array} Returns the new array of filtered values.\n * @see _.without, _.xor\n * @example\n *\n * _.difference([2, 1], [2, 3]);\n * // => [1]\n */\n var difference = baseRest(function(array, values) {\n return isArrayLikeObject(array)\n ? baseDifference(array, baseFlatten(values, 1, isArrayLikeObject, true))\n : [];\n });\n\n /**\n * This method is like `_.difference` except that it accepts `iteratee` which\n * is invoked for each element of `array` and `values` to generate the criterion\n * by which they're compared. The order and references of result values are\n * determined by the first array. The iteratee is invoked with one argument:\n * (value).\n *\n * **Note:** Unlike `_.pullAllBy`, this method returns a new array.\n *\n * @static\n * @memberOf _\n * @since 4.0.0\n * @category Array\n * @param {Array} array The array to inspect.\n * @param {...Array} [values] The values to exclude.\n * @param {Function} [iteratee=_.identity] The iteratee invoked per element.\n * @returns {Array} Returns the new array of filtered values.\n * @example\n *\n * _.differenceBy([2.1, 1.2], [2.3, 3.4], Math.floor);\n * // => [1.2]\n *\n * // The `_.property` iteratee shorthand.\n * _.differenceBy([{ 'x': 2 }, { 'x': 1 }], [{ 'x': 1 }], 'x');\n * // => [{ 'x': 2 }]\n */\n var differenceBy = baseRest(function(array, values) {\n var iteratee = last(values);\n if (isArrayLikeObject(iteratee)) {\n iteratee = undefined;\n }\n return isArrayLikeObject(array)\n ? baseDifference(array, baseFlatten(values, 1, isArrayLikeObject, true), getIteratee(iteratee, 2))\n : [];\n });\n\n /**\n * This method is like `_.difference` except that it accepts `comparator`\n * which is invoked to compare elements of `array` to `values`. The order and\n * references of result values are determined by the first array. The comparator\n * is invoked with two arguments: (arrVal, othVal).\n *\n * **Note:** Unlike `_.pullAllWith`, this method returns a new array.\n *\n * @static\n * @memberOf _\n * @since 4.0.0\n * @category Array\n * @param {Array} array The array to inspect.\n * @param {...Array} [values] The values to exclude.\n * @param {Function} [comparator] The comparator invoked per element.\n * @returns {Array} Returns the new array of filtered values.\n * @example\n *\n * var objects = [{ 'x': 1, 'y': 2 }, { 'x': 2, 'y': 1 }];\n *\n * _.differenceWith(objects, [{ 'x': 1, 'y': 2 }], _.isEqual);\n * // => [{ 'x': 2, 'y': 1 }]\n */\n var differenceWith = baseRest(function(array, values) {\n var comparator = last(values);\n if (isArrayLikeObject(comparator)) {\n comparator = undefined;\n }\n return isArrayLikeObject(array)\n ? baseDifference(array, baseFlatten(values, 1, isArrayLikeObject, true), undefined, comparator)\n : [];\n });\n\n /**\n * Creates a slice of `array` with `n` elements dropped from the beginning.\n *\n * @static\n * @memberOf _\n * @since 0.5.0\n * @category Array\n * @param {Array} array The array to query.\n * @param {number} [n=1] The number of elements to drop.\n * @param- {Object} [guard] Enables use as an iteratee for methods like `_.map`.\n * @returns {Array} Returns the slice of `array`.\n * @example\n *\n * _.drop([1, 2, 3]);\n * // => [2, 3]\n *\n * _.drop([1, 2, 3], 2);\n * // => [3]\n *\n * _.drop([1, 2, 3], 5);\n * // => []\n *\n * _.drop([1, 2, 3], 0);\n * // => [1, 2, 3]\n */\n function drop(array, n, guard) {\n var length = array == null ? 0 : array.length;\n if (!length) {\n return [];\n }\n n = (guard || n === undefined) ? 1 : toInteger(n);\n return baseSlice(array, n < 0 ? 0 : n, length);\n }\n\n /**\n * Creates a slice of `array` with `n` elements dropped from the end.\n *\n * @static\n * @memberOf _\n * @since 3.0.0\n * @category Array\n * @param {Array} array The array to query.\n * @param {number} [n=1] The number of elements to drop.\n * @param- {Object} [guard] Enables use as an iteratee for methods like `_.map`.\n * @returns {Array} Returns the slice of `array`.\n * @example\n *\n * _.dropRight([1, 2, 3]);\n * // => [1, 2]\n *\n * _.dropRight([1, 2, 3], 2);\n * // => [1]\n *\n * _.dropRight([1, 2, 3], 5);\n * // => []\n *\n * _.dropRight([1, 2, 3], 0);\n * // => [1, 2, 3]\n */\n function dropRight(array, n, guard) {\n var length = array == null ? 0 : array.length;\n if (!length) {\n return [];\n }\n n = (guard || n === undefined) ? 1 : toInteger(n);\n n = length - n;\n return baseSlice(array, 0, n < 0 ? 0 : n);\n }\n\n /**\n * Creates a slice of `array` excluding elements dropped from the end.\n * Elements are dropped until `predicate` returns falsey. The predicate is\n * invoked with three arguments: (value, index, array).\n *\n * @static\n * @memberOf _\n * @since 3.0.0\n * @category Array\n * @param {Array} array The array to query.\n * @param {Function} [predicate=_.identity] The function invoked per iteration.\n * @returns {Array} Returns the slice of `array`.\n * @example\n *\n * var users = [\n * { 'user': 'barney', 'active': true },\n * { 'user': 'fred', 'active': false },\n * { 'user': 'pebbles', 'active': false }\n * ];\n *\n * _.dropRightWhile(users, function(o) { return !o.active; });\n * // => objects for ['barney']\n *\n * // The `_.matches` iteratee shorthand.\n * _.dropRightWhile(users, { 'user': 'pebbles', 'active': false });\n * // => objects for ['barney', 'fred']\n *\n * // The `_.matchesProperty` iteratee shorthand.\n * _.dropRightWhile(users, ['active', false]);\n * // => objects for ['barney']\n *\n * // The `_.property` iteratee shorthand.\n * _.dropRightWhile(users, 'active');\n * // => objects for ['barney', 'fred', 'pebbles']\n */\n function dropRightWhile(array, predicate) {\n return (array && array.length)\n ? baseWhile(array, getIteratee(predicate, 3), true, true)\n : [];\n }\n\n /**\n * Creates a slice of `array` excluding elements dropped from the beginning.\n * Elements are dropped until `predicate` returns falsey. The predicate is\n * invoked with three arguments: (value, index, array).\n *\n * @static\n * @memberOf _\n * @since 3.0.0\n * @category Array\n * @param {Array} array The array to query.\n * @param {Function} [predicate=_.identity] The function invoked per iteration.\n * @returns {Array} Returns the slice of `array`.\n * @example\n *\n * var users = [\n * { 'user': 'barney', 'active': false },\n * { 'user': 'fred', 'active': false },\n * { 'user': 'pebbles', 'active': true }\n * ];\n *\n * _.dropWhile(users, function(o) { return !o.active; });\n * // => objects for ['pebbles']\n *\n * // The `_.matches` iteratee shorthand.\n * _.dropWhile(users, { 'user': 'barney', 'active': false });\n * // => objects for ['fred', 'pebbles']\n *\n * // The `_.matchesProperty` iteratee shorthand.\n * _.dropWhile(users, ['active', false]);\n * // => objects for ['pebbles']\n *\n * // The `_.property` iteratee shorthand.\n * _.dropWhile(users, 'active');\n * // => objects for ['barney', 'fred', 'pebbles']\n */\n function dropWhile(array, predicate) {\n return (array && array.length)\n ? baseWhile(array, getIteratee(predicate, 3), true)\n : [];\n }\n\n /**\n * Fills elements of `array` with `value` from `start` up to, but not\n * including, `end`.\n *\n * **Note:** This method mutates `array`.\n *\n * @static\n * @memberOf _\n * @since 3.2.0\n * @category Array\n * @param {Array} array The array to fill.\n * @param {*} value The value to fill `array` with.\n * @param {number} [start=0] The start position.\n * @param {number} [end=array.length] The end position.\n * @returns {Array} Returns `array`.\n * @example\n *\n * var array = [1, 2, 3];\n *\n * _.fill(array, 'a');\n * console.log(array);\n * // => ['a', 'a', 'a']\n *\n * _.fill(Array(3), 2);\n * // => [2, 2, 2]\n *\n * _.fill([4, 6, 8, 10], '*', 1, 3);\n * // => [4, '*', '*', 10]\n */\n function fill(array, value, start, end) {\n var length = array == null ? 0 : array.length;\n if (!length) {\n return [];\n }\n if (start && typeof start != 'number' && isIterateeCall(array, value, start)) {\n start = 0;\n end = length;\n }\n return baseFill(array, value, start, end);\n }\n\n /**\n * This method is like `_.find` except that it returns the index of the first\n * element `predicate` returns truthy for instead of the element itself.\n *\n * @static\n * @memberOf _\n * @since 1.1.0\n * @category Array\n * @param {Array} array The array to inspect.\n * @param {Function} [predicate=_.identity] The function invoked per iteration.\n * @param {number} [fromIndex=0] The index to search from.\n * @returns {number} Returns the index of the found element, else `-1`.\n * @example\n *\n * var users = [\n * { 'user': 'barney', 'active': false },\n * { 'user': 'fred', 'active': false },\n * { 'user': 'pebbles', 'active': true }\n * ];\n *\n * _.findIndex(users, function(o) { return o.user == 'barney'; });\n * // => 0\n *\n * // The `_.matches` iteratee shorthand.\n * _.findIndex(users, { 'user': 'fred', 'active': false });\n * // => 1\n *\n * // The `_.matchesProperty` iteratee shorthand.\n * _.findIndex(users, ['active', false]);\n * // => 0\n *\n * // The `_.property` iteratee shorthand.\n * _.findIndex(users, 'active');\n * // => 2\n */\n function findIndex(array, predicate, fromIndex) {\n var length = array == null ? 0 : array.length;\n if (!length) {\n return -1;\n }\n var index = fromIndex == null ? 0 : toInteger(fromIndex);\n if (index < 0) {\n index = nativeMax(length + index, 0);\n }\n return baseFindIndex(array, getIteratee(predicate, 3), index);\n }\n\n /**\n * This method is like `_.findIndex` except that it iterates over elements\n * of `collection` from right to left.\n *\n * @static\n * @memberOf _\n * @since 2.0.0\n * @category Array\n * @param {Array} array The array to inspect.\n * @param {Function} [predicate=_.identity] The function invoked per iteration.\n * @param {number} [fromIndex=array.length-1] The index to search from.\n * @returns {number} Returns the index of the found element, else `-1`.\n * @example\n *\n * var users = [\n * { 'user': 'barney', 'active': true },\n * { 'user': 'fred', 'active': false },\n * { 'user': 'pebbles', 'active': false }\n * ];\n *\n * _.findLastIndex(users, function(o) { return o.user == 'pebbles'; });\n * // => 2\n *\n * // The `_.matches` iteratee shorthand.\n * _.findLastIndex(users, { 'user': 'barney', 'active': true });\n * // => 0\n *\n * // The `_.matchesProperty` iteratee shorthand.\n * _.findLastIndex(users, ['active', false]);\n * // => 2\n *\n * // The `_.property` iteratee shorthand.\n * _.findLastIndex(users, 'active');\n * // => 0\n */\n function findLastIndex(array, predicate, fromIndex) {\n var length = array == null ? 0 : array.length;\n if (!length) {\n return -1;\n }\n var index = length - 1;\n if (fromIndex !== undefined) {\n index = toInteger(fromIndex);\n index = fromIndex < 0\n ? nativeMax(length + index, 0)\n : nativeMin(index, length - 1);\n }\n return baseFindIndex(array, getIteratee(predicate, 3), index, true);\n }\n\n /**\n * Flattens `array` a single level deep.\n *\n * @static\n * @memberOf _\n * @since 0.1.0\n * @category Array\n * @param {Array} array The array to flatten.\n * @returns {Array} Returns the new flattened array.\n * @example\n *\n * _.flatten([1, [2, [3, [4]], 5]]);\n * // => [1, 2, [3, [4]], 5]\n */\n function flatten(array) {\n var length = array == null ? 0 : array.length;\n return length ? baseFlatten(array, 1) : [];\n }\n\n /**\n * Recursively flattens `array`.\n *\n * @static\n * @memberOf _\n * @since 3.0.0\n * @category Array\n * @param {Array} array The array to flatten.\n * @returns {Array} Returns the new flattened array.\n * @example\n *\n * _.flattenDeep([1, [2, [3, [4]], 5]]);\n * // => [1, 2, 3, 4, 5]\n */\n function flattenDeep(array) {\n var length = array == null ? 0 : array.length;\n return length ? baseFlatten(array, INFINITY) : [];\n }\n\n /**\n * Recursively flatten `array` up to `depth` times.\n *\n * @static\n * @memberOf _\n * @since 4.4.0\n * @category Array\n * @param {Array} array The array to flatten.\n * @param {number} [depth=1] The maximum recursion depth.\n * @returns {Array} Returns the new flattened array.\n * @example\n *\n * var array = [1, [2, [3, [4]], 5]];\n *\n * _.flattenDepth(array, 1);\n * // => [1, 2, [3, [4]], 5]\n *\n * _.flattenDepth(array, 2);\n * // => [1, 2, 3, [4], 5]\n */\n function flattenDepth(array, depth) {\n var length = array == null ? 0 : array.length;\n if (!length) {\n return [];\n }\n depth = depth === undefined ? 1 : toInteger(depth);\n return baseFlatten(array, depth);\n }\n\n /**\n * The inverse of `_.toPairs`; this method returns an object composed\n * from key-value `pairs`.\n *\n * @static\n * @memberOf _\n * @since 4.0.0\n * @category Array\n * @param {Array} pairs The key-value pairs.\n * @returns {Object} Returns the new object.\n * @example\n *\n * _.fromPairs([['a', 1], ['b', 2]]);\n * // => { 'a': 1, 'b': 2 }\n */\n function fromPairs(pairs) {\n var index = -1,\n length = pairs == null ? 0 : pairs.length,\n result = {};\n\n while (++index < length) {\n var pair = pairs[index];\n result[pair[0]] = pair[1];\n }\n return result;\n }\n\n /**\n * Gets the first element of `array`.\n *\n * @static\n * @memberOf _\n * @since 0.1.0\n * @alias first\n * @category Array\n * @param {Array} array The array to query.\n * @returns {*} Returns the first element of `array`.\n * @example\n *\n * _.head([1, 2, 3]);\n * // => 1\n *\n * _.head([]);\n * // => undefined\n */\n function head(array) {\n return (array && array.length) ? array[0] : undefined;\n }\n\n /**\n * Gets the index at which the first occurrence of `value` is found in `array`\n * using [`SameValueZero`](http://ecma-international.org/ecma-262/7.0/#sec-samevaluezero)\n * for equality comparisons. If `fromIndex` is negative, it's used as the\n * offset from the end of `array`.\n *\n * @static\n * @memberOf _\n * @since 0.1.0\n * @category Array\n * @param {Array} array The array to inspect.\n * @param {*} value The value to search for.\n * @param {number} [fromIndex=0] The index to search from.\n * @returns {number} Returns the index of the matched value, else `-1`.\n * @example\n *\n * _.indexOf([1, 2, 1, 2], 2);\n * // => 1\n *\n * // Search from the `fromIndex`.\n * _.indexOf([1, 2, 1, 2], 2, 2);\n * // => 3\n */\n function indexOf(array, value, fromIndex) {\n var length = array == null ? 0 : array.length;\n if (!length) {\n return -1;\n }\n var index = fromIndex == null ? 0 : toInteger(fromIndex);\n if (index < 0) {\n index = nativeMax(length + index, 0);\n }\n return baseIndexOf(array, value, index);\n }\n\n /**\n * Gets all but the last element of `array`.\n *\n * @static\n * @memberOf _\n * @since 0.1.0\n * @category Array\n * @param {Array} array The array to query.\n * @returns {Array} Returns the slice of `array`.\n * @example\n *\n * _.initial([1, 2, 3]);\n * // => [1, 2]\n */\n function initial(array) {\n var length = array == null ? 0 : array.length;\n return length ? baseSlice(array, 0, -1) : [];\n }\n\n /**\n * Creates an array of unique values that are included in all given arrays\n * using [`SameValueZero`](http://ecma-international.org/ecma-262/7.0/#sec-samevaluezero)\n * for equality comparisons. The order and references of result values are\n * determined by the first array.\n *\n * @static\n * @memberOf _\n * @since 0.1.0\n * @category Array\n * @param {...Array} [arrays] The arrays to inspect.\n * @returns {Array} Returns the new array of intersecting values.\n * @example\n *\n * _.intersection([2, 1], [2, 3]);\n * // => [2]\n */\n var intersection = baseRest(function(arrays) {\n var mapped = arrayMap(arrays, castArrayLikeObject);\n return (mapped.length && mapped[0] === arrays[0])\n ? baseIntersection(mapped)\n : [];\n });\n\n /**\n * This method is like `_.intersection` except that it accepts `iteratee`\n * which is invoked for each element of each `arrays` to generate the criterion\n * by which they're compared. The order and references of result values are\n * determined by the first array. The iteratee is invoked with one argument:\n * (value).\n *\n * @static\n * @memberOf _\n * @since 4.0.0\n * @category Array\n * @param {...Array} [arrays] The arrays to inspect.\n * @param {Function} [iteratee=_.identity] The iteratee invoked per element.\n * @returns {Array} Returns the new array of intersecting values.\n * @example\n *\n * _.intersectionBy([2.1, 1.2], [2.3, 3.4], Math.floor);\n * // => [2.1]\n *\n * // The `_.property` iteratee shorthand.\n * _.intersectionBy([{ 'x': 1 }], [{ 'x': 2 }, { 'x': 1 }], 'x');\n * // => [{ 'x': 1 }]\n */\n var intersectionBy = baseRest(function(arrays) {\n var iteratee = last(arrays),\n mapped = arrayMap(arrays, castArrayLikeObject);\n\n if (iteratee === last(mapped)) {\n iteratee = undefined;\n } else {\n mapped.pop();\n }\n return (mapped.length && mapped[0] === arrays[0])\n ? baseIntersection(mapped, getIteratee(iteratee, 2))\n : [];\n });\n\n /**\n * This method is like `_.intersection` except that it accepts `comparator`\n * which is invoked to compare elements of `arrays`. The order and references\n * of result values are determined by the first array. The comparator is\n * invoked with two arguments: (arrVal, othVal).\n *\n * @static\n * @memberOf _\n * @since 4.0.0\n * @category Array\n * @param {...Array} [arrays] The arrays to inspect.\n * @param {Function} [comparator] The comparator invoked per element.\n * @returns {Array} Returns the new array of intersecting values.\n * @example\n *\n * var objects = [{ 'x': 1, 'y': 2 }, { 'x': 2, 'y': 1 }];\n * var others = [{ 'x': 1, 'y': 1 }, { 'x': 1, 'y': 2 }];\n *\n * _.intersectionWith(objects, others, _.isEqual);\n * // => [{ 'x': 1, 'y': 2 }]\n */\n var intersectionWith = baseRest(function(arrays) {\n var comparator = last(arrays),\n mapped = arrayMap(arrays, castArrayLikeObject);\n\n comparator = typeof comparator == 'function' ? comparator : undefined;\n if (comparator) {\n mapped.pop();\n }\n return (mapped.length && mapped[0] === arrays[0])\n ? baseIntersection(mapped, undefined, comparator)\n : [];\n });\n\n /**\n * Converts all elements in `array` into a string separated by `separator`.\n *\n * @static\n * @memberOf _\n * @since 4.0.0\n * @category Array\n * @param {Array} array The array to convert.\n * @param {string} [separator=','] The element separator.\n * @returns {string} Returns the joined string.\n * @example\n *\n * _.join(['a', 'b', 'c'], '~');\n * // => 'a~b~c'\n */\n function join(array, separator) {\n return array == null ? '' : nativeJoin.call(array, separator);\n }\n\n /**\n * Gets the last element of `array`.\n *\n * @static\n * @memberOf _\n * @since 0.1.0\n * @category Array\n * @param {Array} array The array to query.\n * @returns {*} Returns the last element of `array`.\n * @example\n *\n * _.last([1, 2, 3]);\n * // => 3\n */\n function last(array) {\n var length = array == null ? 0 : array.length;\n return length ? array[length - 1] : undefined;\n }\n\n /**\n * This method is like `_.indexOf` except that it iterates over elements of\n * `array` from right to left.\n *\n * @static\n * @memberOf _\n * @since 0.1.0\n * @category Array\n * @param {Array} array The array to inspect.\n * @param {*} value The value to search for.\n * @param {number} [fromIndex=array.length-1] The index to search from.\n * @returns {number} Returns the index of the matched value, else `-1`.\n * @example\n *\n * _.lastIndexOf([1, 2, 1, 2], 2);\n * // => 3\n *\n * // Search from the `fromIndex`.\n * _.lastIndexOf([1, 2, 1, 2], 2, 2);\n * // => 1\n */\n function lastIndexOf(array, value, fromIndex) {\n var length = array == null ? 0 : array.length;\n if (!length) {\n return -1;\n }\n var index = length;\n if (fromIndex !== undefined) {\n index = toInteger(fromIndex);\n index = index < 0 ? nativeMax(length + index, 0) : nativeMin(index, length - 1);\n }\n return value === value\n ? strictLastIndexOf(array, value, index)\n : baseFindIndex(array, baseIsNaN, index, true);\n }\n\n /**\n * Gets the element at index `n` of `array`. If `n` is negative, the nth\n * element from the end is returned.\n *\n * @static\n * @memberOf _\n * @since 4.11.0\n * @category Array\n * @param {Array} array The array to query.\n * @param {number} [n=0] The index of the element to return.\n * @returns {*} Returns the nth element of `array`.\n * @example\n *\n * var array = ['a', 'b', 'c', 'd'];\n *\n * _.nth(array, 1);\n * // => 'b'\n *\n * _.nth(array, -2);\n * // => 'c';\n */\n function nth(array, n) {\n return (array && array.length) ? baseNth(array, toInteger(n)) : undefined;\n }\n\n /**\n * Removes all given values from `array` using\n * [`SameValueZero`](http://ecma-international.org/ecma-262/7.0/#sec-samevaluezero)\n * for equality comparisons.\n *\n * **Note:** Unlike `_.without`, this method mutates `array`. Use `_.remove`\n * to remove elements from an array by predicate.\n *\n * @static\n * @memberOf _\n * @since 2.0.0\n * @category Array\n * @param {Array} array The array to modify.\n * @param {...*} [values] The values to remove.\n * @returns {Array} Returns `array`.\n * @example\n *\n * var array = ['a', 'b', 'c', 'a', 'b', 'c'];\n *\n * _.pull(array, 'a', 'c');\n * console.log(array);\n * // => ['b', 'b']\n */\n var pull = baseRest(pullAll);\n\n /**\n * This method is like `_.pull` except that it accepts an array of values to remove.\n *\n * **Note:** Unlike `_.difference`, this method mutates `array`.\n *\n * @static\n * @memberOf _\n * @since 4.0.0\n * @category Array\n * @param {Array} array The array to modify.\n * @param {Array} values The values to remove.\n * @returns {Array} Returns `array`.\n * @example\n *\n * var array = ['a', 'b', 'c', 'a', 'b', 'c'];\n *\n * _.pullAll(array, ['a', 'c']);\n * console.log(array);\n * // => ['b', 'b']\n */\n function pullAll(array, values) {\n return (array && array.length && values && values.length)\n ? basePullAll(array, values)\n : array;\n }\n\n /**\n * This method is like `_.pullAll` except that it accepts `iteratee` which is\n * invoked for each element of `array` and `values` to generate the criterion\n * by which they're compared. The iteratee is invoked with one argument: (value).\n *\n * **Note:** Unlike `_.differenceBy`, this method mutates `array`.\n *\n * @static\n * @memberOf _\n * @since 4.0.0\n * @category Array\n * @param {Array} array The array to modify.\n * @param {Array} values The values to remove.\n * @param {Function} [iteratee=_.identity] The iteratee invoked per element.\n * @returns {Array} Returns `array`.\n * @example\n *\n * var array = [{ 'x': 1 }, { 'x': 2 }, { 'x': 3 }, { 'x': 1 }];\n *\n * _.pullAllBy(array, [{ 'x': 1 }, { 'x': 3 }], 'x');\n * console.log(array);\n * // => [{ 'x': 2 }]\n */\n function pullAllBy(array, values, iteratee) {\n return (array && array.length && values && values.length)\n ? basePullAll(array, values, getIteratee(iteratee, 2))\n : array;\n }\n\n /**\n * This method is like `_.pullAll` except that it accepts `comparator` which\n * is invoked to compare elements of `array` to `values`. The comparator is\n * invoked with two arguments: (arrVal, othVal).\n *\n * **Note:** Unlike `_.differenceWith`, this method mutates `array`.\n *\n * @static\n * @memberOf _\n * @since 4.6.0\n * @category Array\n * @param {Array} array The array to modify.\n * @param {Array} values The values to remove.\n * @param {Function} [comparator] The comparator invoked per element.\n * @returns {Array} Returns `array`.\n * @example\n *\n * var array = [{ 'x': 1, 'y': 2 }, { 'x': 3, 'y': 4 }, { 'x': 5, 'y': 6 }];\n *\n * _.pullAllWith(array, [{ 'x': 3, 'y': 4 }], _.isEqual);\n * console.log(array);\n * // => [{ 'x': 1, 'y': 2 }, { 'x': 5, 'y': 6 }]\n */\n function pullAllWith(array, values, comparator) {\n return (array && array.length && values && values.length)\n ? basePullAll(array, values, undefined, comparator)\n : array;\n }\n\n /**\n * Removes elements from `array` corresponding to `indexes` and returns an\n * array of removed elements.\n *\n * **Note:** Unlike `_.at`, this method mutates `array`.\n *\n * @static\n * @memberOf _\n * @since 3.0.0\n * @category Array\n * @param {Array} array The array to modify.\n * @param {...(number|number[])} [indexes] The indexes of elements to remove.\n * @returns {Array} Returns the new array of removed elements.\n * @example\n *\n * var array = ['a', 'b', 'c', 'd'];\n * var pulled = _.pullAt(array, [1, 3]);\n *\n * console.log(array);\n * // => ['a', 'c']\n *\n * console.log(pulled);\n * // => ['b', 'd']\n */\n var pullAt = flatRest(function(array, indexes) {\n var length = array == null ? 0 : array.length,\n result = baseAt(array, indexes);\n\n basePullAt(array, arrayMap(indexes, function(index) {\n return isIndex(index, length) ? +index : index;\n }).sort(compareAscending));\n\n return result;\n });\n\n /**\n * Removes all elements from `array` that `predicate` returns truthy for\n * and returns an array of the removed elements. The predicate is invoked\n * with three arguments: (value, index, array).\n *\n * **Note:** Unlike `_.filter`, this method mutates `array`. Use `_.pull`\n * to pull elements from an array by value.\n *\n * @static\n * @memberOf _\n * @since 2.0.0\n * @category Array\n * @param {Array} array The array to modify.\n * @param {Function} [predicate=_.identity] The function invoked per iteration.\n * @returns {Array} Returns the new array of removed elements.\n * @example\n *\n * var array = [1, 2, 3, 4];\n * var evens = _.remove(array, function(n) {\n * return n % 2 == 0;\n * });\n *\n * console.log(array);\n * // => [1, 3]\n *\n * console.log(evens);\n * // => [2, 4]\n */\n function remove(array, predicate) {\n var result = [];\n if (!(array && array.length)) {\n return result;\n }\n var index = -1,\n indexes = [],\n length = array.length;\n\n predicate = getIteratee(predicate, 3);\n while (++index < length) {\n var value = array[index];\n if (predicate(value, index, array)) {\n result.push(value);\n indexes.push(index);\n }\n }\n basePullAt(array, indexes);\n return result;\n }\n\n /**\n * Reverses `array` so that the first element becomes the last, the second\n * element becomes the second to last, and so on.\n *\n * **Note:** This method mutates `array` and is based on\n * [`Array#reverse`](https://mdn.io/Array/reverse).\n *\n * @static\n * @memberOf _\n * @since 4.0.0\n * @category Array\n * @param {Array} array The array to modify.\n * @returns {Array} Returns `array`.\n * @example\n *\n * var array = [1, 2, 3];\n *\n * _.reverse(array);\n * // => [3, 2, 1]\n *\n * console.log(array);\n * // => [3, 2, 1]\n */\n function reverse(array) {\n return array == null ? array : nativeReverse.call(array);\n }\n\n /**\n * Creates a slice of `array` from `start` up to, but not including, `end`.\n *\n * **Note:** This method is used instead of\n * [`Array#slice`](https://mdn.io/Array/slice) to ensure dense arrays are\n * returned.\n *\n * @static\n * @memberOf _\n * @since 3.0.0\n * @category Array\n * @param {Array} array The array to slice.\n * @param {number} [start=0] The start position.\n * @param {number} [end=array.length] The end position.\n * @returns {Array} Returns the slice of `array`.\n */\n function slice(array, start, end) {\n var length = array == null ? 0 : array.length;\n if (!length) {\n return [];\n }\n if (end && typeof end != 'number' && isIterateeCall(array, start, end)) {\n start = 0;\n end = length;\n }\n else {\n start = start == null ? 0 : toInteger(start);\n end = end === undefined ? length : toInteger(end);\n }\n return baseSlice(array, start, end);\n }\n\n /**\n * Uses a binary search to determine the lowest index at which `value`\n * should be inserted into `array` in order to maintain its sort order.\n *\n * @static\n * @memberOf _\n * @since 0.1.0\n * @category Array\n * @param {Array} array The sorted array to inspect.\n * @param {*} value The value to evaluate.\n * @returns {number} Returns the index at which `value` should be inserted\n * into `array`.\n * @example\n *\n * _.sortedIndex([30, 50], 40);\n * // => 1\n */\n function sortedIndex(array, value) {\n return baseSortedIndex(array, value);\n }\n\n /**\n * This method is like `_.sortedIndex` except that it accepts `iteratee`\n * which is invoked for `value` and each element of `array` to compute their\n * sort ranking. The iteratee is invoked with one argument: (value).\n *\n * @static\n * @memberOf _\n * @since 4.0.0\n * @category Array\n * @param {Array} array The sorted array to inspect.\n * @param {*} value The value to evaluate.\n * @param {Function} [iteratee=_.identity] The iteratee invoked per element.\n * @returns {number} Returns the index at which `value` should be inserted\n * into `array`.\n * @example\n *\n * var objects = [{ 'x': 4 }, { 'x': 5 }];\n *\n * _.sortedIndexBy(objects, { 'x': 4 }, function(o) { return o.x; });\n * // => 0\n *\n * // The `_.property` iteratee shorthand.\n * _.sortedIndexBy(objects, { 'x': 4 }, 'x');\n * // => 0\n */\n function sortedIndexBy(array, value, iteratee) {\n return baseSortedIndexBy(array, value, getIteratee(iteratee, 2));\n }\n\n /**\n * This method is like `_.indexOf` except that it performs a binary\n * search on a sorted `array`.\n *\n * @static\n * @memberOf _\n * @since 4.0.0\n * @category Array\n * @param {Array} array The array to inspect.\n * @param {*} value The value to search for.\n * @returns {number} Returns the index of the matched value, else `-1`.\n * @example\n *\n * _.sortedIndexOf([4, 5, 5, 5, 6], 5);\n * // => 1\n */\n function sortedIndexOf(array, value) {\n var length = array == null ? 0 : array.length;\n if (length) {\n var index = baseSortedIndex(array, value);\n if (index < length && eq(array[index], value)) {\n return index;\n }\n }\n return -1;\n }\n\n /**\n * This method is like `_.sortedIndex` except that it returns the highest\n * index at which `value` should be inserted into `array` in order to\n * maintain its sort order.\n *\n * @static\n * @memberOf _\n * @since 3.0.0\n * @category Array\n * @param {Array} array The sorted array to inspect.\n * @param {*} value The value to evaluate.\n * @returns {number} Returns the index at which `value` should be inserted\n * into `array`.\n * @example\n *\n * _.sortedLastIndex([4, 5, 5, 5, 6], 5);\n * // => 4\n */\n function sortedLastIndex(array, value) {\n return baseSortedIndex(array, value, true);\n }\n\n /**\n * This method is like `_.sortedLastIndex` except that it accepts `iteratee`\n * which is invoked for `value` and each element of `array` to compute their\n * sort ranking. The iteratee is invoked with one argument: (value).\n *\n * @static\n * @memberOf _\n * @since 4.0.0\n * @category Array\n * @param {Array} array The sorted array to inspect.\n * @param {*} value The value to evaluate.\n * @param {Function} [iteratee=_.identity] The iteratee invoked per element.\n * @returns {number} Returns the index at which `value` should be inserted\n * into `array`.\n * @example\n *\n * var objects = [{ 'x': 4 }, { 'x': 5 }];\n *\n * _.sortedLastIndexBy(objects, { 'x': 4 }, function(o) { return o.x; });\n * // => 1\n *\n * // The `_.property` iteratee shorthand.\n * _.sortedLastIndexBy(objects, { 'x': 4 }, 'x');\n * // => 1\n */\n function sortedLastIndexBy(array, value, iteratee) {\n return baseSortedIndexBy(array, value, getIteratee(iteratee, 2), true);\n }\n\n /**\n * This method is like `_.lastIndexOf` except that it performs a binary\n * search on a sorted `array`.\n *\n * @static\n * @memberOf _\n * @since 4.0.0\n * @category Array\n * @param {Array} array The array to inspect.\n * @param {*} value The value to search for.\n * @returns {number} Returns the index of the matched value, else `-1`.\n * @example\n *\n * _.sortedLastIndexOf([4, 5, 5, 5, 6], 5);\n * // => 3\n */\n function sortedLastIndexOf(array, value) {\n var length = array == null ? 0 : array.length;\n if (length) {\n var index = baseSortedIndex(array, value, true) - 1;\n if (eq(array[index], value)) {\n return index;\n }\n }\n return -1;\n }\n\n /**\n * This method is like `_.uniq` except that it's designed and optimized\n * for sorted arrays.\n *\n * @static\n * @memberOf _\n * @since 4.0.0\n * @category Array\n * @param {Array} array The array to inspect.\n * @returns {Array} Returns the new duplicate free array.\n * @example\n *\n * _.sortedUniq([1, 1, 2]);\n * // => [1, 2]\n */\n function sortedUniq(array) {\n return (array && array.length)\n ? baseSortedUniq(array)\n : [];\n }\n\n /**\n * This method is like `_.uniqBy` except that it's designed and optimized\n * for sorted arrays.\n *\n * @static\n * @memberOf _\n * @since 4.0.0\n * @category Array\n * @param {Array} array The array to inspect.\n * @param {Function} [iteratee] The iteratee invoked per element.\n * @returns {Array} Returns the new duplicate free array.\n * @example\n *\n * _.sortedUniqBy([1.1, 1.2, 2.3, 2.4], Math.floor);\n * // => [1.1, 2.3]\n */\n function sortedUniqBy(array, iteratee) {\n return (array && array.length)\n ? baseSortedUniq(array, getIteratee(iteratee, 2))\n : [];\n }\n\n /**\n * Gets all but the first element of `array`.\n *\n * @static\n * @memberOf _\n * @since 4.0.0\n * @category Array\n * @param {Array} array The array to query.\n * @returns {Array} Returns the slice of `array`.\n * @example\n *\n * _.tail([1, 2, 3]);\n * // => [2, 3]\n */\n function tail(array) {\n var length = array == null ? 0 : array.length;\n return length ? baseSlice(array, 1, length) : [];\n }\n\n /**\n * Creates a slice of `array` with `n` elements taken from the beginning.\n *\n * @static\n * @memberOf _\n * @since 0.1.0\n * @category Array\n * @param {Array} array The array to query.\n * @param {number} [n=1] The number of elements to take.\n * @param- {Object} [guard] Enables use as an iteratee for methods like `_.map`.\n * @returns {Array} Returns the slice of `array`.\n * @example\n *\n * _.take([1, 2, 3]);\n * // => [1]\n *\n * _.take([1, 2, 3], 2);\n * // => [1, 2]\n *\n * _.take([1, 2, 3], 5);\n * // => [1, 2, 3]\n *\n * _.take([1, 2, 3], 0);\n * // => []\n */\n function take(array, n, guard) {\n if (!(array && array.length)) {\n return [];\n }\n n = (guard || n === undefined) ? 1 : toInteger(n);\n return baseSlice(array, 0, n < 0 ? 0 : n);\n }\n\n /**\n * Creates a slice of `array` with `n` elements taken from the end.\n *\n * @static\n * @memberOf _\n * @since 3.0.0\n * @category Array\n * @param {Array} array The array to query.\n * @param {number} [n=1] The number of elements to take.\n * @param- {Object} [guard] Enables use as an iteratee for methods like `_.map`.\n * @returns {Array} Returns the slice of `array`.\n * @example\n *\n * _.takeRight([1, 2, 3]);\n * // => [3]\n *\n * _.takeRight([1, 2, 3], 2);\n * // => [2, 3]\n *\n * _.takeRight([1, 2, 3], 5);\n * // => [1, 2, 3]\n *\n * _.takeRight([1, 2, 3], 0);\n * // => []\n */\n function takeRight(array, n, guard) {\n var length = array == null ? 0 : array.length;\n if (!length) {\n return [];\n }\n n = (guard || n === undefined) ? 1 : toInteger(n);\n n = length - n;\n return baseSlice(array, n < 0 ? 0 : n, length);\n }\n\n /**\n * Creates a slice of `array` with elements taken from the end. Elements are\n * taken until `predicate` returns falsey. The predicate is invoked with\n * three arguments: (value, index, array).\n *\n * @static\n * @memberOf _\n * @since 3.0.0\n * @category Array\n * @param {Array} array The array to query.\n * @param {Function} [predicate=_.identity] The function invoked per iteration.\n * @returns {Array} Returns the slice of `array`.\n * @example\n *\n * var users = [\n * { 'user': 'barney', 'active': true },\n * { 'user': 'fred', 'active': false },\n * { 'user': 'pebbles', 'active': false }\n * ];\n *\n * _.takeRightWhile(users, function(o) { return !o.active; });\n * // => objects for ['fred', 'pebbles']\n *\n * // The `_.matches` iteratee shorthand.\n * _.takeRightWhile(users, { 'user': 'pebbles', 'active': false });\n * // => objects for ['pebbles']\n *\n * // The `_.matchesProperty` iteratee shorthand.\n * _.takeRightWhile(users, ['active', false]);\n * // => objects for ['fred', 'pebbles']\n *\n * // The `_.property` iteratee shorthand.\n * _.takeRightWhile(users, 'active');\n * // => []\n */\n function takeRightWhile(array, predicate) {\n return (array && array.length)\n ? baseWhile(array, getIteratee(predicate, 3), false, true)\n : [];\n }\n\n /**\n * Creates a slice of `array` with elements taken from the beginning. Elements\n * are taken until `predicate` returns falsey. The predicate is invoked with\n * three arguments: (value, index, array).\n *\n * @static\n * @memberOf _\n * @since 3.0.0\n * @category Array\n * @param {Array} array The array to query.\n * @param {Function} [predicate=_.identity] The function invoked per iteration.\n * @returns {Array} Returns the slice of `array`.\n * @example\n *\n * var users = [\n * { 'user': 'barney', 'active': false },\n * { 'user': 'fred', 'active': false },\n * { 'user': 'pebbles', 'active': true }\n * ];\n *\n * _.takeWhile(users, function(o) { return !o.active; });\n * // => objects for ['barney', 'fred']\n *\n * // The `_.matches` iteratee shorthand.\n * _.takeWhile(users, { 'user': 'barney', 'active': false });\n * // => objects for ['barney']\n *\n * // The `_.matchesProperty` iteratee shorthand.\n * _.takeWhile(users, ['active', false]);\n * // => objects for ['barney', 'fred']\n *\n * // The `_.property` iteratee shorthand.\n * _.takeWhile(users, 'active');\n * // => []\n */\n function takeWhile(array, predicate) {\n return (array && array.length)\n ? baseWhile(array, getIteratee(predicate, 3))\n : [];\n }\n\n /**\n * Creates an array of unique values, in order, from all given arrays using\n * [`SameValueZero`](http://ecma-international.org/ecma-262/7.0/#sec-samevaluezero)\n * for equality comparisons.\n *\n * @static\n * @memberOf _\n * @since 0.1.0\n * @category Array\n * @param {...Array} [arrays] The arrays to inspect.\n * @returns {Array} Returns the new array of combined values.\n * @example\n *\n * _.union([2], [1, 2]);\n * // => [2, 1]\n */\n var union = baseRest(function(arrays) {\n return baseUniq(baseFlatten(arrays, 1, isArrayLikeObject, true));\n });\n\n /**\n * This method is like `_.union` except that it accepts `iteratee` which is\n * invoked for each element of each `arrays` to generate the criterion by\n * which uniqueness is computed. Result values are chosen from the first\n * array in which the value occurs. The iteratee is invoked with one argument:\n * (value).\n *\n * @static\n * @memberOf _\n * @since 4.0.0\n * @category Array\n * @param {...Array} [arrays] The arrays to inspect.\n * @param {Function} [iteratee=_.identity] The iteratee invoked per element.\n * @returns {Array} Returns the new array of combined values.\n * @example\n *\n * _.unionBy([2.1], [1.2, 2.3], Math.floor);\n * // => [2.1, 1.2]\n *\n * // The `_.property` iteratee shorthand.\n * _.unionBy([{ 'x': 1 }], [{ 'x': 2 }, { 'x': 1 }], 'x');\n * // => [{ 'x': 1 }, { 'x': 2 }]\n */\n var unionBy = baseRest(function(arrays) {\n var iteratee = last(arrays);\n if (isArrayLikeObject(iteratee)) {\n iteratee = undefined;\n }\n return baseUniq(baseFlatten(arrays, 1, isArrayLikeObject, true), getIteratee(iteratee, 2));\n });\n\n /**\n * This method is like `_.union` except that it accepts `comparator` which\n * is invoked to compare elements of `arrays`. Result values are chosen from\n * the first array in which the value occurs. The comparator is invoked\n * with two arguments: (arrVal, othVal).\n *\n * @static\n * @memberOf _\n * @since 4.0.0\n * @category Array\n * @param {...Array} [arrays] The arrays to inspect.\n * @param {Function} [comparator] The comparator invoked per element.\n * @returns {Array} Returns the new array of combined values.\n * @example\n *\n * var objects = [{ 'x': 1, 'y': 2 }, { 'x': 2, 'y': 1 }];\n * var others = [{ 'x': 1, 'y': 1 }, { 'x': 1, 'y': 2 }];\n *\n * _.unionWith(objects, others, _.isEqual);\n * // => [{ 'x': 1, 'y': 2 }, { 'x': 2, 'y': 1 }, { 'x': 1, 'y': 1 }]\n */\n var unionWith = baseRest(function(arrays) {\n var comparator = last(arrays);\n comparator = typeof comparator == 'function' ? comparator : undefined;\n return baseUniq(baseFlatten(arrays, 1, isArrayLikeObject, true), undefined, comparator);\n });\n\n /**\n * Creates a duplicate-free version of an array, using\n * [`SameValueZero`](http://ecma-international.org/ecma-262/7.0/#sec-samevaluezero)\n * for equality comparisons, in which only the first occurrence of each element\n * is kept. The order of result values is determined by the order they occur\n * in the array.\n *\n * @static\n * @memberOf _\n * @since 0.1.0\n * @category Array\n * @param {Array} array The array to inspect.\n * @returns {Array} Returns the new duplicate free array.\n * @example\n *\n * _.uniq([2, 1, 2]);\n * // => [2, 1]\n */\n function uniq(array) {\n return (array && array.length) ? baseUniq(array) : [];\n }\n\n /**\n * This method is like `_.uniq` except that it accepts `iteratee` which is\n * invoked for each element in `array` to generate the criterion by which\n * uniqueness is computed. The order of result values is determined by the\n * order they occur in the array. The iteratee is invoked with one argument:\n * (value).\n *\n * @static\n * @memberOf _\n * @since 4.0.0\n * @category Array\n * @param {Array} array The array to inspect.\n * @param {Function} [iteratee=_.identity] The iteratee invoked per element.\n * @returns {Array} Returns the new duplicate free array.\n * @example\n *\n * _.uniqBy([2.1, 1.2, 2.3], Math.floor);\n * // => [2.1, 1.2]\n *\n * // The `_.property` iteratee shorthand.\n * _.uniqBy([{ 'x': 1 }, { 'x': 2 }, { 'x': 1 }], 'x');\n * // => [{ 'x': 1 }, { 'x': 2 }]\n */\n function uniqBy(array, iteratee) {\n return (array && array.length) ? baseUniq(array, getIteratee(iteratee, 2)) : [];\n }\n\n /**\n * This method is like `_.uniq` except that it accepts `comparator` which\n * is invoked to compare elements of `array`. The order of result values is\n * determined by the order they occur in the array.The comparator is invoked\n * with two arguments: (arrVal, othVal).\n *\n * @static\n * @memberOf _\n * @since 4.0.0\n * @category Array\n * @param {Array} array The array to inspect.\n * @param {Function} [comparator] The comparator invoked per element.\n * @returns {Array} Returns the new duplicate free array.\n * @example\n *\n * var objects = [{ 'x': 1, 'y': 2 }, { 'x': 2, 'y': 1 }, { 'x': 1, 'y': 2 }];\n *\n * _.uniqWith(objects, _.isEqual);\n * // => [{ 'x': 1, 'y': 2 }, { 'x': 2, 'y': 1 }]\n */\n function uniqWith(array, comparator) {\n comparator = typeof comparator == 'function' ? comparator : undefined;\n return (array && array.length) ? baseUniq(array, undefined, comparator) : [];\n }\n\n /**\n * This method is like `_.zip` except that it accepts an array of grouped\n * elements and creates an array regrouping the elements to their pre-zip\n * configuration.\n *\n * @static\n * @memberOf _\n * @since 1.2.0\n * @category Array\n * @param {Array} array The array of grouped elements to process.\n * @returns {Array} Returns the new array of regrouped elements.\n * @example\n *\n * var zipped = _.zip(['a', 'b'], [1, 2], [true, false]);\n * // => [['a', 1, true], ['b', 2, false]]\n *\n * _.unzip(zipped);\n * // => [['a', 'b'], [1, 2], [true, false]]\n */\n function unzip(array) {\n if (!(array && array.length)) {\n return [];\n }\n var length = 0;\n array = arrayFilter(array, function(group) {\n if (isArrayLikeObject(group)) {\n length = nativeMax(group.length, length);\n return true;\n }\n });\n return baseTimes(length, function(index) {\n return arrayMap(array, baseProperty(index));\n });\n }\n\n /**\n * This method is like `_.unzip` except that it accepts `iteratee` to specify\n * how regrouped values should be combined. The iteratee is invoked with the\n * elements of each group: (...group).\n *\n * @static\n * @memberOf _\n * @since 3.8.0\n * @category Array\n * @param {Array} array The array of grouped elements to process.\n * @param {Function} [iteratee=_.identity] The function to combine\n * regrouped values.\n * @returns {Array} Returns the new array of regrouped elements.\n * @example\n *\n * var zipped = _.zip([1, 2], [10, 20], [100, 200]);\n * // => [[1, 10, 100], [2, 20, 200]]\n *\n * _.unzipWith(zipped, _.add);\n * // => [3, 30, 300]\n */\n function unzipWith(array, iteratee) {\n if (!(array && array.length)) {\n return [];\n }\n var result = unzip(array);\n if (iteratee == null) {\n return result;\n }\n return arrayMap(result, function(group) {\n return apply(iteratee, undefined, group);\n });\n }\n\n /**\n * Creates an array excluding all given values using\n * [`SameValueZero`](http://ecma-international.org/ecma-262/7.0/#sec-samevaluezero)\n * for equality comparisons.\n *\n * **Note:** Unlike `_.pull`, this method returns a new array.\n *\n * @static\n * @memberOf _\n * @since 0.1.0\n * @category Array\n * @param {Array} array The array to inspect.\n * @param {...*} [values] The values to exclude.\n * @returns {Array} Returns the new array of filtered values.\n * @see _.difference, _.xor\n * @example\n *\n * _.without([2, 1, 2, 3], 1, 2);\n * // => [3]\n */\n var without = baseRest(function(array, values) {\n return isArrayLikeObject(array)\n ? baseDifference(array, values)\n : [];\n });\n\n /**\n * Creates an array of unique values that is the\n * [symmetric difference](https://en.wikipedia.org/wiki/Symmetric_difference)\n * of the given arrays. The order of result values is determined by the order\n * they occur in the arrays.\n *\n * @static\n * @memberOf _\n * @since 2.4.0\n * @category Array\n * @param {...Array} [arrays] The arrays to inspect.\n * @returns {Array} Returns the new array of filtered values.\n * @see _.difference, _.without\n * @example\n *\n * _.xor([2, 1], [2, 3]);\n * // => [1, 3]\n */\n var xor = baseRest(function(arrays) {\n return baseXor(arrayFilter(arrays, isArrayLikeObject));\n });\n\n /**\n * This method is like `_.xor` except that it accepts `iteratee` which is\n * invoked for each element of each `arrays` to generate the criterion by\n * which by which they're compared. The order of result values is determined\n * by the order they occur in the arrays. The iteratee is invoked with one\n * argument: (value).\n *\n * @static\n * @memberOf _\n * @since 4.0.0\n * @category Array\n * @param {...Array} [arrays] The arrays to inspect.\n * @param {Function} [iteratee=_.identity] The iteratee invoked per element.\n * @returns {Array} Returns the new array of filtered values.\n * @example\n *\n * _.xorBy([2.1, 1.2], [2.3, 3.4], Math.floor);\n * // => [1.2, 3.4]\n *\n * // The `_.property` iteratee shorthand.\n * _.xorBy([{ 'x': 1 }], [{ 'x': 2 }, { 'x': 1 }], 'x');\n * // => [{ 'x': 2 }]\n */\n var xorBy = baseRest(function(arrays) {\n var iteratee = last(arrays);\n if (isArrayLikeObject(iteratee)) {\n iteratee = undefined;\n }\n return baseXor(arrayFilter(arrays, isArrayLikeObject), getIteratee(iteratee, 2));\n });\n\n /**\n * This method is like `_.xor` except that it accepts `comparator` which is\n * invoked to compare elements of `arrays`. The order of result values is\n * determined by the order they occur in the arrays. The comparator is invoked\n * with two arguments: (arrVal, othVal).\n *\n * @static\n * @memberOf _\n * @since 4.0.0\n * @category Array\n * @param {...Array} [arrays] The arrays to inspect.\n * @param {Function} [comparator] The comparator invoked per element.\n * @returns {Array} Returns the new array of filtered values.\n * @example\n *\n * var objects = [{ 'x': 1, 'y': 2 }, { 'x': 2, 'y': 1 }];\n * var others = [{ 'x': 1, 'y': 1 }, { 'x': 1, 'y': 2 }];\n *\n * _.xorWith(objects, others, _.isEqual);\n * // => [{ 'x': 2, 'y': 1 }, { 'x': 1, 'y': 1 }]\n */\n var xorWith = baseRest(function(arrays) {\n var comparator = last(arrays);\n comparator = typeof comparator == 'function' ? comparator : undefined;\n return baseXor(arrayFilter(arrays, isArrayLikeObject), undefined, comparator);\n });\n\n /**\n * Creates an array of grouped elements, the first of which contains the\n * first elements of the given arrays, the second of which contains the\n * second elements of the given arrays, and so on.\n *\n * @static\n * @memberOf _\n * @since 0.1.0\n * @category Array\n * @param {...Array} [arrays] The arrays to process.\n * @returns {Array} Returns the new array of grouped elements.\n * @example\n *\n * _.zip(['a', 'b'], [1, 2], [true, false]);\n * // => [['a', 1, true], ['b', 2, false]]\n */\n var zip = baseRest(unzip);\n\n /**\n * This method is like `_.fromPairs` except that it accepts two arrays,\n * one of property identifiers and one of corresponding values.\n *\n * @static\n * @memberOf _\n * @since 0.4.0\n * @category Array\n * @param {Array} [props=[]] The property identifiers.\n * @param {Array} [values=[]] The property values.\n * @returns {Object} Returns the new object.\n * @example\n *\n * _.zipObject(['a', 'b'], [1, 2]);\n * // => { 'a': 1, 'b': 2 }\n */\n function zipObject(props, values) {\n return baseZipObject(props || [], values || [], assignValue);\n }\n\n /**\n * This method is like `_.zipObject` except that it supports property paths.\n *\n * @static\n * @memberOf _\n * @since 4.1.0\n * @category Array\n * @param {Array} [props=[]] The property identifiers.\n * @param {Array} [values=[]] The property values.\n * @returns {Object} Returns the new object.\n * @example\n *\n * _.zipObjectDeep(['a.b[0].c', 'a.b[1].d'], [1, 2]);\n * // => { 'a': { 'b': [{ 'c': 1 }, { 'd': 2 }] } }\n */\n function zipObjectDeep(props, values) {\n return baseZipObject(props || [], values || [], baseSet);\n }\n\n /**\n * This method is like `_.zip` except that it accepts `iteratee` to specify\n * how grouped values should be combined. The iteratee is invoked with the\n * elements of each group: (...group).\n *\n * @static\n * @memberOf _\n * @since 3.8.0\n * @category Array\n * @param {...Array} [arrays] The arrays to process.\n * @param {Function} [iteratee=_.identity] The function to combine\n * grouped values.\n * @returns {Array} Returns the new array of grouped elements.\n * @example\n *\n * _.zipWith([1, 2], [10, 20], [100, 200], function(a, b, c) {\n * return a + b + c;\n * });\n * // => [111, 222]\n */\n var zipWith = baseRest(function(arrays) {\n var length = arrays.length,\n iteratee = length > 1 ? arrays[length - 1] : undefined;\n\n iteratee = typeof iteratee == 'function' ? (arrays.pop(), iteratee) : undefined;\n return unzipWith(arrays, iteratee);\n });\n\n /*------------------------------------------------------------------------*/\n\n /**\n * Creates a `lodash` wrapper instance that wraps `value` with explicit method\n * chain sequences enabled. The result of such sequences must be unwrapped\n * with `_#value`.\n *\n * @static\n * @memberOf _\n * @since 1.3.0\n * @category Seq\n * @param {*} value The value to wrap.\n * @returns {Object} Returns the new `lodash` wrapper instance.\n * @example\n *\n * var users = [\n * { 'user': 'barney', 'age': 36 },\n * { 'user': 'fred', 'age': 40 },\n * { 'user': 'pebbles', 'age': 1 }\n * ];\n *\n * var youngest = _\n * .chain(users)\n * .sortBy('age')\n * .map(function(o) {\n * return o.user + ' is ' + o.age;\n * })\n * .head()\n * .value();\n * // => 'pebbles is 1'\n */\n function chain(value) {\n var result = lodash(value);\n result.__chain__ = true;\n return result;\n }\n\n /**\n * This method invokes `interceptor` and returns `value`. The interceptor\n * is invoked with one argument; (value). The purpose of this method is to\n * \"tap into\" a method chain sequence in order to modify intermediate results.\n *\n * @static\n * @memberOf _\n * @since 0.1.0\n * @category Seq\n * @param {*} value The value to provide to `interceptor`.\n * @param {Function} interceptor The function to invoke.\n * @returns {*} Returns `value`.\n * @example\n *\n * _([1, 2, 3])\n * .tap(function(array) {\n * // Mutate input array.\n * array.pop();\n * })\n * .reverse()\n * .value();\n * // => [2, 1]\n */\n function tap(value, interceptor) {\n interceptor(value);\n return value;\n }\n\n /**\n * This method is like `_.tap` except that it returns the result of `interceptor`.\n * The purpose of this method is to \"pass thru\" values replacing intermediate\n * results in a method chain sequence.\n *\n * @static\n * @memberOf _\n * @since 3.0.0\n * @category Seq\n * @param {*} value The value to provide to `interceptor`.\n * @param {Function} interceptor The function to invoke.\n * @returns {*} Returns the result of `interceptor`.\n * @example\n *\n * _(' abc ')\n * .chain()\n * .trim()\n * .thru(function(value) {\n * return [value];\n * })\n * .value();\n * // => ['abc']\n */\n function thru(value, interceptor) {\n return interceptor(value);\n }\n\n /**\n * This method is the wrapper version of `_.at`.\n *\n * @name at\n * @memberOf _\n * @since 1.0.0\n * @category Seq\n * @param {...(string|string[])} [paths] The property paths to pick.\n * @returns {Object} Returns the new `lodash` wrapper instance.\n * @example\n *\n * var object = { 'a': [{ 'b': { 'c': 3 } }, 4] };\n *\n * _(object).at(['a[0].b.c', 'a[1]']).value();\n * // => [3, 4]\n */\n var wrapperAt = flatRest(function(paths) {\n var length = paths.length,\n start = length ? paths[0] : 0,\n value = this.__wrapped__,\n interceptor = function(object) { return baseAt(object, paths); };\n\n if (length > 1 || this.__actions__.length ||\n !(value instanceof LazyWrapper) || !isIndex(start)) {\n return this.thru(interceptor);\n }\n value = value.slice(start, +start + (length ? 1 : 0));\n value.__actions__.push({\n 'func': thru,\n 'args': [interceptor],\n 'thisArg': undefined\n });\n return new LodashWrapper(value, this.__chain__).thru(function(array) {\n if (length && !array.length) {\n array.push(undefined);\n }\n return array;\n });\n });\n\n /**\n * Creates a `lodash` wrapper instance with explicit method chain sequences enabled.\n *\n * @name chain\n * @memberOf _\n * @since 0.1.0\n * @category Seq\n * @returns {Object} Returns the new `lodash` wrapper instance.\n * @example\n *\n * var users = [\n * { 'user': 'barney', 'age': 36 },\n * { 'user': 'fred', 'age': 40 }\n * ];\n *\n * // A sequence without explicit chaining.\n * _(users).head();\n * // => { 'user': 'barney', 'age': 36 }\n *\n * // A sequence with explicit chaining.\n * _(users)\n * .chain()\n * .head()\n * .pick('user')\n * .value();\n * // => { 'user': 'barney' }\n */\n function wrapperChain() {\n return chain(this);\n }\n\n /**\n * Executes the chain sequence and returns the wrapped result.\n *\n * @name commit\n * @memberOf _\n * @since 3.2.0\n * @category Seq\n * @returns {Object} Returns the new `lodash` wrapper instance.\n * @example\n *\n * var array = [1, 2];\n * var wrapped = _(array).push(3);\n *\n * console.log(array);\n * // => [1, 2]\n *\n * wrapped = wrapped.commit();\n * console.log(array);\n * // => [1, 2, 3]\n *\n * wrapped.last();\n * // => 3\n *\n * console.log(array);\n * // => [1, 2, 3]\n */\n function wrapperCommit() {\n return new LodashWrapper(this.value(), this.__chain__);\n }\n\n /**\n * Gets the next value on a wrapped object following the\n * [iterator protocol](https://mdn.io/iteration_protocols#iterator).\n *\n * @name next\n * @memberOf _\n * @since 4.0.0\n * @category Seq\n * @returns {Object} Returns the next iterator value.\n * @example\n *\n * var wrapped = _([1, 2]);\n *\n * wrapped.next();\n * // => { 'done': false, 'value': 1 }\n *\n * wrapped.next();\n * // => { 'done': false, 'value': 2 }\n *\n * wrapped.next();\n * // => { 'done': true, 'value': undefined }\n */\n function wrapperNext() {\n if (this.__values__ === undefined) {\n this.__values__ = toArray(this.value());\n }\n var done = this.__index__ >= this.__values__.length,\n value = done ? undefined : this.__values__[this.__index__++];\n\n return { 'done': done, 'value': value };\n }\n\n /**\n * Enables the wrapper to be iterable.\n *\n * @name Symbol.iterator\n * @memberOf _\n * @since 4.0.0\n * @category Seq\n * @returns {Object} Returns the wrapper object.\n * @example\n *\n * var wrapped = _([1, 2]);\n *\n * wrapped[Symbol.iterator]() === wrapped;\n * // => true\n *\n * Array.from(wrapped);\n * // => [1, 2]\n */\n function wrapperToIterator() {\n return this;\n }\n\n /**\n * Creates a clone of the chain sequence planting `value` as the wrapped value.\n *\n * @name plant\n * @memberOf _\n * @since 3.2.0\n * @category Seq\n * @param {*} value The value to plant.\n * @returns {Object} Returns the new `lodash` wrapper instance.\n * @example\n *\n * function square(n) {\n * return n * n;\n * }\n *\n * var wrapped = _([1, 2]).map(square);\n * var other = wrapped.plant([3, 4]);\n *\n * other.value();\n * // => [9, 16]\n *\n * wrapped.value();\n * // => [1, 4]\n */\n function wrapperPlant(value) {\n var result,\n parent = this;\n\n while (parent instanceof baseLodash) {\n var clone = wrapperClone(parent);\n clone.__index__ = 0;\n clone.__values__ = undefined;\n if (result) {\n previous.__wrapped__ = clone;\n } else {\n result = clone;\n }\n var previous = clone;\n parent = parent.__wrapped__;\n }\n previous.__wrapped__ = value;\n return result;\n }\n\n /**\n * This method is the wrapper version of `_.reverse`.\n *\n * **Note:** This method mutates the wrapped array.\n *\n * @name reverse\n * @memberOf _\n * @since 0.1.0\n * @category Seq\n * @returns {Object} Returns the new `lodash` wrapper instance.\n * @example\n *\n * var array = [1, 2, 3];\n *\n * _(array).reverse().value()\n * // => [3, 2, 1]\n *\n * console.log(array);\n * // => [3, 2, 1]\n */\n function wrapperReverse() {\n var value = this.__wrapped__;\n if (value instanceof LazyWrapper) {\n var wrapped = value;\n if (this.__actions__.length) {\n wrapped = new LazyWrapper(this);\n }\n wrapped = wrapped.reverse();\n wrapped.__actions__.push({\n 'func': thru,\n 'args': [reverse],\n 'thisArg': undefined\n });\n return new LodashWrapper(wrapped, this.__chain__);\n }\n return this.thru(reverse);\n }\n\n /**\n * Executes the chain sequence to resolve the unwrapped value.\n *\n * @name value\n * @memberOf _\n * @since 0.1.0\n * @alias toJSON, valueOf\n * @category Seq\n * @returns {*} Returns the resolved unwrapped value.\n * @example\n *\n * _([1, 2, 3]).value();\n * // => [1, 2, 3]\n */\n function wrapperValue() {\n return baseWrapperValue(this.__wrapped__, this.__actions__);\n }\n\n /*------------------------------------------------------------------------*/\n\n /**\n * Creates an object composed of keys generated from the results of running\n * each element of `collection` thru `iteratee`. The corresponding value of\n * each key is the number of times the key was returned by `iteratee`. The\n * iteratee is invoked with one argument: (value).\n *\n * @static\n * @memberOf _\n * @since 0.5.0\n * @category Collection\n * @param {Array|Object} collection The collection to iterate over.\n * @param {Function} [iteratee=_.identity] The iteratee to transform keys.\n * @returns {Object} Returns the composed aggregate object.\n * @example\n *\n * _.countBy([6.1, 4.2, 6.3], Math.floor);\n * // => { '4': 1, '6': 2 }\n *\n * // The `_.property` iteratee shorthand.\n * _.countBy(['one', 'two', 'three'], 'length');\n * // => { '3': 2, '5': 1 }\n */\n var countBy = createAggregator(function(result, value, key) {\n if (hasOwnProperty.call(result, key)) {\n ++result[key];\n } else {\n baseAssignValue(result, key, 1);\n }\n });\n\n /**\n * Checks if `predicate` returns truthy for **all** elements of `collection`.\n * Iteration is stopped once `predicate` returns falsey. The predicate is\n * invoked with three arguments: (value, index|key, collection).\n *\n * **Note:** This method returns `true` for\n * [empty collections](https://en.wikipedia.org/wiki/Empty_set) because\n * [everything is true](https://en.wikipedia.org/wiki/Vacuous_truth) of\n * elements of empty collections.\n *\n * @static\n * @memberOf _\n * @since 0.1.0\n * @category Collection\n * @param {Array|Object} collection The collection to iterate over.\n * @param {Function} [predicate=_.identity] The function invoked per iteration.\n * @param- {Object} [guard] Enables use as an iteratee for methods like `_.map`.\n * @returns {boolean} Returns `true` if all elements pass the predicate check,\n * else `false`.\n * @example\n *\n * _.every([true, 1, null, 'yes'], Boolean);\n * // => false\n *\n * var users = [\n * { 'user': 'barney', 'age': 36, 'active': false },\n * { 'user': 'fred', 'age': 40, 'active': false }\n * ];\n *\n * // The `_.matches` iteratee shorthand.\n * _.every(users, { 'user': 'barney', 'active': false });\n * // => false\n *\n * // The `_.matchesProperty` iteratee shorthand.\n * _.every(users, ['active', false]);\n * // => true\n *\n * // The `_.property` iteratee shorthand.\n * _.every(users, 'active');\n * // => false\n */\n function every(collection, predicate, guard) {\n var func = isArray(collection) ? arrayEvery : baseEvery;\n if (guard && isIterateeCall(collection, predicate, guard)) {\n predicate = undefined;\n }\n return func(collection, getIteratee(predicate, 3));\n }\n\n /**\n * Iterates over elements of `collection`, returning an array of all elements\n * `predicate` returns truthy for. The predicate is invoked with three\n * arguments: (value, index|key, collection).\n *\n * **Note:** Unlike `_.remove`, this method returns a new array.\n *\n * @static\n * @memberOf _\n * @since 0.1.0\n * @category Collection\n * @param {Array|Object} collection The collection to iterate over.\n * @param {Function} [predicate=_.identity] The function invoked per iteration.\n * @returns {Array} Returns the new filtered array.\n * @see _.reject\n * @example\n *\n * var users = [\n * { 'user': 'barney', 'age': 36, 'active': true },\n * { 'user': 'fred', 'age': 40, 'active': false }\n * ];\n *\n * _.filter(users, function(o) { return !o.active; });\n * // => objects for ['fred']\n *\n * // The `_.matches` iteratee shorthand.\n * _.filter(users, { 'age': 36, 'active': true });\n * // => objects for ['barney']\n *\n * // The `_.matchesProperty` iteratee shorthand.\n * _.filter(users, ['active', false]);\n * // => objects for ['fred']\n *\n * // The `_.property` iteratee shorthand.\n * _.filter(users, 'active');\n * // => objects for ['barney']\n *\n * // Combining several predicates using `_.overEvery` or `_.overSome`.\n * _.filter(users, _.overSome([{ 'age': 36 }, ['age', 40]]));\n * // => objects for ['fred', 'barney']\n */\n function filter(collection, predicate) {\n var func = isArray(collection) ? arrayFilter : baseFilter;\n return func(collection, getIteratee(predicate, 3));\n }\n\n /**\n * Iterates over elements of `collection`, returning the first element\n * `predicate` returns truthy for. The predicate is invoked with three\n * arguments: (value, index|key, collection).\n *\n * @static\n * @memberOf _\n * @since 0.1.0\n * @category Collection\n * @param {Array|Object} collection The collection to inspect.\n * @param {Function} [predicate=_.identity] The function invoked per iteration.\n * @param {number} [fromIndex=0] The index to search from.\n * @returns {*} Returns the matched element, else `undefined`.\n * @example\n *\n * var users = [\n * { 'user': 'barney', 'age': 36, 'active': true },\n * { 'user': 'fred', 'age': 40, 'active': false },\n * { 'user': 'pebbles', 'age': 1, 'active': true }\n * ];\n *\n * _.find(users, function(o) { return o.age < 40; });\n * // => object for 'barney'\n *\n * // The `_.matches` iteratee shorthand.\n * _.find(users, { 'age': 1, 'active': true });\n * // => object for 'pebbles'\n *\n * // The `_.matchesProperty` iteratee shorthand.\n * _.find(users, ['active', false]);\n * // => object for 'fred'\n *\n * // The `_.property` iteratee shorthand.\n * _.find(users, 'active');\n * // => object for 'barney'\n */\n var find = createFind(findIndex);\n\n /**\n * This method is like `_.find` except that it iterates over elements of\n * `collection` from right to left.\n *\n * @static\n * @memberOf _\n * @since 2.0.0\n * @category Collection\n * @param {Array|Object} collection The collection to inspect.\n * @param {Function} [predicate=_.identity] The function invoked per iteration.\n * @param {number} [fromIndex=collection.length-1] The index to search from.\n * @returns {*} Returns the matched element, else `undefined`.\n * @example\n *\n * _.findLast([1, 2, 3, 4], function(n) {\n * return n % 2 == 1;\n * });\n * // => 3\n */\n var findLast = createFind(findLastIndex);\n\n /**\n * Creates a flattened array of values by running each element in `collection`\n * thru `iteratee` and flattening the mapped results. The iteratee is invoked\n * with three arguments: (value, index|key, collection).\n *\n * @static\n * @memberOf _\n * @since 4.0.0\n * @category Collection\n * @param {Array|Object} collection The collection to iterate over.\n * @param {Function} [iteratee=_.identity] The function invoked per iteration.\n * @returns {Array} Returns the new flattened array.\n * @example\n *\n * function duplicate(n) {\n * return [n, n];\n * }\n *\n * _.flatMap([1, 2], duplicate);\n * // => [1, 1, 2, 2]\n */\n function flatMap(collection, iteratee) {\n return baseFlatten(map(collection, iteratee), 1);\n }\n\n /**\n * This method is like `_.flatMap` except that it recursively flattens the\n * mapped results.\n *\n * @static\n * @memberOf _\n * @since 4.7.0\n * @category Collection\n * @param {Array|Object} collection The collection to iterate over.\n * @param {Function} [iteratee=_.identity] The function invoked per iteration.\n * @returns {Array} Returns the new flattened array.\n * @example\n *\n * function duplicate(n) {\n * return [[[n, n]]];\n * }\n *\n * _.flatMapDeep([1, 2], duplicate);\n * // => [1, 1, 2, 2]\n */\n function flatMapDeep(collection, iteratee) {\n return baseFlatten(map(collection, iteratee), INFINITY);\n }\n\n /**\n * This method is like `_.flatMap` except that it recursively flattens the\n * mapped results up to `depth` times.\n *\n * @static\n * @memberOf _\n * @since 4.7.0\n * @category Collection\n * @param {Array|Object} collection The collection to iterate over.\n * @param {Function} [iteratee=_.identity] The function invoked per iteration.\n * @param {number} [depth=1] The maximum recursion depth.\n * @returns {Array} Returns the new flattened array.\n * @example\n *\n * function duplicate(n) {\n * return [[[n, n]]];\n * }\n *\n * _.flatMapDepth([1, 2], duplicate, 2);\n * // => [[1, 1], [2, 2]]\n */\n function flatMapDepth(collection, iteratee, depth) {\n depth = depth === undefined ? 1 : toInteger(depth);\n return baseFlatten(map(collection, iteratee), depth);\n }\n\n /**\n * Iterates over elements of `collection` and invokes `iteratee` for each element.\n * The iteratee is invoked with three arguments: (value, index|key, collection).\n * Iteratee functions may exit iteration early by explicitly returning `false`.\n *\n * **Note:** As with other \"Collections\" methods, objects with a \"length\"\n * property are iterated like arrays. To avoid this behavior use `_.forIn`\n * or `_.forOwn` for object iteration.\n *\n * @static\n * @memberOf _\n * @since 0.1.0\n * @alias each\n * @category Collection\n * @param {Array|Object} collection The collection to iterate over.\n * @param {Function} [iteratee=_.identity] The function invoked per iteration.\n * @returns {Array|Object} Returns `collection`.\n * @see _.forEachRight\n * @example\n *\n * _.forEach([1, 2], function(value) {\n * console.log(value);\n * });\n * // => Logs `1` then `2`.\n *\n * _.forEach({ 'a': 1, 'b': 2 }, function(value, key) {\n * console.log(key);\n * });\n * // => Logs 'a' then 'b' (iteration order is not guaranteed).\n */\n function forEach(collection, iteratee) {\n var func = isArray(collection) ? arrayEach : baseEach;\n return func(collection, getIteratee(iteratee, 3));\n }\n\n /**\n * This method is like `_.forEach` except that it iterates over elements of\n * `collection` from right to left.\n *\n * @static\n * @memberOf _\n * @since 2.0.0\n * @alias eachRight\n * @category Collection\n * @param {Array|Object} collection The collection to iterate over.\n * @param {Function} [iteratee=_.identity] The function invoked per iteration.\n * @returns {Array|Object} Returns `collection`.\n * @see _.forEach\n * @example\n *\n * _.forEachRight([1, 2], function(value) {\n * console.log(value);\n * });\n * // => Logs `2` then `1`.\n */\n function forEachRight(collection, iteratee) {\n var func = isArray(collection) ? arrayEachRight : baseEachRight;\n return func(collection, getIteratee(iteratee, 3));\n }\n\n /**\n * Creates an object composed of keys generated from the results of running\n * each element of `collection` thru `iteratee`. The order of grouped values\n * is determined by the order they occur in `collection`. The corresponding\n * value of each key is an array of elements responsible for generating the\n * key. The iteratee is invoked with one argument: (value).\n *\n * @static\n * @memberOf _\n * @since 0.1.0\n * @category Collection\n * @param {Array|Object} collection The collection to iterate over.\n * @param {Function} [iteratee=_.identity] The iteratee to transform keys.\n * @returns {Object} Returns the composed aggregate object.\n * @example\n *\n * _.groupBy([6.1, 4.2, 6.3], Math.floor);\n * // => { '4': [4.2], '6': [6.1, 6.3] }\n *\n * // The `_.property` iteratee shorthand.\n * _.groupBy(['one', 'two', 'three'], 'length');\n * // => { '3': ['one', 'two'], '5': ['three'] }\n */\n var groupBy = createAggregator(function(result, value, key) {\n if (hasOwnProperty.call(result, key)) {\n result[key].push(value);\n } else {\n baseAssignValue(result, key, [value]);\n }\n });\n\n /**\n * Checks if `value` is in `collection`. If `collection` is a string, it's\n * checked for a substring of `value`, otherwise\n * [`SameValueZero`](http://ecma-international.org/ecma-262/7.0/#sec-samevaluezero)\n * is used for equality comparisons. If `fromIndex` is negative, it's used as\n * the offset from the end of `collection`.\n *\n * @static\n * @memberOf _\n * @since 0.1.0\n * @category Collection\n * @param {Array|Object|string} collection The collection to inspect.\n * @param {*} value The value to search for.\n * @param {number} [fromIndex=0] The index to search from.\n * @param- {Object} [guard] Enables use as an iteratee for methods like `_.reduce`.\n * @returns {boolean} Returns `true` if `value` is found, else `false`.\n * @example\n *\n * _.includes([1, 2, 3], 1);\n * // => true\n *\n * _.includes([1, 2, 3], 1, 2);\n * // => false\n *\n * _.includes({ 'a': 1, 'b': 2 }, 1);\n * // => true\n *\n * _.includes('abcd', 'bc');\n * // => true\n */\n function includes(collection, value, fromIndex, guard) {\n collection = isArrayLike(collection) ? collection : values(collection);\n fromIndex = (fromIndex && !guard) ? toInteger(fromIndex) : 0;\n\n var length = collection.length;\n if (fromIndex < 0) {\n fromIndex = nativeMax(length + fromIndex, 0);\n }\n return isString(collection)\n ? (fromIndex <= length && collection.indexOf(value, fromIndex) > -1)\n : (!!length && baseIndexOf(collection, value, fromIndex) > -1);\n }\n\n /**\n * Invokes the method at `path` of each element in `collection`, returning\n * an array of the results of each invoked method. Any additional arguments\n * are provided to each invoked method. If `path` is a function, it's invoked\n * for, and `this` bound to, each element in `collection`.\n *\n * @static\n * @memberOf _\n * @since 4.0.0\n * @category Collection\n * @param {Array|Object} collection The collection to iterate over.\n * @param {Array|Function|string} path The path of the method to invoke or\n * the function invoked per iteration.\n * @param {...*} [args] The arguments to invoke each method with.\n * @returns {Array} Returns the array of results.\n * @example\n *\n * _.invokeMap([[5, 1, 7], [3, 2, 1]], 'sort');\n * // => [[1, 5, 7], [1, 2, 3]]\n *\n * _.invokeMap([123, 456], String.prototype.split, '');\n * // => [['1', '2', '3'], ['4', '5', '6']]\n */\n var invokeMap = baseRest(function(collection, path, args) {\n var index = -1,\n isFunc = typeof path == 'function',\n result = isArrayLike(collection) ? Array(collection.length) : [];\n\n baseEach(collection, function(value) {\n result[++index] = isFunc ? apply(path, value, args) : baseInvoke(value, path, args);\n });\n return result;\n });\n\n /**\n * Creates an object composed of keys generated from the results of running\n * each element of `collection` thru `iteratee`. The corresponding value of\n * each key is the last element responsible for generating the key. The\n * iteratee is invoked with one argument: (value).\n *\n * @static\n * @memberOf _\n * @since 4.0.0\n * @category Collection\n * @param {Array|Object} collection The collection to iterate over.\n * @param {Function} [iteratee=_.identity] The iteratee to transform keys.\n * @returns {Object} Returns the composed aggregate object.\n * @example\n *\n * var array = [\n * { 'dir': 'left', 'code': 97 },\n * { 'dir': 'right', 'code': 100 }\n * ];\n *\n * _.keyBy(array, function(o) {\n * return String.fromCharCode(o.code);\n * });\n * // => { 'a': { 'dir': 'left', 'code': 97 }, 'd': { 'dir': 'right', 'code': 100 } }\n *\n * _.keyBy(array, 'dir');\n * // => { 'left': { 'dir': 'left', 'code': 97 }, 'right': { 'dir': 'right', 'code': 100 } }\n */\n var keyBy = createAggregator(function(result, value, key) {\n baseAssignValue(result, key, value);\n });\n\n /**\n * Creates an array of values by running each element in `collection` thru\n * `iteratee`. The iteratee is invoked with three arguments:\n * (value, index|key, collection).\n *\n * Many lodash methods are guarded to work as iteratees for methods like\n * `_.every`, `_.filter`, `_.map`, `_.mapValues`, `_.reject`, and `_.some`.\n *\n * The guarded methods are:\n * `ary`, `chunk`, `curry`, `curryRight`, `drop`, `dropRight`, `every`,\n * `fill`, `invert`, `parseInt`, `random`, `range`, `rangeRight`, `repeat`,\n * `sampleSize`, `slice`, `some`, `sortBy`, `split`, `take`, `takeRight`,\n * `template`, `trim`, `trimEnd`, `trimStart`, and `words`\n *\n * @static\n * @memberOf _\n * @since 0.1.0\n * @category Collection\n * @param {Array|Object} collection The collection to iterate over.\n * @param {Function} [iteratee=_.identity] The function invoked per iteration.\n * @returns {Array} Returns the new mapped array.\n * @example\n *\n * function square(n) {\n * return n * n;\n * }\n *\n * _.map([4, 8], square);\n * // => [16, 64]\n *\n * _.map({ 'a': 4, 'b': 8 }, square);\n * // => [16, 64] (iteration order is not guaranteed)\n *\n * var users = [\n * { 'user': 'barney' },\n * { 'user': 'fred' }\n * ];\n *\n * // The `_.property` iteratee shorthand.\n * _.map(users, 'user');\n * // => ['barney', 'fred']\n */\n function map(collection, iteratee) {\n var func = isArray(collection) ? arrayMap : baseMap;\n return func(collection, getIteratee(iteratee, 3));\n }\n\n /**\n * This method is like `_.sortBy` except that it allows specifying the sort\n * orders of the iteratees to sort by. If `orders` is unspecified, all values\n * are sorted in ascending order. Otherwise, specify an order of \"desc\" for\n * descending or \"asc\" for ascending sort order of corresponding values.\n *\n * @static\n * @memberOf _\n * @since 4.0.0\n * @category Collection\n * @param {Array|Object} collection The collection to iterate over.\n * @param {Array[]|Function[]|Object[]|string[]} [iteratees=[_.identity]]\n * The iteratees to sort by.\n * @param {string[]} [orders] The sort orders of `iteratees`.\n * @param- {Object} [guard] Enables use as an iteratee for methods like `_.reduce`.\n * @returns {Array} Returns the new sorted array.\n * @example\n *\n * var users = [\n * { 'user': 'fred', 'age': 48 },\n * { 'user': 'barney', 'age': 34 },\n * { 'user': 'fred', 'age': 40 },\n * { 'user': 'barney', 'age': 36 }\n * ];\n *\n * // Sort by `user` in ascending order and by `age` in descending order.\n * _.orderBy(users, ['user', 'age'], ['asc', 'desc']);\n * // => objects for [['barney', 36], ['barney', 34], ['fred', 48], ['fred', 40]]\n */\n function orderBy(collection, iteratees, orders, guard) {\n if (collection == null) {\n return [];\n }\n if (!isArray(iteratees)) {\n iteratees = iteratees == null ? [] : [iteratees];\n }\n orders = guard ? undefined : orders;\n if (!isArray(orders)) {\n orders = orders == null ? [] : [orders];\n }\n return baseOrderBy(collection, iteratees, orders);\n }\n\n /**\n * Creates an array of elements split into two groups, the first of which\n * contains elements `predicate` returns truthy for, the second of which\n * contains elements `predicate` returns falsey for. The predicate is\n * invoked with one argument: (value).\n *\n * @static\n * @memberOf _\n * @since 3.0.0\n * @category Collection\n * @param {Array|Object} collection The collection to iterate over.\n * @param {Function} [predicate=_.identity] The function invoked per iteration.\n * @returns {Array} Returns the array of grouped elements.\n * @example\n *\n * var users = [\n * { 'user': 'barney', 'age': 36, 'active': false },\n * { 'user': 'fred', 'age': 40, 'active': true },\n * { 'user': 'pebbles', 'age': 1, 'active': false }\n * ];\n *\n * _.partition(users, function(o) { return o.active; });\n * // => objects for [['fred'], ['barney', 'pebbles']]\n *\n * // The `_.matches` iteratee shorthand.\n * _.partition(users, { 'age': 1, 'active': false });\n * // => objects for [['pebbles'], ['barney', 'fred']]\n *\n * // The `_.matchesProperty` iteratee shorthand.\n * _.partition(users, ['active', false]);\n * // => objects for [['barney', 'pebbles'], ['fred']]\n *\n * // The `_.property` iteratee shorthand.\n * _.partition(users, 'active');\n * // => objects for [['fred'], ['barney', 'pebbles']]\n */\n var partition = createAggregator(function(result, value, key) {\n result[key ? 0 : 1].push(value);\n }, function() { return [[], []]; });\n\n /**\n * Reduces `collection` to a value which is the accumulated result of running\n * each element in `collection` thru `iteratee`, where each successive\n * invocation is supplied the return value of the previous. If `accumulator`\n * is not given, the first element of `collection` is used as the initial\n * value. The iteratee is invoked with four arguments:\n * (accumulator, value, index|key, collection).\n *\n * Many lodash methods are guarded to work as iteratees for methods like\n * `_.reduce`, `_.reduceRight`, and `_.transform`.\n *\n * The guarded methods are:\n * `assign`, `defaults`, `defaultsDeep`, `includes`, `merge`, `orderBy`,\n * and `sortBy`\n *\n * @static\n * @memberOf _\n * @since 0.1.0\n * @category Collection\n * @param {Array|Object} collection The collection to iterate over.\n * @param {Function} [iteratee=_.identity] The function invoked per iteration.\n * @param {*} [accumulator] The initial value.\n * @returns {*} Returns the accumulated value.\n * @see _.reduceRight\n * @example\n *\n * _.reduce([1, 2], function(sum, n) {\n * return sum + n;\n * }, 0);\n * // => 3\n *\n * _.reduce({ 'a': 1, 'b': 2, 'c': 1 }, function(result, value, key) {\n * (result[value] || (result[value] = [])).push(key);\n * return result;\n * }, {});\n * // => { '1': ['a', 'c'], '2': ['b'] } (iteration order is not guaranteed)\n */\n function reduce(collection, iteratee, accumulator) {\n var func = isArray(collection) ? arrayReduce : baseReduce,\n initAccum = arguments.length < 3;\n\n return func(collection, getIteratee(iteratee, 4), accumulator, initAccum, baseEach);\n }\n\n /**\n * This method is like `_.reduce` except that it iterates over elements of\n * `collection` from right to left.\n *\n * @static\n * @memberOf _\n * @since 0.1.0\n * @category Collection\n * @param {Array|Object} collection The collection to iterate over.\n * @param {Function} [iteratee=_.identity] The function invoked per iteration.\n * @param {*} [accumulator] The initial value.\n * @returns {*} Returns the accumulated value.\n * @see _.reduce\n * @example\n *\n * var array = [[0, 1], [2, 3], [4, 5]];\n *\n * _.reduceRight(array, function(flattened, other) {\n * return flattened.concat(other);\n * }, []);\n * // => [4, 5, 2, 3, 0, 1]\n */\n function reduceRight(collection, iteratee, accumulator) {\n var func = isArray(collection) ? arrayReduceRight : baseReduce,\n initAccum = arguments.length < 3;\n\n return func(collection, getIteratee(iteratee, 4), accumulator, initAccum, baseEachRight);\n }\n\n /**\n * The opposite of `_.filter`; this method returns the elements of `collection`\n * that `predicate` does **not** return truthy for.\n *\n * @static\n * @memberOf _\n * @since 0.1.0\n * @category Collection\n * @param {Array|Object} collection The collection to iterate over.\n * @param {Function} [predicate=_.identity] The function invoked per iteration.\n * @returns {Array} Returns the new filtered array.\n * @see _.filter\n * @example\n *\n * var users = [\n * { 'user': 'barney', 'age': 36, 'active': false },\n * { 'user': 'fred', 'age': 40, 'active': true }\n * ];\n *\n * _.reject(users, function(o) { return !o.active; });\n * // => objects for ['fred']\n *\n * // The `_.matches` iteratee shorthand.\n * _.reject(users, { 'age': 40, 'active': true });\n * // => objects for ['barney']\n *\n * // The `_.matchesProperty` iteratee shorthand.\n * _.reject(users, ['active', false]);\n * // => objects for ['fred']\n *\n * // The `_.property` iteratee shorthand.\n * _.reject(users, 'active');\n * // => objects for ['barney']\n */\n function reject(collection, predicate) {\n var func = isArray(collection) ? arrayFilter : baseFilter;\n return func(collection, negate(getIteratee(predicate, 3)));\n }\n\n /**\n * Gets a random element from `collection`.\n *\n * @static\n * @memberOf _\n * @since 2.0.0\n * @category Collection\n * @param {Array|Object} collection The collection to sample.\n * @returns {*} Returns the random element.\n * @example\n *\n * _.sample([1, 2, 3, 4]);\n * // => 2\n */\n function sample(collection) {\n var func = isArray(collection) ? arraySample : baseSample;\n return func(collection);\n }\n\n /**\n * Gets `n` random elements at unique keys from `collection` up to the\n * size of `collection`.\n *\n * @static\n * @memberOf _\n * @since 4.0.0\n * @category Collection\n * @param {Array|Object} collection The collection to sample.\n * @param {number} [n=1] The number of elements to sample.\n * @param- {Object} [guard] Enables use as an iteratee for methods like `_.map`.\n * @returns {Array} Returns the random elements.\n * @example\n *\n * _.sampleSize([1, 2, 3], 2);\n * // => [3, 1]\n *\n * _.sampleSize([1, 2, 3], 4);\n * // => [2, 3, 1]\n */\n function sampleSize(collection, n, guard) {\n if ((guard ? isIterateeCall(collection, n, guard) : n === undefined)) {\n n = 1;\n } else {\n n = toInteger(n);\n }\n var func = isArray(collection) ? arraySampleSize : baseSampleSize;\n return func(collection, n);\n }\n\n /**\n * Creates an array of shuffled values, using a version of the\n * [Fisher-Yates shuffle](https://en.wikipedia.org/wiki/Fisher-Yates_shuffle).\n *\n * @static\n * @memberOf _\n * @since 0.1.0\n * @category Collection\n * @param {Array|Object} collection The collection to shuffle.\n * @returns {Array} Returns the new shuffled array.\n * @example\n *\n * _.shuffle([1, 2, 3, 4]);\n * // => [4, 1, 3, 2]\n */\n function shuffle(collection) {\n var func = isArray(collection) ? arrayShuffle : baseShuffle;\n return func(collection);\n }\n\n /**\n * Gets the size of `collection` by returning its length for array-like\n * values or the number of own enumerable string keyed properties for objects.\n *\n * @static\n * @memberOf _\n * @since 0.1.0\n * @category Collection\n * @param {Array|Object|string} collection The collection to inspect.\n * @returns {number} Returns the collection size.\n * @example\n *\n * _.size([1, 2, 3]);\n * // => 3\n *\n * _.size({ 'a': 1, 'b': 2 });\n * // => 2\n *\n * _.size('pebbles');\n * // => 7\n */\n function size(collection) {\n if (collection == null) {\n return 0;\n }\n if (isArrayLike(collection)) {\n return isString(collection) ? stringSize(collection) : collection.length;\n }\n var tag = getTag(collection);\n if (tag == mapTag || tag == setTag) {\n return collection.size;\n }\n return baseKeys(collection).length;\n }\n\n /**\n * Checks if `predicate` returns truthy for **any** element of `collection`.\n * Iteration is stopped once `predicate` returns truthy. The predicate is\n * invoked with three arguments: (value, index|key, collection).\n *\n * @static\n * @memberOf _\n * @since 0.1.0\n * @category Collection\n * @param {Array|Object} collection The collection to iterate over.\n * @param {Function} [predicate=_.identity] The function invoked per iteration.\n * @param- {Object} [guard] Enables use as an iteratee for methods like `_.map`.\n * @returns {boolean} Returns `true` if any element passes the predicate check,\n * else `false`.\n * @example\n *\n * _.some([null, 0, 'yes', false], Boolean);\n * // => true\n *\n * var users = [\n * { 'user': 'barney', 'active': true },\n * { 'user': 'fred', 'active': false }\n * ];\n *\n * // The `_.matches` iteratee shorthand.\n * _.some(users, { 'user': 'barney', 'active': false });\n * // => false\n *\n * // The `_.matchesProperty` iteratee shorthand.\n * _.some(users, ['active', false]);\n * // => true\n *\n * // The `_.property` iteratee shorthand.\n * _.some(users, 'active');\n * // => true\n */\n function some(collection, predicate, guard) {\n var func = isArray(collection) ? arraySome : baseSome;\n if (guard && isIterateeCall(collection, predicate, guard)) {\n predicate = undefined;\n }\n return func(collection, getIteratee(predicate, 3));\n }\n\n /**\n * Creates an array of elements, sorted in ascending order by the results of\n * running each element in a collection thru each iteratee. This method\n * performs a stable sort, that is, it preserves the original sort order of\n * equal elements. The iteratees are invoked with one argument: (value).\n *\n * @static\n * @memberOf _\n * @since 0.1.0\n * @category Collection\n * @param {Array|Object} collection The collection to iterate over.\n * @param {...(Function|Function[])} [iteratees=[_.identity]]\n * The iteratees to sort by.\n * @returns {Array} Returns the new sorted array.\n * @example\n *\n * var users = [\n * { 'user': 'fred', 'age': 48 },\n * { 'user': 'barney', 'age': 36 },\n * { 'user': 'fred', 'age': 30 },\n * { 'user': 'barney', 'age': 34 }\n * ];\n *\n * _.sortBy(users, [function(o) { return o.user; }]);\n * // => objects for [['barney', 36], ['barney', 34], ['fred', 48], ['fred', 30]]\n *\n * _.sortBy(users, ['user', 'age']);\n * // => objects for [['barney', 34], ['barney', 36], ['fred', 30], ['fred', 48]]\n */\n var sortBy = baseRest(function(collection, iteratees) {\n if (collection == null) {\n return [];\n }\n var length = iteratees.length;\n if (length > 1 && isIterateeCall(collection, iteratees[0], iteratees[1])) {\n iteratees = [];\n } else if (length > 2 && isIterateeCall(iteratees[0], iteratees[1], iteratees[2])) {\n iteratees = [iteratees[0]];\n }\n return baseOrderBy(collection, baseFlatten(iteratees, 1), []);\n });\n\n /*------------------------------------------------------------------------*/\n\n /**\n * Gets the timestamp of the number of milliseconds that have elapsed since\n * the Unix epoch (1 January 1970 00:00:00 UTC).\n *\n * @static\n * @memberOf _\n * @since 2.4.0\n * @category Date\n * @returns {number} Returns the timestamp.\n * @example\n *\n * _.defer(function(stamp) {\n * console.log(_.now() - stamp);\n * }, _.now());\n * // => Logs the number of milliseconds it took for the deferred invocation.\n */\n var now = ctxNow || function() {\n return root.Date.now();\n };\n\n /*------------------------------------------------------------------------*/\n\n /**\n * The opposite of `_.before`; this method creates a function that invokes\n * `func` once it's called `n` or more times.\n *\n * @static\n * @memberOf _\n * @since 0.1.0\n * @category Function\n * @param {number} n The number of calls before `func` is invoked.\n * @param {Function} func The function to restrict.\n * @returns {Function} Returns the new restricted function.\n * @example\n *\n * var saves = ['profile', 'settings'];\n *\n * var done = _.after(saves.length, function() {\n * console.log('done saving!');\n * });\n *\n * _.forEach(saves, function(type) {\n * asyncSave({ 'type': type, 'complete': done });\n * });\n * // => Logs 'done saving!' after the two async saves have completed.\n */\n function after(n, func) {\n if (typeof func != 'function') {\n throw new TypeError(FUNC_ERROR_TEXT);\n }\n n = toInteger(n);\n return function() {\n if (--n < 1) {\n return func.apply(this, arguments);\n }\n };\n }\n\n /**\n * Creates a function that invokes `func`, with up to `n` arguments,\n * ignoring any additional arguments.\n *\n * @static\n * @memberOf _\n * @since 3.0.0\n * @category Function\n * @param {Function} func The function to cap arguments for.\n * @param {number} [n=func.length] The arity cap.\n * @param- {Object} [guard] Enables use as an iteratee for methods like `_.map`.\n * @returns {Function} Returns the new capped function.\n * @example\n *\n * _.map(['6', '8', '10'], _.ary(parseInt, 1));\n * // => [6, 8, 10]\n */\n function ary(func, n, guard) {\n n = guard ? undefined : n;\n n = (func && n == null) ? func.length : n;\n return createWrap(func, WRAP_ARY_FLAG, undefined, undefined, undefined, undefined, n);\n }\n\n /**\n * Creates a function that invokes `func`, with the `this` binding and arguments\n * of the created function, while it's called less than `n` times. Subsequent\n * calls to the created function return the result of the last `func` invocation.\n *\n * @static\n * @memberOf _\n * @since 3.0.0\n * @category Function\n * @param {number} n The number of calls at which `func` is no longer invoked.\n * @param {Function} func The function to restrict.\n * @returns {Function} Returns the new restricted function.\n * @example\n *\n * jQuery(element).on('click', _.before(5, addContactToList));\n * // => Allows adding up to 4 contacts to the list.\n */\n function before(n, func) {\n var result;\n if (typeof func != 'function') {\n throw new TypeError(FUNC_ERROR_TEXT);\n }\n n = toInteger(n);\n return function() {\n if (--n > 0) {\n result = func.apply(this, arguments);\n }\n if (n <= 1) {\n func = undefined;\n }\n return result;\n };\n }\n\n /**\n * Creates a function that invokes `func` with the `this` binding of `thisArg`\n * and `partials` prepended to the arguments it receives.\n *\n * The `_.bind.placeholder` value, which defaults to `_` in monolithic builds,\n * may be used as a placeholder for partially applied arguments.\n *\n * **Note:** Unlike native `Function#bind`, this method doesn't set the \"length\"\n * property of bound functions.\n *\n * @static\n * @memberOf _\n * @since 0.1.0\n * @category Function\n * @param {Function} func The function to bind.\n * @param {*} thisArg The `this` binding of `func`.\n * @param {...*} [partials] The arguments to be partially applied.\n * @returns {Function} Returns the new bound function.\n * @example\n *\n * function greet(greeting, punctuation) {\n * return greeting + ' ' + this.user + punctuation;\n * }\n *\n * var object = { 'user': 'fred' };\n *\n * var bound = _.bind(greet, object, 'hi');\n * bound('!');\n * // => 'hi fred!'\n *\n * // Bound with placeholders.\n * var bound = _.bind(greet, object, _, '!');\n * bound('hi');\n * // => 'hi fred!'\n */\n var bind = baseRest(function(func, thisArg, partials) {\n var bitmask = WRAP_BIND_FLAG;\n if (partials.length) {\n var holders = replaceHolders(partials, getHolder(bind));\n bitmask |= WRAP_PARTIAL_FLAG;\n }\n return createWrap(func, bitmask, thisArg, partials, holders);\n });\n\n /**\n * Creates a function that invokes the method at `object[key]` with `partials`\n * prepended to the arguments it receives.\n *\n * This method differs from `_.bind` by allowing bound functions to reference\n * methods that may be redefined or don't yet exist. See\n * [Peter Michaux's article](http://peter.michaux.ca/articles/lazy-function-definition-pattern)\n * for more details.\n *\n * The `_.bindKey.placeholder` value, which defaults to `_` in monolithic\n * builds, may be used as a placeholder for partially applied arguments.\n *\n * @static\n * @memberOf _\n * @since 0.10.0\n * @category Function\n * @param {Object} object The object to invoke the method on.\n * @param {string} key The key of the method.\n * @param {...*} [partials] The arguments to be partially applied.\n * @returns {Function} Returns the new bound function.\n * @example\n *\n * var object = {\n * 'user': 'fred',\n * 'greet': function(greeting, punctuation) {\n * return greeting + ' ' + this.user + punctuation;\n * }\n * };\n *\n * var bound = _.bindKey(object, 'greet', 'hi');\n * bound('!');\n * // => 'hi fred!'\n *\n * object.greet = function(greeting, punctuation) {\n * return greeting + 'ya ' + this.user + punctuation;\n * };\n *\n * bound('!');\n * // => 'hiya fred!'\n *\n * // Bound with placeholders.\n * var bound = _.bindKey(object, 'greet', _, '!');\n * bound('hi');\n * // => 'hiya fred!'\n */\n var bindKey = baseRest(function(object, key, partials) {\n var bitmask = WRAP_BIND_FLAG | WRAP_BIND_KEY_FLAG;\n if (partials.length) {\n var holders = replaceHolders(partials, getHolder(bindKey));\n bitmask |= WRAP_PARTIAL_FLAG;\n }\n return createWrap(key, bitmask, object, partials, holders);\n });\n\n /**\n * Creates a function that accepts arguments of `func` and either invokes\n * `func` returning its result, if at least `arity` number of arguments have\n * been provided, or returns a function that accepts the remaining `func`\n * arguments, and so on. The arity of `func` may be specified if `func.length`\n * is not sufficient.\n *\n * The `_.curry.placeholder` value, which defaults to `_` in monolithic builds,\n * may be used as a placeholder for provided arguments.\n *\n * **Note:** This method doesn't set the \"length\" property of curried functions.\n *\n * @static\n * @memberOf _\n * @since 2.0.0\n * @category Function\n * @param {Function} func The function to curry.\n * @param {number} [arity=func.length] The arity of `func`.\n * @param- {Object} [guard] Enables use as an iteratee for methods like `_.map`.\n * @returns {Function} Returns the new curried function.\n * @example\n *\n * var abc = function(a, b, c) {\n * return [a, b, c];\n * };\n *\n * var curried = _.curry(abc);\n *\n * curried(1)(2)(3);\n * // => [1, 2, 3]\n *\n * curried(1, 2)(3);\n * // => [1, 2, 3]\n *\n * curried(1, 2, 3);\n * // => [1, 2, 3]\n *\n * // Curried with placeholders.\n * curried(1)(_, 3)(2);\n * // => [1, 2, 3]\n */\n function curry(func, arity, guard) {\n arity = guard ? undefined : arity;\n var result = createWrap(func, WRAP_CURRY_FLAG, undefined, undefined, undefined, undefined, undefined, arity);\n result.placeholder = curry.placeholder;\n return result;\n }\n\n /**\n * This method is like `_.curry` except that arguments are applied to `func`\n * in the manner of `_.partialRight` instead of `_.partial`.\n *\n * The `_.curryRight.placeholder` value, which defaults to `_` in monolithic\n * builds, may be used as a placeholder for provided arguments.\n *\n * **Note:** This method doesn't set the \"length\" property of curried functions.\n *\n * @static\n * @memberOf _\n * @since 3.0.0\n * @category Function\n * @param {Function} func The function to curry.\n * @param {number} [arity=func.length] The arity of `func`.\n * @param- {Object} [guard] Enables use as an iteratee for methods like `_.map`.\n * @returns {Function} Returns the new curried function.\n * @example\n *\n * var abc = function(a, b, c) {\n * return [a, b, c];\n * };\n *\n * var curried = _.curryRight(abc);\n *\n * curried(3)(2)(1);\n * // => [1, 2, 3]\n *\n * curried(2, 3)(1);\n * // => [1, 2, 3]\n *\n * curried(1, 2, 3);\n * // => [1, 2, 3]\n *\n * // Curried with placeholders.\n * curried(3)(1, _)(2);\n * // => [1, 2, 3]\n */\n function curryRight(func, arity, guard) {\n arity = guard ? undefined : arity;\n var result = createWrap(func, WRAP_CURRY_RIGHT_FLAG, undefined, undefined, undefined, undefined, undefined, arity);\n result.placeholder = curryRight.placeholder;\n return result;\n }\n\n /**\n * Creates a debounced function that delays invoking `func` until after `wait`\n * milliseconds have elapsed since the last time the debounced function was\n * invoked. The debounced function comes with a `cancel` method to cancel\n * delayed `func` invocations and a `flush` method to immediately invoke them.\n * Provide `options` to indicate whether `func` should be invoked on the\n * leading and/or trailing edge of the `wait` timeout. The `func` is invoked\n * with the last arguments provided to the debounced function. Subsequent\n * calls to the debounced function return the result of the last `func`\n * invocation.\n *\n * **Note:** If `leading` and `trailing` options are `true`, `func` is\n * invoked on the trailing edge of the timeout only if the debounced function\n * is invoked more than once during the `wait` timeout.\n *\n * If `wait` is `0` and `leading` is `false`, `func` invocation is deferred\n * until to the next tick, similar to `setTimeout` with a timeout of `0`.\n *\n * See [David Corbacho's article](https://css-tricks.com/debouncing-throttling-explained-examples/)\n * for details over the differences between `_.debounce` and `_.throttle`.\n *\n * @static\n * @memberOf _\n * @since 0.1.0\n * @category Function\n * @param {Function} func The function to debounce.\n * @param {number} [wait=0] The number of milliseconds to delay.\n * @param {Object} [options={}] The options object.\n * @param {boolean} [options.leading=false]\n * Specify invoking on the leading edge of the timeout.\n * @param {number} [options.maxWait]\n * The maximum time `func` is allowed to be delayed before it's invoked.\n * @param {boolean} [options.trailing=true]\n * Specify invoking on the trailing edge of the timeout.\n * @returns {Function} Returns the new debounced function.\n * @example\n *\n * // Avoid costly calculations while the window size is in flux.\n * jQuery(window).on('resize', _.debounce(calculateLayout, 150));\n *\n * // Invoke `sendMail` when clicked, debouncing subsequent calls.\n * jQuery(element).on('click', _.debounce(sendMail, 300, {\n * 'leading': true,\n * 'trailing': false\n * }));\n *\n * // Ensure `batchLog` is invoked once after 1 second of debounced calls.\n * var debounced = _.debounce(batchLog, 250, { 'maxWait': 1000 });\n * var source = new EventSource('/stream');\n * jQuery(source).on('message', debounced);\n *\n * // Cancel the trailing debounced invocation.\n * jQuery(window).on('popstate', debounced.cancel);\n */\n function debounce(func, wait, options) {\n var lastArgs,\n lastThis,\n maxWait,\n result,\n timerId,\n lastCallTime,\n lastInvokeTime = 0,\n leading = false,\n maxing = false,\n trailing = true;\n\n if (typeof func != 'function') {\n throw new TypeError(FUNC_ERROR_TEXT);\n }\n wait = toNumber(wait) || 0;\n if (isObject(options)) {\n leading = !!options.leading;\n maxing = 'maxWait' in options;\n maxWait = maxing ? nativeMax(toNumber(options.maxWait) || 0, wait) : maxWait;\n trailing = 'trailing' in options ? !!options.trailing : trailing;\n }\n\n function invokeFunc(time) {\n var args = lastArgs,\n thisArg = lastThis;\n\n lastArgs = lastThis = undefined;\n lastInvokeTime = time;\n result = func.apply(thisArg, args);\n return result;\n }\n\n function leadingEdge(time) {\n // Reset any `maxWait` timer.\n lastInvokeTime = time;\n // Start the timer for the trailing edge.\n timerId = setTimeout(timerExpired, wait);\n // Invoke the leading edge.\n return leading ? invokeFunc(time) : result;\n }\n\n function remainingWait(time) {\n var timeSinceLastCall = time - lastCallTime,\n timeSinceLastInvoke = time - lastInvokeTime,\n timeWaiting = wait - timeSinceLastCall;\n\n return maxing\n ? nativeMin(timeWaiting, maxWait - timeSinceLastInvoke)\n : timeWaiting;\n }\n\n function shouldInvoke(time) {\n var timeSinceLastCall = time - lastCallTime,\n timeSinceLastInvoke = time - lastInvokeTime;\n\n // Either this is the first call, activity has stopped and we're at the\n // trailing edge, the system time has gone backwards and we're treating\n // it as the trailing edge, or we've hit the `maxWait` limit.\n return (lastCallTime === undefined || (timeSinceLastCall >= wait) ||\n (timeSinceLastCall < 0) || (maxing && timeSinceLastInvoke >= maxWait));\n }\n\n function timerExpired() {\n var time = now();\n if (shouldInvoke(time)) {\n return trailingEdge(time);\n }\n // Restart the timer.\n timerId = setTimeout(timerExpired, remainingWait(time));\n }\n\n function trailingEdge(time) {\n timerId = undefined;\n\n // Only invoke if we have `lastArgs` which means `func` has been\n // debounced at least once.\n if (trailing && lastArgs) {\n return invokeFunc(time);\n }\n lastArgs = lastThis = undefined;\n return result;\n }\n\n function cancel() {\n if (timerId !== undefined) {\n clearTimeout(timerId);\n }\n lastInvokeTime = 0;\n lastArgs = lastCallTime = lastThis = timerId = undefined;\n }\n\n function flush() {\n return timerId === undefined ? result : trailingEdge(now());\n }\n\n function debounced() {\n var time = now(),\n isInvoking = shouldInvoke(time);\n\n lastArgs = arguments;\n lastThis = this;\n lastCallTime = time;\n\n if (isInvoking) {\n if (timerId === undefined) {\n return leadingEdge(lastCallTime);\n }\n if (maxing) {\n // Handle invocations in a tight loop.\n clearTimeout(timerId);\n timerId = setTimeout(timerExpired, wait);\n return invokeFunc(lastCallTime);\n }\n }\n if (timerId === undefined) {\n timerId = setTimeout(timerExpired, wait);\n }\n return result;\n }\n debounced.cancel = cancel;\n debounced.flush = flush;\n return debounced;\n }\n\n /**\n * Defers invoking the `func` until the current call stack has cleared. Any\n * additional arguments are provided to `func` when it's invoked.\n *\n * @static\n * @memberOf _\n * @since 0.1.0\n * @category Function\n * @param {Function} func The function to defer.\n * @param {...*} [args] The arguments to invoke `func` with.\n * @returns {number} Returns the timer id.\n * @example\n *\n * _.defer(function(text) {\n * console.log(text);\n * }, 'deferred');\n * // => Logs 'deferred' after one millisecond.\n */\n var defer = baseRest(function(func, args) {\n return baseDelay(func, 1, args);\n });\n\n /**\n * Invokes `func` after `wait` milliseconds. Any additional arguments are\n * provided to `func` when it's invoked.\n *\n * @static\n * @memberOf _\n * @since 0.1.0\n * @category Function\n * @param {Function} func The function to delay.\n * @param {number} wait The number of milliseconds to delay invocation.\n * @param {...*} [args] The arguments to invoke `func` with.\n * @returns {number} Returns the timer id.\n * @example\n *\n * _.delay(function(text) {\n * console.log(text);\n * }, 1000, 'later');\n * // => Logs 'later' after one second.\n */\n var delay = baseRest(function(func, wait, args) {\n return baseDelay(func, toNumber(wait) || 0, args);\n });\n\n /**\n * Creates a function that invokes `func` with arguments reversed.\n *\n * @static\n * @memberOf _\n * @since 4.0.0\n * @category Function\n * @param {Function} func The function to flip arguments for.\n * @returns {Function} Returns the new flipped function.\n * @example\n *\n * var flipped = _.flip(function() {\n * return _.toArray(arguments);\n * });\n *\n * flipped('a', 'b', 'c', 'd');\n * // => ['d', 'c', 'b', 'a']\n */\n function flip(func) {\n return createWrap(func, WRAP_FLIP_FLAG);\n }\n\n /**\n * Creates a function that memoizes the result of `func`. If `resolver` is\n * provided, it determines the cache key for storing the result based on the\n * arguments provided to the memoized function. By default, the first argument\n * provided to the memoized function is used as the map cache key. The `func`\n * is invoked with the `this` binding of the memoized function.\n *\n * **Note:** The cache is exposed as the `cache` property on the memoized\n * function. Its creation may be customized by replacing the `_.memoize.Cache`\n * constructor with one whose instances implement the\n * [`Map`](http://ecma-international.org/ecma-262/7.0/#sec-properties-of-the-map-prototype-object)\n * method interface of `clear`, `delete`, `get`, `has`, and `set`.\n *\n * @static\n * @memberOf _\n * @since 0.1.0\n * @category Function\n * @param {Function} func The function to have its output memoized.\n * @param {Function} [resolver] The function to resolve the cache key.\n * @returns {Function} Returns the new memoized function.\n * @example\n *\n * var object = { 'a': 1, 'b': 2 };\n * var other = { 'c': 3, 'd': 4 };\n *\n * var values = _.memoize(_.values);\n * values(object);\n * // => [1, 2]\n *\n * values(other);\n * // => [3, 4]\n *\n * object.a = 2;\n * values(object);\n * // => [1, 2]\n *\n * // Modify the result cache.\n * values.cache.set(object, ['a', 'b']);\n * values(object);\n * // => ['a', 'b']\n *\n * // Replace `_.memoize.Cache`.\n * _.memoize.Cache = WeakMap;\n */\n function memoize(func, resolver) {\n if (typeof func != 'function' || (resolver != null && typeof resolver != 'function')) {\n throw new TypeError(FUNC_ERROR_TEXT);\n }\n var memoized = function() {\n var args = arguments,\n key = resolver ? resolver.apply(this, args) : args[0],\n cache = memoized.cache;\n\n if (cache.has(key)) {\n return cache.get(key);\n }\n var result = func.apply(this, args);\n memoized.cache = cache.set(key, result) || cache;\n return result;\n };\n memoized.cache = new (memoize.Cache || MapCache);\n return memoized;\n }\n\n // Expose `MapCache`.\n memoize.Cache = MapCache;\n\n /**\n * Creates a function that negates the result of the predicate `func`. The\n * `func` predicate is invoked with the `this` binding and arguments of the\n * created function.\n *\n * @static\n * @memberOf _\n * @since 3.0.0\n * @category Function\n * @param {Function} predicate The predicate to negate.\n * @returns {Function} Returns the new negated function.\n * @example\n *\n * function isEven(n) {\n * return n % 2 == 0;\n * }\n *\n * _.filter([1, 2, 3, 4, 5, 6], _.negate(isEven));\n * // => [1, 3, 5]\n */\n function negate(predicate) {\n if (typeof predicate != 'function') {\n throw new TypeError(FUNC_ERROR_TEXT);\n }\n return function() {\n var args = arguments;\n switch (args.length) {\n case 0: return !predicate.call(this);\n case 1: return !predicate.call(this, args[0]);\n case 2: return !predicate.call(this, args[0], args[1]);\n case 3: return !predicate.call(this, args[0], args[1], args[2]);\n }\n return !predicate.apply(this, args);\n };\n }\n\n /**\n * Creates a function that is restricted to invoking `func` once. Repeat calls\n * to the function return the value of the first invocation. The `func` is\n * invoked with the `this` binding and arguments of the created function.\n *\n * @static\n * @memberOf _\n * @since 0.1.0\n * @category Function\n * @param {Function} func The function to restrict.\n * @returns {Function} Returns the new restricted function.\n * @example\n *\n * var initialize = _.once(createApplication);\n * initialize();\n * initialize();\n * // => `createApplication` is invoked once\n */\n function once(func) {\n return before(2, func);\n }\n\n /**\n * Creates a function that invokes `func` with its arguments transformed.\n *\n * @static\n * @since 4.0.0\n * @memberOf _\n * @category Function\n * @param {Function} func The function to wrap.\n * @param {...(Function|Function[])} [transforms=[_.identity]]\n * The argument transforms.\n * @returns {Function} Returns the new function.\n * @example\n *\n * function doubled(n) {\n * return n * 2;\n * }\n *\n * function square(n) {\n * return n * n;\n * }\n *\n * var func = _.overArgs(function(x, y) {\n * return [x, y];\n * }, [square, doubled]);\n *\n * func(9, 3);\n * // => [81, 6]\n *\n * func(10, 5);\n * // => [100, 10]\n */\n var overArgs = castRest(function(func, transforms) {\n transforms = (transforms.length == 1 && isArray(transforms[0]))\n ? arrayMap(transforms[0], baseUnary(getIteratee()))\n : arrayMap(baseFlatten(transforms, 1), baseUnary(getIteratee()));\n\n var funcsLength = transforms.length;\n return baseRest(function(args) {\n var index = -1,\n length = nativeMin(args.length, funcsLength);\n\n while (++index < length) {\n args[index] = transforms[index].call(this, args[index]);\n }\n return apply(func, this, args);\n });\n });\n\n /**\n * Creates a function that invokes `func` with `partials` prepended to the\n * arguments it receives. This method is like `_.bind` except it does **not**\n * alter the `this` binding.\n *\n * The `_.partial.placeholder` value, which defaults to `_` in monolithic\n * builds, may be used as a placeholder for partially applied arguments.\n *\n * **Note:** This method doesn't set the \"length\" property of partially\n * applied functions.\n *\n * @static\n * @memberOf _\n * @since 0.2.0\n * @category Function\n * @param {Function} func The function to partially apply arguments to.\n * @param {...*} [partials] The arguments to be partially applied.\n * @returns {Function} Returns the new partially applied function.\n * @example\n *\n * function greet(greeting, name) {\n * return greeting + ' ' + name;\n * }\n *\n * var sayHelloTo = _.partial(greet, 'hello');\n * sayHelloTo('fred');\n * // => 'hello fred'\n *\n * // Partially applied with placeholders.\n * var greetFred = _.partial(greet, _, 'fred');\n * greetFred('hi');\n * // => 'hi fred'\n */\n var partial = baseRest(function(func, partials) {\n var holders = replaceHolders(partials, getHolder(partial));\n return createWrap(func, WRAP_PARTIAL_FLAG, undefined, partials, holders);\n });\n\n /**\n * This method is like `_.partial` except that partially applied arguments\n * are appended to the arguments it receives.\n *\n * The `_.partialRight.placeholder` value, which defaults to `_` in monolithic\n * builds, may be used as a placeholder for partially applied arguments.\n *\n * **Note:** This method doesn't set the \"length\" property of partially\n * applied functions.\n *\n * @static\n * @memberOf _\n * @since 1.0.0\n * @category Function\n * @param {Function} func The function to partially apply arguments to.\n * @param {...*} [partials] The arguments to be partially applied.\n * @returns {Function} Returns the new partially applied function.\n * @example\n *\n * function greet(greeting, name) {\n * return greeting + ' ' + name;\n * }\n *\n * var greetFred = _.partialRight(greet, 'fred');\n * greetFred('hi');\n * // => 'hi fred'\n *\n * // Partially applied with placeholders.\n * var sayHelloTo = _.partialRight(greet, 'hello', _);\n * sayHelloTo('fred');\n * // => 'hello fred'\n */\n var partialRight = baseRest(function(func, partials) {\n var holders = replaceHolders(partials, getHolder(partialRight));\n return createWrap(func, WRAP_PARTIAL_RIGHT_FLAG, undefined, partials, holders);\n });\n\n /**\n * Creates a function that invokes `func` with arguments arranged according\n * to the specified `indexes` where the argument value at the first index is\n * provided as the first argument, the argument value at the second index is\n * provided as the second argument, and so on.\n *\n * @static\n * @memberOf _\n * @since 3.0.0\n * @category Function\n * @param {Function} func The function to rearrange arguments for.\n * @param {...(number|number[])} indexes The arranged argument indexes.\n * @returns {Function} Returns the new function.\n * @example\n *\n * var rearged = _.rearg(function(a, b, c) {\n * return [a, b, c];\n * }, [2, 0, 1]);\n *\n * rearged('b', 'c', 'a')\n * // => ['a', 'b', 'c']\n */\n var rearg = flatRest(function(func, indexes) {\n return createWrap(func, WRAP_REARG_FLAG, undefined, undefined, undefined, indexes);\n });\n\n /**\n * Creates a function that invokes `func` with the `this` binding of the\n * created function and arguments from `start` and beyond provided as\n * an array.\n *\n * **Note:** This method is based on the\n * [rest parameter](https://mdn.io/rest_parameters).\n *\n * @static\n * @memberOf _\n * @since 4.0.0\n * @category Function\n * @param {Function} func The function to apply a rest parameter to.\n * @param {number} [start=func.length-1] The start position of the rest parameter.\n * @returns {Function} Returns the new function.\n * @example\n *\n * var say = _.rest(function(what, names) {\n * return what + ' ' + _.initial(names).join(', ') +\n * (_.size(names) > 1 ? ', & ' : '') + _.last(names);\n * });\n *\n * say('hello', 'fred', 'barney', 'pebbles');\n * // => 'hello fred, barney, & pebbles'\n */\n function rest(func, start) {\n if (typeof func != 'function') {\n throw new TypeError(FUNC_ERROR_TEXT);\n }\n start = start === undefined ? start : toInteger(start);\n return baseRest(func, start);\n }\n\n /**\n * Creates a function that invokes `func` with the `this` binding of the\n * create function and an array of arguments much like\n * [`Function#apply`](http://www.ecma-international.org/ecma-262/7.0/#sec-function.prototype.apply).\n *\n * **Note:** This method is based on the\n * [spread operator](https://mdn.io/spread_operator).\n *\n * @static\n * @memberOf _\n * @since 3.2.0\n * @category Function\n * @param {Function} func The function to spread arguments over.\n * @param {number} [start=0] The start position of the spread.\n * @returns {Function} Returns the new function.\n * @example\n *\n * var say = _.spread(function(who, what) {\n * return who + ' says ' + what;\n * });\n *\n * say(['fred', 'hello']);\n * // => 'fred says hello'\n *\n * var numbers = Promise.all([\n * Promise.resolve(40),\n * Promise.resolve(36)\n * ]);\n *\n * numbers.then(_.spread(function(x, y) {\n * return x + y;\n * }));\n * // => a Promise of 76\n */\n function spread(func, start) {\n if (typeof func != 'function') {\n throw new TypeError(FUNC_ERROR_TEXT);\n }\n start = start == null ? 0 : nativeMax(toInteger(start), 0);\n return baseRest(function(args) {\n var array = args[start],\n otherArgs = castSlice(args, 0, start);\n\n if (array) {\n arrayPush(otherArgs, array);\n }\n return apply(func, this, otherArgs);\n });\n }\n\n /**\n * Creates a throttled function that only invokes `func` at most once per\n * every `wait` milliseconds. The throttled function comes with a `cancel`\n * method to cancel delayed `func` invocations and a `flush` method to\n * immediately invoke them. Provide `options` to indicate whether `func`\n * should be invoked on the leading and/or trailing edge of the `wait`\n * timeout. The `func` is invoked with the last arguments provided to the\n * throttled function. Subsequent calls to the throttled function return the\n * result of the last `func` invocation.\n *\n * **Note:** If `leading` and `trailing` options are `true`, `func` is\n * invoked on the trailing edge of the timeout only if the throttled function\n * is invoked more than once during the `wait` timeout.\n *\n * If `wait` is `0` and `leading` is `false`, `func` invocation is deferred\n * until to the next tick, similar to `setTimeout` with a timeout of `0`.\n *\n * See [David Corbacho's article](https://css-tricks.com/debouncing-throttling-explained-examples/)\n * for details over the differences between `_.throttle` and `_.debounce`.\n *\n * @static\n * @memberOf _\n * @since 0.1.0\n * @category Function\n * @param {Function} func The function to throttle.\n * @param {number} [wait=0] The number of milliseconds to throttle invocations to.\n * @param {Object} [options={}] The options object.\n * @param {boolean} [options.leading=true]\n * Specify invoking on the leading edge of the timeout.\n * @param {boolean} [options.trailing=true]\n * Specify invoking on the trailing edge of the timeout.\n * @returns {Function} Returns the new throttled function.\n * @example\n *\n * // Avoid excessively updating the position while scrolling.\n * jQuery(window).on('scroll', _.throttle(updatePosition, 100));\n *\n * // Invoke `renewToken` when the click event is fired, but not more than once every 5 minutes.\n * var throttled = _.throttle(renewToken, 300000, { 'trailing': false });\n * jQuery(element).on('click', throttled);\n *\n * // Cancel the trailing throttled invocation.\n * jQuery(window).on('popstate', throttled.cancel);\n */\n function throttle(func, wait, options) {\n var leading = true,\n trailing = true;\n\n if (typeof func != 'function') {\n throw new TypeError(FUNC_ERROR_TEXT);\n }\n if (isObject(options)) {\n leading = 'leading' in options ? !!options.leading : leading;\n trailing = 'trailing' in options ? !!options.trailing : trailing;\n }\n return debounce(func, wait, {\n 'leading': leading,\n 'maxWait': wait,\n 'trailing': trailing\n });\n }\n\n /**\n * Creates a function that accepts up to one argument, ignoring any\n * additional arguments.\n *\n * @static\n * @memberOf _\n * @since 4.0.0\n * @category Function\n * @param {Function} func The function to cap arguments for.\n * @returns {Function} Returns the new capped function.\n * @example\n *\n * _.map(['6', '8', '10'], _.unary(parseInt));\n * // => [6, 8, 10]\n */\n function unary(func) {\n return ary(func, 1);\n }\n\n /**\n * Creates a function that provides `value` to `wrapper` as its first\n * argument. Any additional arguments provided to the function are appended\n * to those provided to the `wrapper`. The wrapper is invoked with the `this`\n * binding of the created function.\n *\n * @static\n * @memberOf _\n * @since 0.1.0\n * @category Function\n * @param {*} value The value to wrap.\n * @param {Function} [wrapper=identity] The wrapper function.\n * @returns {Function} Returns the new function.\n * @example\n *\n * var p = _.wrap(_.escape, function(func, text) {\n * return '

' + func(text) + '

';\n * });\n *\n * p('fred, barney, & pebbles');\n * // => '

fred, barney, & pebbles

'\n */\n function wrap(value, wrapper) {\n return partial(castFunction(wrapper), value);\n }\n\n /*------------------------------------------------------------------------*/\n\n /**\n * Casts `value` as an array if it's not one.\n *\n * @static\n * @memberOf _\n * @since 4.4.0\n * @category Lang\n * @param {*} value The value to inspect.\n * @returns {Array} Returns the cast array.\n * @example\n *\n * _.castArray(1);\n * // => [1]\n *\n * _.castArray({ 'a': 1 });\n * // => [{ 'a': 1 }]\n *\n * _.castArray('abc');\n * // => ['abc']\n *\n * _.castArray(null);\n * // => [null]\n *\n * _.castArray(undefined);\n * // => [undefined]\n *\n * _.castArray();\n * // => []\n *\n * var array = [1, 2, 3];\n * console.log(_.castArray(array) === array);\n * // => true\n */\n function castArray() {\n if (!arguments.length) {\n return [];\n }\n var value = arguments[0];\n return isArray(value) ? value : [value];\n }\n\n /**\n * Creates a shallow clone of `value`.\n *\n * **Note:** This method is loosely based on the\n * [structured clone algorithm](https://mdn.io/Structured_clone_algorithm)\n * and supports cloning arrays, array buffers, booleans, date objects, maps,\n * numbers, `Object` objects, regexes, sets, strings, symbols, and typed\n * arrays. The own enumerable properties of `arguments` objects are cloned\n * as plain objects. An empty object is returned for uncloneable values such\n * as error objects, functions, DOM nodes, and WeakMaps.\n *\n * @static\n * @memberOf _\n * @since 0.1.0\n * @category Lang\n * @param {*} value The value to clone.\n * @returns {*} Returns the cloned value.\n * @see _.cloneDeep\n * @example\n *\n * var objects = [{ 'a': 1 }, { 'b': 2 }];\n *\n * var shallow = _.clone(objects);\n * console.log(shallow[0] === objects[0]);\n * // => true\n */\n function clone(value) {\n return baseClone(value, CLONE_SYMBOLS_FLAG);\n }\n\n /**\n * This method is like `_.clone` except that it accepts `customizer` which\n * is invoked to produce the cloned value. If `customizer` returns `undefined`,\n * cloning is handled by the method instead. The `customizer` is invoked with\n * up to four arguments; (value [, index|key, object, stack]).\n *\n * @static\n * @memberOf _\n * @since 4.0.0\n * @category Lang\n * @param {*} value The value to clone.\n * @param {Function} [customizer] The function to customize cloning.\n * @returns {*} Returns the cloned value.\n * @see _.cloneDeepWith\n * @example\n *\n * function customizer(value) {\n * if (_.isElement(value)) {\n * return value.cloneNode(false);\n * }\n * }\n *\n * var el = _.cloneWith(document.body, customizer);\n *\n * console.log(el === document.body);\n * // => false\n * console.log(el.nodeName);\n * // => 'BODY'\n * console.log(el.childNodes.length);\n * // => 0\n */\n function cloneWith(value, customizer) {\n customizer = typeof customizer == 'function' ? customizer : undefined;\n return baseClone(value, CLONE_SYMBOLS_FLAG, customizer);\n }\n\n /**\n * This method is like `_.clone` except that it recursively clones `value`.\n *\n * @static\n * @memberOf _\n * @since 1.0.0\n * @category Lang\n * @param {*} value The value to recursively clone.\n * @returns {*} Returns the deep cloned value.\n * @see _.clone\n * @example\n *\n * var objects = [{ 'a': 1 }, { 'b': 2 }];\n *\n * var deep = _.cloneDeep(objects);\n * console.log(deep[0] === objects[0]);\n * // => false\n */\n function cloneDeep(value) {\n return baseClone(value, CLONE_DEEP_FLAG | CLONE_SYMBOLS_FLAG);\n }\n\n /**\n * This method is like `_.cloneWith` except that it recursively clones `value`.\n *\n * @static\n * @memberOf _\n * @since 4.0.0\n * @category Lang\n * @param {*} value The value to recursively clone.\n * @param {Function} [customizer] The function to customize cloning.\n * @returns {*} Returns the deep cloned value.\n * @see _.cloneWith\n * @example\n *\n * function customizer(value) {\n * if (_.isElement(value)) {\n * return value.cloneNode(true);\n * }\n * }\n *\n * var el = _.cloneDeepWith(document.body, customizer);\n *\n * console.log(el === document.body);\n * // => false\n * console.log(el.nodeName);\n * // => 'BODY'\n * console.log(el.childNodes.length);\n * // => 20\n */\n function cloneDeepWith(value, customizer) {\n customizer = typeof customizer == 'function' ? customizer : undefined;\n return baseClone(value, CLONE_DEEP_FLAG | CLONE_SYMBOLS_FLAG, customizer);\n }\n\n /**\n * Checks if `object` conforms to `source` by invoking the predicate\n * properties of `source` with the corresponding property values of `object`.\n *\n * **Note:** This method is equivalent to `_.conforms` when `source` is\n * partially applied.\n *\n * @static\n * @memberOf _\n * @since 4.14.0\n * @category Lang\n * @param {Object} object The object to inspect.\n * @param {Object} source The object of property predicates to conform to.\n * @returns {boolean} Returns `true` if `object` conforms, else `false`.\n * @example\n *\n * var object = { 'a': 1, 'b': 2 };\n *\n * _.conformsTo(object, { 'b': function(n) { return n > 1; } });\n * // => true\n *\n * _.conformsTo(object, { 'b': function(n) { return n > 2; } });\n * // => false\n */\n function conformsTo(object, source) {\n return source == null || baseConformsTo(object, source, keys(source));\n }\n\n /**\n * Performs a\n * [`SameValueZero`](http://ecma-international.org/ecma-262/7.0/#sec-samevaluezero)\n * comparison between two values to determine if they are equivalent.\n *\n * @static\n * @memberOf _\n * @since 4.0.0\n * @category Lang\n * @param {*} value The value to compare.\n * @param {*} other The other value to compare.\n * @returns {boolean} Returns `true` if the values are equivalent, else `false`.\n * @example\n *\n * var object = { 'a': 1 };\n * var other = { 'a': 1 };\n *\n * _.eq(object, object);\n * // => true\n *\n * _.eq(object, other);\n * // => false\n *\n * _.eq('a', 'a');\n * // => true\n *\n * _.eq('a', Object('a'));\n * // => false\n *\n * _.eq(NaN, NaN);\n * // => true\n */\n function eq(value, other) {\n return value === other || (value !== value && other !== other);\n }\n\n /**\n * Checks if `value` is greater than `other`.\n *\n * @static\n * @memberOf _\n * @since 3.9.0\n * @category Lang\n * @param {*} value The value to compare.\n * @param {*} other The other value to compare.\n * @returns {boolean} Returns `true` if `value` is greater than `other`,\n * else `false`.\n * @see _.lt\n * @example\n *\n * _.gt(3, 1);\n * // => true\n *\n * _.gt(3, 3);\n * // => false\n *\n * _.gt(1, 3);\n * // => false\n */\n var gt = createRelationalOperation(baseGt);\n\n /**\n * Checks if `value` is greater than or equal to `other`.\n *\n * @static\n * @memberOf _\n * @since 3.9.0\n * @category Lang\n * @param {*} value The value to compare.\n * @param {*} other The other value to compare.\n * @returns {boolean} Returns `true` if `value` is greater than or equal to\n * `other`, else `false`.\n * @see _.lte\n * @example\n *\n * _.gte(3, 1);\n * // => true\n *\n * _.gte(3, 3);\n * // => true\n *\n * _.gte(1, 3);\n * // => false\n */\n var gte = createRelationalOperation(function(value, other) {\n return value >= other;\n });\n\n /**\n * Checks if `value` is likely an `arguments` object.\n *\n * @static\n * @memberOf _\n * @since 0.1.0\n * @category Lang\n * @param {*} value The value to check.\n * @returns {boolean} Returns `true` if `value` is an `arguments` object,\n * else `false`.\n * @example\n *\n * _.isArguments(function() { return arguments; }());\n * // => true\n *\n * _.isArguments([1, 2, 3]);\n * // => false\n */\n var isArguments = baseIsArguments(function() { return arguments; }()) ? baseIsArguments : function(value) {\n return isObjectLike(value) && hasOwnProperty.call(value, 'callee') &&\n !propertyIsEnumerable.call(value, 'callee');\n };\n\n /**\n * Checks if `value` is classified as an `Array` object.\n *\n * @static\n * @memberOf _\n * @since 0.1.0\n * @category Lang\n * @param {*} value The value to check.\n * @returns {boolean} Returns `true` if `value` is an array, else `false`.\n * @example\n *\n * _.isArray([1, 2, 3]);\n * // => true\n *\n * _.isArray(document.body.children);\n * // => false\n *\n * _.isArray('abc');\n * // => false\n *\n * _.isArray(_.noop);\n * // => false\n */\n var isArray = Array.isArray;\n\n /**\n * Checks if `value` is classified as an `ArrayBuffer` object.\n *\n * @static\n * @memberOf _\n * @since 4.3.0\n * @category Lang\n * @param {*} value The value to check.\n * @returns {boolean} Returns `true` if `value` is an array buffer, else `false`.\n * @example\n *\n * _.isArrayBuffer(new ArrayBuffer(2));\n * // => true\n *\n * _.isArrayBuffer(new Array(2));\n * // => false\n */\n var isArrayBuffer = nodeIsArrayBuffer ? baseUnary(nodeIsArrayBuffer) : baseIsArrayBuffer;\n\n /**\n * Checks if `value` is array-like. A value is considered array-like if it's\n * not a function and has a `value.length` that's an integer greater than or\n * equal to `0` and less than or equal to `Number.MAX_SAFE_INTEGER`.\n *\n * @static\n * @memberOf _\n * @since 4.0.0\n * @category Lang\n * @param {*} value The value to check.\n * @returns {boolean} Returns `true` if `value` is array-like, else `false`.\n * @example\n *\n * _.isArrayLike([1, 2, 3]);\n * // => true\n *\n * _.isArrayLike(document.body.children);\n * // => true\n *\n * _.isArrayLike('abc');\n * // => true\n *\n * _.isArrayLike(_.noop);\n * // => false\n */\n function isArrayLike(value) {\n return value != null && isLength(value.length) && !isFunction(value);\n }\n\n /**\n * This method is like `_.isArrayLike` except that it also checks if `value`\n * is an object.\n *\n * @static\n * @memberOf _\n * @since 4.0.0\n * @category Lang\n * @param {*} value The value to check.\n * @returns {boolean} Returns `true` if `value` is an array-like object,\n * else `false`.\n * @example\n *\n * _.isArrayLikeObject([1, 2, 3]);\n * // => true\n *\n * _.isArrayLikeObject(document.body.children);\n * // => true\n *\n * _.isArrayLikeObject('abc');\n * // => false\n *\n * _.isArrayLikeObject(_.noop);\n * // => false\n */\n function isArrayLikeObject(value) {\n return isObjectLike(value) && isArrayLike(value);\n }\n\n /**\n * Checks if `value` is classified as a boolean primitive or object.\n *\n * @static\n * @memberOf _\n * @since 0.1.0\n * @category Lang\n * @param {*} value The value to check.\n * @returns {boolean} Returns `true` if `value` is a boolean, else `false`.\n * @example\n *\n * _.isBoolean(false);\n * // => true\n *\n * _.isBoolean(null);\n * // => false\n */\n function isBoolean(value) {\n return value === true || value === false ||\n (isObjectLike(value) && baseGetTag(value) == boolTag);\n }\n\n /**\n * Checks if `value` is a buffer.\n *\n * @static\n * @memberOf _\n * @since 4.3.0\n * @category Lang\n * @param {*} value The value to check.\n * @returns {boolean} Returns `true` if `value` is a buffer, else `false`.\n * @example\n *\n * _.isBuffer(new Buffer(2));\n * // => true\n *\n * _.isBuffer(new Uint8Array(2));\n * // => false\n */\n var isBuffer = nativeIsBuffer || stubFalse;\n\n /**\n * Checks if `value` is classified as a `Date` object.\n *\n * @static\n * @memberOf _\n * @since 0.1.0\n * @category Lang\n * @param {*} value The value to check.\n * @returns {boolean} Returns `true` if `value` is a date object, else `false`.\n * @example\n *\n * _.isDate(new Date);\n * // => true\n *\n * _.isDate('Mon April 23 2012');\n * // => false\n */\n var isDate = nodeIsDate ? baseUnary(nodeIsDate) : baseIsDate;\n\n /**\n * Checks if `value` is likely a DOM element.\n *\n * @static\n * @memberOf _\n * @since 0.1.0\n * @category Lang\n * @param {*} value The value to check.\n * @returns {boolean} Returns `true` if `value` is a DOM element, else `false`.\n * @example\n *\n * _.isElement(document.body);\n * // => true\n *\n * _.isElement('');\n * // => false\n */\n function isElement(value) {\n return isObjectLike(value) && value.nodeType === 1 && !isPlainObject(value);\n }\n\n /**\n * Checks if `value` is an empty object, collection, map, or set.\n *\n * Objects are considered empty if they have no own enumerable string keyed\n * properties.\n *\n * Array-like values such as `arguments` objects, arrays, buffers, strings, or\n * jQuery-like collections are considered empty if they have a `length` of `0`.\n * Similarly, maps and sets are considered empty if they have a `size` of `0`.\n *\n * @static\n * @memberOf _\n * @since 0.1.0\n * @category Lang\n * @param {*} value The value to check.\n * @returns {boolean} Returns `true` if `value` is empty, else `false`.\n * @example\n *\n * _.isEmpty(null);\n * // => true\n *\n * _.isEmpty(true);\n * // => true\n *\n * _.isEmpty(1);\n * // => true\n *\n * _.isEmpty([1, 2, 3]);\n * // => false\n *\n * _.isEmpty({ 'a': 1 });\n * // => false\n */\n function isEmpty(value) {\n if (value == null) {\n return true;\n }\n if (isArrayLike(value) &&\n (isArray(value) || typeof value == 'string' || typeof value.splice == 'function' ||\n isBuffer(value) || isTypedArray(value) || isArguments(value))) {\n return !value.length;\n }\n var tag = getTag(value);\n if (tag == mapTag || tag == setTag) {\n return !value.size;\n }\n if (isPrototype(value)) {\n return !baseKeys(value).length;\n }\n for (var key in value) {\n if (hasOwnProperty.call(value, key)) {\n return false;\n }\n }\n return true;\n }\n\n /**\n * Performs a deep comparison between two values to determine if they are\n * equivalent.\n *\n * **Note:** This method supports comparing arrays, array buffers, booleans,\n * date objects, error objects, maps, numbers, `Object` objects, regexes,\n * sets, strings, symbols, and typed arrays. `Object` objects are compared\n * by their own, not inherited, enumerable properties. Functions and DOM\n * nodes are compared by strict equality, i.e. `===`.\n *\n * @static\n * @memberOf _\n * @since 0.1.0\n * @category Lang\n * @param {*} value The value to compare.\n * @param {*} other The other value to compare.\n * @returns {boolean} Returns `true` if the values are equivalent, else `false`.\n * @example\n *\n * var object = { 'a': 1 };\n * var other = { 'a': 1 };\n *\n * _.isEqual(object, other);\n * // => true\n *\n * object === other;\n * // => false\n */\n function isEqual(value, other) {\n return baseIsEqual(value, other);\n }\n\n /**\n * This method is like `_.isEqual` except that it accepts `customizer` which\n * is invoked to compare values. If `customizer` returns `undefined`, comparisons\n * are handled by the method instead. The `customizer` is invoked with up to\n * six arguments: (objValue, othValue [, index|key, object, other, stack]).\n *\n * @static\n * @memberOf _\n * @since 4.0.0\n * @category Lang\n * @param {*} value The value to compare.\n * @param {*} other The other value to compare.\n * @param {Function} [customizer] The function to customize comparisons.\n * @returns {boolean} Returns `true` if the values are equivalent, else `false`.\n * @example\n *\n * function isGreeting(value) {\n * return /^h(?:i|ello)$/.test(value);\n * }\n *\n * function customizer(objValue, othValue) {\n * if (isGreeting(objValue) && isGreeting(othValue)) {\n * return true;\n * }\n * }\n *\n * var array = ['hello', 'goodbye'];\n * var other = ['hi', 'goodbye'];\n *\n * _.isEqualWith(array, other, customizer);\n * // => true\n */\n function isEqualWith(value, other, customizer) {\n customizer = typeof customizer == 'function' ? customizer : undefined;\n var result = customizer ? customizer(value, other) : undefined;\n return result === undefined ? baseIsEqual(value, other, undefined, customizer) : !!result;\n }\n\n /**\n * Checks if `value` is an `Error`, `EvalError`, `RangeError`, `ReferenceError`,\n * `SyntaxError`, `TypeError`, or `URIError` object.\n *\n * @static\n * @memberOf _\n * @since 3.0.0\n * @category Lang\n * @param {*} value The value to check.\n * @returns {boolean} Returns `true` if `value` is an error object, else `false`.\n * @example\n *\n * _.isError(new Error);\n * // => true\n *\n * _.isError(Error);\n * // => false\n */\n function isError(value) {\n if (!isObjectLike(value)) {\n return false;\n }\n var tag = baseGetTag(value);\n return tag == errorTag || tag == domExcTag ||\n (typeof value.message == 'string' && typeof value.name == 'string' && !isPlainObject(value));\n }\n\n /**\n * Checks if `value` is a finite primitive number.\n *\n * **Note:** This method is based on\n * [`Number.isFinite`](https://mdn.io/Number/isFinite).\n *\n * @static\n * @memberOf _\n * @since 0.1.0\n * @category Lang\n * @param {*} value The value to check.\n * @returns {boolean} Returns `true` if `value` is a finite number, else `false`.\n * @example\n *\n * _.isFinite(3);\n * // => true\n *\n * _.isFinite(Number.MIN_VALUE);\n * // => true\n *\n * _.isFinite(Infinity);\n * // => false\n *\n * _.isFinite('3');\n * // => false\n */\n function isFinite(value) {\n return typeof value == 'number' && nativeIsFinite(value);\n }\n\n /**\n * Checks if `value` is classified as a `Function` object.\n *\n * @static\n * @memberOf _\n * @since 0.1.0\n * @category Lang\n * @param {*} value The value to check.\n * @returns {boolean} Returns `true` if `value` is a function, else `false`.\n * @example\n *\n * _.isFunction(_);\n * // => true\n *\n * _.isFunction(/abc/);\n * // => false\n */\n function isFunction(value) {\n if (!isObject(value)) {\n return false;\n }\n // The use of `Object#toString` avoids issues with the `typeof` operator\n // in Safari 9 which returns 'object' for typed arrays and other constructors.\n var tag = baseGetTag(value);\n return tag == funcTag || tag == genTag || tag == asyncTag || tag == proxyTag;\n }\n\n /**\n * Checks if `value` is an integer.\n *\n * **Note:** This method is based on\n * [`Number.isInteger`](https://mdn.io/Number/isInteger).\n *\n * @static\n * @memberOf _\n * @since 4.0.0\n * @category Lang\n * @param {*} value The value to check.\n * @returns {boolean} Returns `true` if `value` is an integer, else `false`.\n * @example\n *\n * _.isInteger(3);\n * // => true\n *\n * _.isInteger(Number.MIN_VALUE);\n * // => false\n *\n * _.isInteger(Infinity);\n * // => false\n *\n * _.isInteger('3');\n * // => false\n */\n function isInteger(value) {\n return typeof value == 'number' && value == toInteger(value);\n }\n\n /**\n * Checks if `value` is a valid array-like length.\n *\n * **Note:** This method is loosely based on\n * [`ToLength`](http://ecma-international.org/ecma-262/7.0/#sec-tolength).\n *\n * @static\n * @memberOf _\n * @since 4.0.0\n * @category Lang\n * @param {*} value The value to check.\n * @returns {boolean} Returns `true` if `value` is a valid length, else `false`.\n * @example\n *\n * _.isLength(3);\n * // => true\n *\n * _.isLength(Number.MIN_VALUE);\n * // => false\n *\n * _.isLength(Infinity);\n * // => false\n *\n * _.isLength('3');\n * // => false\n */\n function isLength(value) {\n return typeof value == 'number' &&\n value > -1 && value % 1 == 0 && value <= MAX_SAFE_INTEGER;\n }\n\n /**\n * Checks if `value` is the\n * [language type](http://www.ecma-international.org/ecma-262/7.0/#sec-ecmascript-language-types)\n * of `Object`. (e.g. arrays, functions, objects, regexes, `new Number(0)`, and `new String('')`)\n *\n * @static\n * @memberOf _\n * @since 0.1.0\n * @category Lang\n * @param {*} value The value to check.\n * @returns {boolean} Returns `true` if `value` is an object, else `false`.\n * @example\n *\n * _.isObject({});\n * // => true\n *\n * _.isObject([1, 2, 3]);\n * // => true\n *\n * _.isObject(_.noop);\n * // => true\n *\n * _.isObject(null);\n * // => false\n */\n function isObject(value) {\n var type = typeof value;\n return value != null && (type == 'object' || type == 'function');\n }\n\n /**\n * Checks if `value` is object-like. A value is object-like if it's not `null`\n * and has a `typeof` result of \"object\".\n *\n * @static\n * @memberOf _\n * @since 4.0.0\n * @category Lang\n * @param {*} value The value to check.\n * @returns {boolean} Returns `true` if `value` is object-like, else `false`.\n * @example\n *\n * _.isObjectLike({});\n * // => true\n *\n * _.isObjectLike([1, 2, 3]);\n * // => true\n *\n * _.isObjectLike(_.noop);\n * // => false\n *\n * _.isObjectLike(null);\n * // => false\n */\n function isObjectLike(value) {\n return value != null && typeof value == 'object';\n }\n\n /**\n * Checks if `value` is classified as a `Map` object.\n *\n * @static\n * @memberOf _\n * @since 4.3.0\n * @category Lang\n * @param {*} value The value to check.\n * @returns {boolean} Returns `true` if `value` is a map, else `false`.\n * @example\n *\n * _.isMap(new Map);\n * // => true\n *\n * _.isMap(new WeakMap);\n * // => false\n */\n var isMap = nodeIsMap ? baseUnary(nodeIsMap) : baseIsMap;\n\n /**\n * Performs a partial deep comparison between `object` and `source` to\n * determine if `object` contains equivalent property values.\n *\n * **Note:** This method is equivalent to `_.matches` when `source` is\n * partially applied.\n *\n * Partial comparisons will match empty array and empty object `source`\n * values against any array or object value, respectively. See `_.isEqual`\n * for a list of supported value comparisons.\n *\n * @static\n * @memberOf _\n * @since 3.0.0\n * @category Lang\n * @param {Object} object The object to inspect.\n * @param {Object} source The object of property values to match.\n * @returns {boolean} Returns `true` if `object` is a match, else `false`.\n * @example\n *\n * var object = { 'a': 1, 'b': 2 };\n *\n * _.isMatch(object, { 'b': 2 });\n * // => true\n *\n * _.isMatch(object, { 'b': 1 });\n * // => false\n */\n function isMatch(object, source) {\n return object === source || baseIsMatch(object, source, getMatchData(source));\n }\n\n /**\n * This method is like `_.isMatch` except that it accepts `customizer` which\n * is invoked to compare values. If `customizer` returns `undefined`, comparisons\n * are handled by the method instead. The `customizer` is invoked with five\n * arguments: (objValue, srcValue, index|key, object, source).\n *\n * @static\n * @memberOf _\n * @since 4.0.0\n * @category Lang\n * @param {Object} object The object to inspect.\n * @param {Object} source The object of property values to match.\n * @param {Function} [customizer] The function to customize comparisons.\n * @returns {boolean} Returns `true` if `object` is a match, else `false`.\n * @example\n *\n * function isGreeting(value) {\n * return /^h(?:i|ello)$/.test(value);\n * }\n *\n * function customizer(objValue, srcValue) {\n * if (isGreeting(objValue) && isGreeting(srcValue)) {\n * return true;\n * }\n * }\n *\n * var object = { 'greeting': 'hello' };\n * var source = { 'greeting': 'hi' };\n *\n * _.isMatchWith(object, source, customizer);\n * // => true\n */\n function isMatchWith(object, source, customizer) {\n customizer = typeof customizer == 'function' ? customizer : undefined;\n return baseIsMatch(object, source, getMatchData(source), customizer);\n }\n\n /**\n * Checks if `value` is `NaN`.\n *\n * **Note:** This method is based on\n * [`Number.isNaN`](https://mdn.io/Number/isNaN) and is not the same as\n * global [`isNaN`](https://mdn.io/isNaN) which returns `true` for\n * `undefined` and other non-number values.\n *\n * @static\n * @memberOf _\n * @since 0.1.0\n * @category Lang\n * @param {*} value The value to check.\n * @returns {boolean} Returns `true` if `value` is `NaN`, else `false`.\n * @example\n *\n * _.isNaN(NaN);\n * // => true\n *\n * _.isNaN(new Number(NaN));\n * // => true\n *\n * isNaN(undefined);\n * // => true\n *\n * _.isNaN(undefined);\n * // => false\n */\n function isNaN(value) {\n // An `NaN` primitive is the only value that is not equal to itself.\n // Perform the `toStringTag` check first to avoid errors with some\n // ActiveX objects in IE.\n return isNumber(value) && value != +value;\n }\n\n /**\n * Checks if `value` is a pristine native function.\n *\n * **Note:** This method can't reliably detect native functions in the presence\n * of the core-js package because core-js circumvents this kind of detection.\n * Despite multiple requests, the core-js maintainer has made it clear: any\n * attempt to fix the detection will be obstructed. As a result, we're left\n * with little choice but to throw an error. Unfortunately, this also affects\n * packages, like [babel-polyfill](https://www.npmjs.com/package/babel-polyfill),\n * which rely on core-js.\n *\n * @static\n * @memberOf _\n * @since 3.0.0\n * @category Lang\n * @param {*} value The value to check.\n * @returns {boolean} Returns `true` if `value` is a native function,\n * else `false`.\n * @example\n *\n * _.isNative(Array.prototype.push);\n * // => true\n *\n * _.isNative(_);\n * // => false\n */\n function isNative(value) {\n if (isMaskable(value)) {\n throw new Error(CORE_ERROR_TEXT);\n }\n return baseIsNative(value);\n }\n\n /**\n * Checks if `value` is `null`.\n *\n * @static\n * @memberOf _\n * @since 0.1.0\n * @category Lang\n * @param {*} value The value to check.\n * @returns {boolean} Returns `true` if `value` is `null`, else `false`.\n * @example\n *\n * _.isNull(null);\n * // => true\n *\n * _.isNull(void 0);\n * // => false\n */\n function isNull(value) {\n return value === null;\n }\n\n /**\n * Checks if `value` is `null` or `undefined`.\n *\n * @static\n * @memberOf _\n * @since 4.0.0\n * @category Lang\n * @param {*} value The value to check.\n * @returns {boolean} Returns `true` if `value` is nullish, else `false`.\n * @example\n *\n * _.isNil(null);\n * // => true\n *\n * _.isNil(void 0);\n * // => true\n *\n * _.isNil(NaN);\n * // => false\n */\n function isNil(value) {\n return value == null;\n }\n\n /**\n * Checks if `value` is classified as a `Number` primitive or object.\n *\n * **Note:** To exclude `Infinity`, `-Infinity`, and `NaN`, which are\n * classified as numbers, use the `_.isFinite` method.\n *\n * @static\n * @memberOf _\n * @since 0.1.0\n * @category Lang\n * @param {*} value The value to check.\n * @returns {boolean} Returns `true` if `value` is a number, else `false`.\n * @example\n *\n * _.isNumber(3);\n * // => true\n *\n * _.isNumber(Number.MIN_VALUE);\n * // => true\n *\n * _.isNumber(Infinity);\n * // => true\n *\n * _.isNumber('3');\n * // => false\n */\n function isNumber(value) {\n return typeof value == 'number' ||\n (isObjectLike(value) && baseGetTag(value) == numberTag);\n }\n\n /**\n * Checks if `value` is a plain object, that is, an object created by the\n * `Object` constructor or one with a `[[Prototype]]` of `null`.\n *\n * @static\n * @memberOf _\n * @since 0.8.0\n * @category Lang\n * @param {*} value The value to check.\n * @returns {boolean} Returns `true` if `value` is a plain object, else `false`.\n * @example\n *\n * function Foo() {\n * this.a = 1;\n * }\n *\n * _.isPlainObject(new Foo);\n * // => false\n *\n * _.isPlainObject([1, 2, 3]);\n * // => false\n *\n * _.isPlainObject({ 'x': 0, 'y': 0 });\n * // => true\n *\n * _.isPlainObject(Object.create(null));\n * // => true\n */\n function isPlainObject(value) {\n if (!isObjectLike(value) || baseGetTag(value) != objectTag) {\n return false;\n }\n var proto = getPrototype(value);\n if (proto === null) {\n return true;\n }\n var Ctor = hasOwnProperty.call(proto, 'constructor') && proto.constructor;\n return typeof Ctor == 'function' && Ctor instanceof Ctor &&\n funcToString.call(Ctor) == objectCtorString;\n }\n\n /**\n * Checks if `value` is classified as a `RegExp` object.\n *\n * @static\n * @memberOf _\n * @since 0.1.0\n * @category Lang\n * @param {*} value The value to check.\n * @returns {boolean} Returns `true` if `value` is a regexp, else `false`.\n * @example\n *\n * _.isRegExp(/abc/);\n * // => true\n *\n * _.isRegExp('/abc/');\n * // => false\n */\n var isRegExp = nodeIsRegExp ? baseUnary(nodeIsRegExp) : baseIsRegExp;\n\n /**\n * Checks if `value` is a safe integer. An integer is safe if it's an IEEE-754\n * double precision number which isn't the result of a rounded unsafe integer.\n *\n * **Note:** This method is based on\n * [`Number.isSafeInteger`](https://mdn.io/Number/isSafeInteger).\n *\n * @static\n * @memberOf _\n * @since 4.0.0\n * @category Lang\n * @param {*} value The value to check.\n * @returns {boolean} Returns `true` if `value` is a safe integer, else `false`.\n * @example\n *\n * _.isSafeInteger(3);\n * // => true\n *\n * _.isSafeInteger(Number.MIN_VALUE);\n * // => false\n *\n * _.isSafeInteger(Infinity);\n * // => false\n *\n * _.isSafeInteger('3');\n * // => false\n */\n function isSafeInteger(value) {\n return isInteger(value) && value >= -MAX_SAFE_INTEGER && value <= MAX_SAFE_INTEGER;\n }\n\n /**\n * Checks if `value` is classified as a `Set` object.\n *\n * @static\n * @memberOf _\n * @since 4.3.0\n * @category Lang\n * @param {*} value The value to check.\n * @returns {boolean} Returns `true` if `value` is a set, else `false`.\n * @example\n *\n * _.isSet(new Set);\n * // => true\n *\n * _.isSet(new WeakSet);\n * // => false\n */\n var isSet = nodeIsSet ? baseUnary(nodeIsSet) : baseIsSet;\n\n /**\n * Checks if `value` is classified as a `String` primitive or object.\n *\n * @static\n * @since 0.1.0\n * @memberOf _\n * @category Lang\n * @param {*} value The value to check.\n * @returns {boolean} Returns `true` if `value` is a string, else `false`.\n * @example\n *\n * _.isString('abc');\n * // => true\n *\n * _.isString(1);\n * // => false\n */\n function isString(value) {\n return typeof value == 'string' ||\n (!isArray(value) && isObjectLike(value) && baseGetTag(value) == stringTag);\n }\n\n /**\n * Checks if `value` is classified as a `Symbol` primitive or object.\n *\n * @static\n * @memberOf _\n * @since 4.0.0\n * @category Lang\n * @param {*} value The value to check.\n * @returns {boolean} Returns `true` if `value` is a symbol, else `false`.\n * @example\n *\n * _.isSymbol(Symbol.iterator);\n * // => true\n *\n * _.isSymbol('abc');\n * // => false\n */\n function isSymbol(value) {\n return typeof value == 'symbol' ||\n (isObjectLike(value) && baseGetTag(value) == symbolTag);\n }\n\n /**\n * Checks if `value` is classified as a typed array.\n *\n * @static\n * @memberOf _\n * @since 3.0.0\n * @category Lang\n * @param {*} value The value to check.\n * @returns {boolean} Returns `true` if `value` is a typed array, else `false`.\n * @example\n *\n * _.isTypedArray(new Uint8Array);\n * // => true\n *\n * _.isTypedArray([]);\n * // => false\n */\n var isTypedArray = nodeIsTypedArray ? baseUnary(nodeIsTypedArray) : baseIsTypedArray;\n\n /**\n * Checks if `value` is `undefined`.\n *\n * @static\n * @since 0.1.0\n * @memberOf _\n * @category Lang\n * @param {*} value The value to check.\n * @returns {boolean} Returns `true` if `value` is `undefined`, else `false`.\n * @example\n *\n * _.isUndefined(void 0);\n * // => true\n *\n * _.isUndefined(null);\n * // => false\n */\n function isUndefined(value) {\n return value === undefined;\n }\n\n /**\n * Checks if `value` is classified as a `WeakMap` object.\n *\n * @static\n * @memberOf _\n * @since 4.3.0\n * @category Lang\n * @param {*} value The value to check.\n * @returns {boolean} Returns `true` if `value` is a weak map, else `false`.\n * @example\n *\n * _.isWeakMap(new WeakMap);\n * // => true\n *\n * _.isWeakMap(new Map);\n * // => false\n */\n function isWeakMap(value) {\n return isObjectLike(value) && getTag(value) == weakMapTag;\n }\n\n /**\n * Checks if `value` is classified as a `WeakSet` object.\n *\n * @static\n * @memberOf _\n * @since 4.3.0\n * @category Lang\n * @param {*} value The value to check.\n * @returns {boolean} Returns `true` if `value` is a weak set, else `false`.\n * @example\n *\n * _.isWeakSet(new WeakSet);\n * // => true\n *\n * _.isWeakSet(new Set);\n * // => false\n */\n function isWeakSet(value) {\n return isObjectLike(value) && baseGetTag(value) == weakSetTag;\n }\n\n /**\n * Checks if `value` is less than `other`.\n *\n * @static\n * @memberOf _\n * @since 3.9.0\n * @category Lang\n * @param {*} value The value to compare.\n * @param {*} other The other value to compare.\n * @returns {boolean} Returns `true` if `value` is less than `other`,\n * else `false`.\n * @see _.gt\n * @example\n *\n * _.lt(1, 3);\n * // => true\n *\n * _.lt(3, 3);\n * // => false\n *\n * _.lt(3, 1);\n * // => false\n */\n var lt = createRelationalOperation(baseLt);\n\n /**\n * Checks if `value` is less than or equal to `other`.\n *\n * @static\n * @memberOf _\n * @since 3.9.0\n * @category Lang\n * @param {*} value The value to compare.\n * @param {*} other The other value to compare.\n * @returns {boolean} Returns `true` if `value` is less than or equal to\n * `other`, else `false`.\n * @see _.gte\n * @example\n *\n * _.lte(1, 3);\n * // => true\n *\n * _.lte(3, 3);\n * // => true\n *\n * _.lte(3, 1);\n * // => false\n */\n var lte = createRelationalOperation(function(value, other) {\n return value <= other;\n });\n\n /**\n * Converts `value` to an array.\n *\n * @static\n * @since 0.1.0\n * @memberOf _\n * @category Lang\n * @param {*} value The value to convert.\n * @returns {Array} Returns the converted array.\n * @example\n *\n * _.toArray({ 'a': 1, 'b': 2 });\n * // => [1, 2]\n *\n * _.toArray('abc');\n * // => ['a', 'b', 'c']\n *\n * _.toArray(1);\n * // => []\n *\n * _.toArray(null);\n * // => []\n */\n function toArray(value) {\n if (!value) {\n return [];\n }\n if (isArrayLike(value)) {\n return isString(value) ? stringToArray(value) : copyArray(value);\n }\n if (symIterator && value[symIterator]) {\n return iteratorToArray(value[symIterator]());\n }\n var tag = getTag(value),\n func = tag == mapTag ? mapToArray : (tag == setTag ? setToArray : values);\n\n return func(value);\n }\n\n /**\n * Converts `value` to a finite number.\n *\n * @static\n * @memberOf _\n * @since 4.12.0\n * @category Lang\n * @param {*} value The value to convert.\n * @returns {number} Returns the converted number.\n * @example\n *\n * _.toFinite(3.2);\n * // => 3.2\n *\n * _.toFinite(Number.MIN_VALUE);\n * // => 5e-324\n *\n * _.toFinite(Infinity);\n * // => 1.7976931348623157e+308\n *\n * _.toFinite('3.2');\n * // => 3.2\n */\n function toFinite(value) {\n if (!value) {\n return value === 0 ? value : 0;\n }\n value = toNumber(value);\n if (value === INFINITY || value === -INFINITY) {\n var sign = (value < 0 ? -1 : 1);\n return sign * MAX_INTEGER;\n }\n return value === value ? value : 0;\n }\n\n /**\n * Converts `value` to an integer.\n *\n * **Note:** This method is loosely based on\n * [`ToInteger`](http://www.ecma-international.org/ecma-262/7.0/#sec-tointeger).\n *\n * @static\n * @memberOf _\n * @since 4.0.0\n * @category Lang\n * @param {*} value The value to convert.\n * @returns {number} Returns the converted integer.\n * @example\n *\n * _.toInteger(3.2);\n * // => 3\n *\n * _.toInteger(Number.MIN_VALUE);\n * // => 0\n *\n * _.toInteger(Infinity);\n * // => 1.7976931348623157e+308\n *\n * _.toInteger('3.2');\n * // => 3\n */\n function toInteger(value) {\n var result = toFinite(value),\n remainder = result % 1;\n\n return result === result ? (remainder ? result - remainder : result) : 0;\n }\n\n /**\n * Converts `value` to an integer suitable for use as the length of an\n * array-like object.\n *\n * **Note:** This method is based on\n * [`ToLength`](http://ecma-international.org/ecma-262/7.0/#sec-tolength).\n *\n * @static\n * @memberOf _\n * @since 4.0.0\n * @category Lang\n * @param {*} value The value to convert.\n * @returns {number} Returns the converted integer.\n * @example\n *\n * _.toLength(3.2);\n * // => 3\n *\n * _.toLength(Number.MIN_VALUE);\n * // => 0\n *\n * _.toLength(Infinity);\n * // => 4294967295\n *\n * _.toLength('3.2');\n * // => 3\n */\n function toLength(value) {\n return value ? baseClamp(toInteger(value), 0, MAX_ARRAY_LENGTH) : 0;\n }\n\n /**\n * Converts `value` to a number.\n *\n * @static\n * @memberOf _\n * @since 4.0.0\n * @category Lang\n * @param {*} value The value to process.\n * @returns {number} Returns the number.\n * @example\n *\n * _.toNumber(3.2);\n * // => 3.2\n *\n * _.toNumber(Number.MIN_VALUE);\n * // => 5e-324\n *\n * _.toNumber(Infinity);\n * // => Infinity\n *\n * _.toNumber('3.2');\n * // => 3.2\n */\n function toNumber(value) {\n if (typeof value == 'number') {\n return value;\n }\n if (isSymbol(value)) {\n return NAN;\n }\n if (isObject(value)) {\n var other = typeof value.valueOf == 'function' ? value.valueOf() : value;\n value = isObject(other) ? (other + '') : other;\n }\n if (typeof value != 'string') {\n return value === 0 ? value : +value;\n }\n value = baseTrim(value);\n var isBinary = reIsBinary.test(value);\n return (isBinary || reIsOctal.test(value))\n ? freeParseInt(value.slice(2), isBinary ? 2 : 8)\n : (reIsBadHex.test(value) ? NAN : +value);\n }\n\n /**\n * Converts `value` to a plain object flattening inherited enumerable string\n * keyed properties of `value` to own properties of the plain object.\n *\n * @static\n * @memberOf _\n * @since 3.0.0\n * @category Lang\n * @param {*} value The value to convert.\n * @returns {Object} Returns the converted plain object.\n * @example\n *\n * function Foo() {\n * this.b = 2;\n * }\n *\n * Foo.prototype.c = 3;\n *\n * _.assign({ 'a': 1 }, new Foo);\n * // => { 'a': 1, 'b': 2 }\n *\n * _.assign({ 'a': 1 }, _.toPlainObject(new Foo));\n * // => { 'a': 1, 'b': 2, 'c': 3 }\n */\n function toPlainObject(value) {\n return copyObject(value, keysIn(value));\n }\n\n /**\n * Converts `value` to a safe integer. A safe integer can be compared and\n * represented correctly.\n *\n * @static\n * @memberOf _\n * @since 4.0.0\n * @category Lang\n * @param {*} value The value to convert.\n * @returns {number} Returns the converted integer.\n * @example\n *\n * _.toSafeInteger(3.2);\n * // => 3\n *\n * _.toSafeInteger(Number.MIN_VALUE);\n * // => 0\n *\n * _.toSafeInteger(Infinity);\n * // => 9007199254740991\n *\n * _.toSafeInteger('3.2');\n * // => 3\n */\n function toSafeInteger(value) {\n return value\n ? baseClamp(toInteger(value), -MAX_SAFE_INTEGER, MAX_SAFE_INTEGER)\n : (value === 0 ? value : 0);\n }\n\n /**\n * Converts `value` to a string. An empty string is returned for `null`\n * and `undefined` values. The sign of `-0` is preserved.\n *\n * @static\n * @memberOf _\n * @since 4.0.0\n * @category Lang\n * @param {*} value The value to convert.\n * @returns {string} Returns the converted string.\n * @example\n *\n * _.toString(null);\n * // => ''\n *\n * _.toString(-0);\n * // => '-0'\n *\n * _.toString([1, 2, 3]);\n * // => '1,2,3'\n */\n function toString(value) {\n return value == null ? '' : baseToString(value);\n }\n\n /*------------------------------------------------------------------------*/\n\n /**\n * Assigns own enumerable string keyed properties of source objects to the\n * destination object. Source objects are applied from left to right.\n * Subsequent sources overwrite property assignments of previous sources.\n *\n * **Note:** This method mutates `object` and is loosely based on\n * [`Object.assign`](https://mdn.io/Object/assign).\n *\n * @static\n * @memberOf _\n * @since 0.10.0\n * @category Object\n * @param {Object} object The destination object.\n * @param {...Object} [sources] The source objects.\n * @returns {Object} Returns `object`.\n * @see _.assignIn\n * @example\n *\n * function Foo() {\n * this.a = 1;\n * }\n *\n * function Bar() {\n * this.c = 3;\n * }\n *\n * Foo.prototype.b = 2;\n * Bar.prototype.d = 4;\n *\n * _.assign({ 'a': 0 }, new Foo, new Bar);\n * // => { 'a': 1, 'c': 3 }\n */\n var assign = createAssigner(function(object, source) {\n if (isPrototype(source) || isArrayLike(source)) {\n copyObject(source, keys(source), object);\n return;\n }\n for (var key in source) {\n if (hasOwnProperty.call(source, key)) {\n assignValue(object, key, source[key]);\n }\n }\n });\n\n /**\n * This method is like `_.assign` except that it iterates over own and\n * inherited source properties.\n *\n * **Note:** This method mutates `object`.\n *\n * @static\n * @memberOf _\n * @since 4.0.0\n * @alias extend\n * @category Object\n * @param {Object} object The destination object.\n * @param {...Object} [sources] The source objects.\n * @returns {Object} Returns `object`.\n * @see _.assign\n * @example\n *\n * function Foo() {\n * this.a = 1;\n * }\n *\n * function Bar() {\n * this.c = 3;\n * }\n *\n * Foo.prototype.b = 2;\n * Bar.prototype.d = 4;\n *\n * _.assignIn({ 'a': 0 }, new Foo, new Bar);\n * // => { 'a': 1, 'b': 2, 'c': 3, 'd': 4 }\n */\n var assignIn = createAssigner(function(object, source) {\n copyObject(source, keysIn(source), object);\n });\n\n /**\n * This method is like `_.assignIn` except that it accepts `customizer`\n * which is invoked to produce the assigned values. If `customizer` returns\n * `undefined`, assignment is handled by the method instead. The `customizer`\n * is invoked with five arguments: (objValue, srcValue, key, object, source).\n *\n * **Note:** This method mutates `object`.\n *\n * @static\n * @memberOf _\n * @since 4.0.0\n * @alias extendWith\n * @category Object\n * @param {Object} object The destination object.\n * @param {...Object} sources The source objects.\n * @param {Function} [customizer] The function to customize assigned values.\n * @returns {Object} Returns `object`.\n * @see _.assignWith\n * @example\n *\n * function customizer(objValue, srcValue) {\n * return _.isUndefined(objValue) ? srcValue : objValue;\n * }\n *\n * var defaults = _.partialRight(_.assignInWith, customizer);\n *\n * defaults({ 'a': 1 }, { 'b': 2 }, { 'a': 3 });\n * // => { 'a': 1, 'b': 2 }\n */\n var assignInWith = createAssigner(function(object, source, srcIndex, customizer) {\n copyObject(source, keysIn(source), object, customizer);\n });\n\n /**\n * This method is like `_.assign` except that it accepts `customizer`\n * which is invoked to produce the assigned values. If `customizer` returns\n * `undefined`, assignment is handled by the method instead. The `customizer`\n * is invoked with five arguments: (objValue, srcValue, key, object, source).\n *\n * **Note:** This method mutates `object`.\n *\n * @static\n * @memberOf _\n * @since 4.0.0\n * @category Object\n * @param {Object} object The destination object.\n * @param {...Object} sources The source objects.\n * @param {Function} [customizer] The function to customize assigned values.\n * @returns {Object} Returns `object`.\n * @see _.assignInWith\n * @example\n *\n * function customizer(objValue, srcValue) {\n * return _.isUndefined(objValue) ? srcValue : objValue;\n * }\n *\n * var defaults = _.partialRight(_.assignWith, customizer);\n *\n * defaults({ 'a': 1 }, { 'b': 2 }, { 'a': 3 });\n * // => { 'a': 1, 'b': 2 }\n */\n var assignWith = createAssigner(function(object, source, srcIndex, customizer) {\n copyObject(source, keys(source), object, customizer);\n });\n\n /**\n * Creates an array of values corresponding to `paths` of `object`.\n *\n * @static\n * @memberOf _\n * @since 1.0.0\n * @category Object\n * @param {Object} object The object to iterate over.\n * @param {...(string|string[])} [paths] The property paths to pick.\n * @returns {Array} Returns the picked values.\n * @example\n *\n * var object = { 'a': [{ 'b': { 'c': 3 } }, 4] };\n *\n * _.at(object, ['a[0].b.c', 'a[1]']);\n * // => [3, 4]\n */\n var at = flatRest(baseAt);\n\n /**\n * Creates an object that inherits from the `prototype` object. If a\n * `properties` object is given, its own enumerable string keyed properties\n * are assigned to the created object.\n *\n * @static\n * @memberOf _\n * @since 2.3.0\n * @category Object\n * @param {Object} prototype The object to inherit from.\n * @param {Object} [properties] The properties to assign to the object.\n * @returns {Object} Returns the new object.\n * @example\n *\n * function Shape() {\n * this.x = 0;\n * this.y = 0;\n * }\n *\n * function Circle() {\n * Shape.call(this);\n * }\n *\n * Circle.prototype = _.create(Shape.prototype, {\n * 'constructor': Circle\n * });\n *\n * var circle = new Circle;\n * circle instanceof Circle;\n * // => true\n *\n * circle instanceof Shape;\n * // => true\n */\n function create(prototype, properties) {\n var result = baseCreate(prototype);\n return properties == null ? result : baseAssign(result, properties);\n }\n\n /**\n * Assigns own and inherited enumerable string keyed properties of source\n * objects to the destination object for all destination properties that\n * resolve to `undefined`. Source objects are applied from left to right.\n * Once a property is set, additional values of the same property are ignored.\n *\n * **Note:** This method mutates `object`.\n *\n * @static\n * @since 0.1.0\n * @memberOf _\n * @category Object\n * @param {Object} object The destination object.\n * @param {...Object} [sources] The source objects.\n * @returns {Object} Returns `object`.\n * @see _.defaultsDeep\n * @example\n *\n * _.defaults({ 'a': 1 }, { 'b': 2 }, { 'a': 3 });\n * // => { 'a': 1, 'b': 2 }\n */\n var defaults = baseRest(function(object, sources) {\n object = Object(object);\n\n var index = -1;\n var length = sources.length;\n var guard = length > 2 ? sources[2] : undefined;\n\n if (guard && isIterateeCall(sources[0], sources[1], guard)) {\n length = 1;\n }\n\n while (++index < length) {\n var source = sources[index];\n var props = keysIn(source);\n var propsIndex = -1;\n var propsLength = props.length;\n\n while (++propsIndex < propsLength) {\n var key = props[propsIndex];\n var value = object[key];\n\n if (value === undefined ||\n (eq(value, objectProto[key]) && !hasOwnProperty.call(object, key))) {\n object[key] = source[key];\n }\n }\n }\n\n return object;\n });\n\n /**\n * This method is like `_.defaults` except that it recursively assigns\n * default properties.\n *\n * **Note:** This method mutates `object`.\n *\n * @static\n * @memberOf _\n * @since 3.10.0\n * @category Object\n * @param {Object} object The destination object.\n * @param {...Object} [sources] The source objects.\n * @returns {Object} Returns `object`.\n * @see _.defaults\n * @example\n *\n * _.defaultsDeep({ 'a': { 'b': 2 } }, { 'a': { 'b': 1, 'c': 3 } });\n * // => { 'a': { 'b': 2, 'c': 3 } }\n */\n var defaultsDeep = baseRest(function(args) {\n args.push(undefined, customDefaultsMerge);\n return apply(mergeWith, undefined, args);\n });\n\n /**\n * This method is like `_.find` except that it returns the key of the first\n * element `predicate` returns truthy for instead of the element itself.\n *\n * @static\n * @memberOf _\n * @since 1.1.0\n * @category Object\n * @param {Object} object The object to inspect.\n * @param {Function} [predicate=_.identity] The function invoked per iteration.\n * @returns {string|undefined} Returns the key of the matched element,\n * else `undefined`.\n * @example\n *\n * var users = {\n * 'barney': { 'age': 36, 'active': true },\n * 'fred': { 'age': 40, 'active': false },\n * 'pebbles': { 'age': 1, 'active': true }\n * };\n *\n * _.findKey(users, function(o) { return o.age < 40; });\n * // => 'barney' (iteration order is not guaranteed)\n *\n * // The `_.matches` iteratee shorthand.\n * _.findKey(users, { 'age': 1, 'active': true });\n * // => 'pebbles'\n *\n * // The `_.matchesProperty` iteratee shorthand.\n * _.findKey(users, ['active', false]);\n * // => 'fred'\n *\n * // The `_.property` iteratee shorthand.\n * _.findKey(users, 'active');\n * // => 'barney'\n */\n function findKey(object, predicate) {\n return baseFindKey(object, getIteratee(predicate, 3), baseForOwn);\n }\n\n /**\n * This method is like `_.findKey` except that it iterates over elements of\n * a collection in the opposite order.\n *\n * @static\n * @memberOf _\n * @since 2.0.0\n * @category Object\n * @param {Object} object The object to inspect.\n * @param {Function} [predicate=_.identity] The function invoked per iteration.\n * @returns {string|undefined} Returns the key of the matched element,\n * else `undefined`.\n * @example\n *\n * var users = {\n * 'barney': { 'age': 36, 'active': true },\n * 'fred': { 'age': 40, 'active': false },\n * 'pebbles': { 'age': 1, 'active': true }\n * };\n *\n * _.findLastKey(users, function(o) { return o.age < 40; });\n * // => returns 'pebbles' assuming `_.findKey` returns 'barney'\n *\n * // The `_.matches` iteratee shorthand.\n * _.findLastKey(users, { 'age': 36, 'active': true });\n * // => 'barney'\n *\n * // The `_.matchesProperty` iteratee shorthand.\n * _.findLastKey(users, ['active', false]);\n * // => 'fred'\n *\n * // The `_.property` iteratee shorthand.\n * _.findLastKey(users, 'active');\n * // => 'pebbles'\n */\n function findLastKey(object, predicate) {\n return baseFindKey(object, getIteratee(predicate, 3), baseForOwnRight);\n }\n\n /**\n * Iterates over own and inherited enumerable string keyed properties of an\n * object and invokes `iteratee` for each property. The iteratee is invoked\n * with three arguments: (value, key, object). Iteratee functions may exit\n * iteration early by explicitly returning `false`.\n *\n * @static\n * @memberOf _\n * @since 0.3.0\n * @category Object\n * @param {Object} object The object to iterate over.\n * @param {Function} [iteratee=_.identity] The function invoked per iteration.\n * @returns {Object} Returns `object`.\n * @see _.forInRight\n * @example\n *\n * function Foo() {\n * this.a = 1;\n * this.b = 2;\n * }\n *\n * Foo.prototype.c = 3;\n *\n * _.forIn(new Foo, function(value, key) {\n * console.log(key);\n * });\n * // => Logs 'a', 'b', then 'c' (iteration order is not guaranteed).\n */\n function forIn(object, iteratee) {\n return object == null\n ? object\n : baseFor(object, getIteratee(iteratee, 3), keysIn);\n }\n\n /**\n * This method is like `_.forIn` except that it iterates over properties of\n * `object` in the opposite order.\n *\n * @static\n * @memberOf _\n * @since 2.0.0\n * @category Object\n * @param {Object} object The object to iterate over.\n * @param {Function} [iteratee=_.identity] The function invoked per iteration.\n * @returns {Object} Returns `object`.\n * @see _.forIn\n * @example\n *\n * function Foo() {\n * this.a = 1;\n * this.b = 2;\n * }\n *\n * Foo.prototype.c = 3;\n *\n * _.forInRight(new Foo, function(value, key) {\n * console.log(key);\n * });\n * // => Logs 'c', 'b', then 'a' assuming `_.forIn` logs 'a', 'b', then 'c'.\n */\n function forInRight(object, iteratee) {\n return object == null\n ? object\n : baseForRight(object, getIteratee(iteratee, 3), keysIn);\n }\n\n /**\n * Iterates over own enumerable string keyed properties of an object and\n * invokes `iteratee` for each property. The iteratee is invoked with three\n * arguments: (value, key, object). Iteratee functions may exit iteration\n * early by explicitly returning `false`.\n *\n * @static\n * @memberOf _\n * @since 0.3.0\n * @category Object\n * @param {Object} object The object to iterate over.\n * @param {Function} [iteratee=_.identity] The function invoked per iteration.\n * @returns {Object} Returns `object`.\n * @see _.forOwnRight\n * @example\n *\n * function Foo() {\n * this.a = 1;\n * this.b = 2;\n * }\n *\n * Foo.prototype.c = 3;\n *\n * _.forOwn(new Foo, function(value, key) {\n * console.log(key);\n * });\n * // => Logs 'a' then 'b' (iteration order is not guaranteed).\n */\n function forOwn(object, iteratee) {\n return object && baseForOwn(object, getIteratee(iteratee, 3));\n }\n\n /**\n * This method is like `_.forOwn` except that it iterates over properties of\n * `object` in the opposite order.\n *\n * @static\n * @memberOf _\n * @since 2.0.0\n * @category Object\n * @param {Object} object The object to iterate over.\n * @param {Function} [iteratee=_.identity] The function invoked per iteration.\n * @returns {Object} Returns `object`.\n * @see _.forOwn\n * @example\n *\n * function Foo() {\n * this.a = 1;\n * this.b = 2;\n * }\n *\n * Foo.prototype.c = 3;\n *\n * _.forOwnRight(new Foo, function(value, key) {\n * console.log(key);\n * });\n * // => Logs 'b' then 'a' assuming `_.forOwn` logs 'a' then 'b'.\n */\n function forOwnRight(object, iteratee) {\n return object && baseForOwnRight(object, getIteratee(iteratee, 3));\n }\n\n /**\n * Creates an array of function property names from own enumerable properties\n * of `object`.\n *\n * @static\n * @since 0.1.0\n * @memberOf _\n * @category Object\n * @param {Object} object The object to inspect.\n * @returns {Array} Returns the function names.\n * @see _.functionsIn\n * @example\n *\n * function Foo() {\n * this.a = _.constant('a');\n * this.b = _.constant('b');\n * }\n *\n * Foo.prototype.c = _.constant('c');\n *\n * _.functions(new Foo);\n * // => ['a', 'b']\n */\n function functions(object) {\n return object == null ? [] : baseFunctions(object, keys(object));\n }\n\n /**\n * Creates an array of function property names from own and inherited\n * enumerable properties of `object`.\n *\n * @static\n * @memberOf _\n * @since 4.0.0\n * @category Object\n * @param {Object} object The object to inspect.\n * @returns {Array} Returns the function names.\n * @see _.functions\n * @example\n *\n * function Foo() {\n * this.a = _.constant('a');\n * this.b = _.constant('b');\n * }\n *\n * Foo.prototype.c = _.constant('c');\n *\n * _.functionsIn(new Foo);\n * // => ['a', 'b', 'c']\n */\n function functionsIn(object) {\n return object == null ? [] : baseFunctions(object, keysIn(object));\n }\n\n /**\n * Gets the value at `path` of `object`. If the resolved value is\n * `undefined`, the `defaultValue` is returned in its place.\n *\n * @static\n * @memberOf _\n * @since 3.7.0\n * @category Object\n * @param {Object} object The object to query.\n * @param {Array|string} path The path of the property to get.\n * @param {*} [defaultValue] The value returned for `undefined` resolved values.\n * @returns {*} Returns the resolved value.\n * @example\n *\n * var object = { 'a': [{ 'b': { 'c': 3 } }] };\n *\n * _.get(object, 'a[0].b.c');\n * // => 3\n *\n * _.get(object, ['a', '0', 'b', 'c']);\n * // => 3\n *\n * _.get(object, 'a.b.c', 'default');\n * // => 'default'\n */\n function get(object, path, defaultValue) {\n var result = object == null ? undefined : baseGet(object, path);\n return result === undefined ? defaultValue : result;\n }\n\n /**\n * Checks if `path` is a direct property of `object`.\n *\n * @static\n * @since 0.1.0\n * @memberOf _\n * @category Object\n * @param {Object} object The object to query.\n * @param {Array|string} path The path to check.\n * @returns {boolean} Returns `true` if `path` exists, else `false`.\n * @example\n *\n * var object = { 'a': { 'b': 2 } };\n * var other = _.create({ 'a': _.create({ 'b': 2 }) });\n *\n * _.has(object, 'a');\n * // => true\n *\n * _.has(object, 'a.b');\n * // => true\n *\n * _.has(object, ['a', 'b']);\n * // => true\n *\n * _.has(other, 'a');\n * // => false\n */\n function has(object, path) {\n return object != null && hasPath(object, path, baseHas);\n }\n\n /**\n * Checks if `path` is a direct or inherited property of `object`.\n *\n * @static\n * @memberOf _\n * @since 4.0.0\n * @category Object\n * @param {Object} object The object to query.\n * @param {Array|string} path The path to check.\n * @returns {boolean} Returns `true` if `path` exists, else `false`.\n * @example\n *\n * var object = _.create({ 'a': _.create({ 'b': 2 }) });\n *\n * _.hasIn(object, 'a');\n * // => true\n *\n * _.hasIn(object, 'a.b');\n * // => true\n *\n * _.hasIn(object, ['a', 'b']);\n * // => true\n *\n * _.hasIn(object, 'b');\n * // => false\n */\n function hasIn(object, path) {\n return object != null && hasPath(object, path, baseHasIn);\n }\n\n /**\n * Creates an object composed of the inverted keys and values of `object`.\n * If `object` contains duplicate values, subsequent values overwrite\n * property assignments of previous values.\n *\n * @static\n * @memberOf _\n * @since 0.7.0\n * @category Object\n * @param {Object} object The object to invert.\n * @returns {Object} Returns the new inverted object.\n * @example\n *\n * var object = { 'a': 1, 'b': 2, 'c': 1 };\n *\n * _.invert(object);\n * // => { '1': 'c', '2': 'b' }\n */\n var invert = createInverter(function(result, value, key) {\n if (value != null &&\n typeof value.toString != 'function') {\n value = nativeObjectToString.call(value);\n }\n\n result[value] = key;\n }, constant(identity));\n\n /**\n * This method is like `_.invert` except that the inverted object is generated\n * from the results of running each element of `object` thru `iteratee`. The\n * corresponding inverted value of each inverted key is an array of keys\n * responsible for generating the inverted value. The iteratee is invoked\n * with one argument: (value).\n *\n * @static\n * @memberOf _\n * @since 4.1.0\n * @category Object\n * @param {Object} object The object to invert.\n * @param {Function} [iteratee=_.identity] The iteratee invoked per element.\n * @returns {Object} Returns the new inverted object.\n * @example\n *\n * var object = { 'a': 1, 'b': 2, 'c': 1 };\n *\n * _.invertBy(object);\n * // => { '1': ['a', 'c'], '2': ['b'] }\n *\n * _.invertBy(object, function(value) {\n * return 'group' + value;\n * });\n * // => { 'group1': ['a', 'c'], 'group2': ['b'] }\n */\n var invertBy = createInverter(function(result, value, key) {\n if (value != null &&\n typeof value.toString != 'function') {\n value = nativeObjectToString.call(value);\n }\n\n if (hasOwnProperty.call(result, value)) {\n result[value].push(key);\n } else {\n result[value] = [key];\n }\n }, getIteratee);\n\n /**\n * Invokes the method at `path` of `object`.\n *\n * @static\n * @memberOf _\n * @since 4.0.0\n * @category Object\n * @param {Object} object The object to query.\n * @param {Array|string} path The path of the method to invoke.\n * @param {...*} [args] The arguments to invoke the method with.\n * @returns {*} Returns the result of the invoked method.\n * @example\n *\n * var object = { 'a': [{ 'b': { 'c': [1, 2, 3, 4] } }] };\n *\n * _.invoke(object, 'a[0].b.c.slice', 1, 3);\n * // => [2, 3]\n */\n var invoke = baseRest(baseInvoke);\n\n /**\n * Creates an array of the own enumerable property names of `object`.\n *\n * **Note:** Non-object values are coerced to objects. See the\n * [ES spec](http://ecma-international.org/ecma-262/7.0/#sec-object.keys)\n * for more details.\n *\n * @static\n * @since 0.1.0\n * @memberOf _\n * @category Object\n * @param {Object} object The object to query.\n * @returns {Array} Returns the array of property names.\n * @example\n *\n * function Foo() {\n * this.a = 1;\n * this.b = 2;\n * }\n *\n * Foo.prototype.c = 3;\n *\n * _.keys(new Foo);\n * // => ['a', 'b'] (iteration order is not guaranteed)\n *\n * _.keys('hi');\n * // => ['0', '1']\n */\n function keys(object) {\n return isArrayLike(object) ? arrayLikeKeys(object) : baseKeys(object);\n }\n\n /**\n * Creates an array of the own and inherited enumerable property names of `object`.\n *\n * **Note:** Non-object values are coerced to objects.\n *\n * @static\n * @memberOf _\n * @since 3.0.0\n * @category Object\n * @param {Object} object The object to query.\n * @returns {Array} Returns the array of property names.\n * @example\n *\n * function Foo() {\n * this.a = 1;\n * this.b = 2;\n * }\n *\n * Foo.prototype.c = 3;\n *\n * _.keysIn(new Foo);\n * // => ['a', 'b', 'c'] (iteration order is not guaranteed)\n */\n function keysIn(object) {\n return isArrayLike(object) ? arrayLikeKeys(object, true) : baseKeysIn(object);\n }\n\n /**\n * The opposite of `_.mapValues`; this method creates an object with the\n * same values as `object` and keys generated by running each own enumerable\n * string keyed property of `object` thru `iteratee`. The iteratee is invoked\n * with three arguments: (value, key, object).\n *\n * @static\n * @memberOf _\n * @since 3.8.0\n * @category Object\n * @param {Object} object The object to iterate over.\n * @param {Function} [iteratee=_.identity] The function invoked per iteration.\n * @returns {Object} Returns the new mapped object.\n * @see _.mapValues\n * @example\n *\n * _.mapKeys({ 'a': 1, 'b': 2 }, function(value, key) {\n * return key + value;\n * });\n * // => { 'a1': 1, 'b2': 2 }\n */\n function mapKeys(object, iteratee) {\n var result = {};\n iteratee = getIteratee(iteratee, 3);\n\n baseForOwn(object, function(value, key, object) {\n baseAssignValue(result, iteratee(value, key, object), value);\n });\n return result;\n }\n\n /**\n * Creates an object with the same keys as `object` and values generated\n * by running each own enumerable string keyed property of `object` thru\n * `iteratee`. The iteratee is invoked with three arguments:\n * (value, key, object).\n *\n * @static\n * @memberOf _\n * @since 2.4.0\n * @category Object\n * @param {Object} object The object to iterate over.\n * @param {Function} [iteratee=_.identity] The function invoked per iteration.\n * @returns {Object} Returns the new mapped object.\n * @see _.mapKeys\n * @example\n *\n * var users = {\n * 'fred': { 'user': 'fred', 'age': 40 },\n * 'pebbles': { 'user': 'pebbles', 'age': 1 }\n * };\n *\n * _.mapValues(users, function(o) { return o.age; });\n * // => { 'fred': 40, 'pebbles': 1 } (iteration order is not guaranteed)\n *\n * // The `_.property` iteratee shorthand.\n * _.mapValues(users, 'age');\n * // => { 'fred': 40, 'pebbles': 1 } (iteration order is not guaranteed)\n */\n function mapValues(object, iteratee) {\n var result = {};\n iteratee = getIteratee(iteratee, 3);\n\n baseForOwn(object, function(value, key, object) {\n baseAssignValue(result, key, iteratee(value, key, object));\n });\n return result;\n }\n\n /**\n * This method is like `_.assign` except that it recursively merges own and\n * inherited enumerable string keyed properties of source objects into the\n * destination object. Source properties that resolve to `undefined` are\n * skipped if a destination value exists. Array and plain object properties\n * are merged recursively. Other objects and value types are overridden by\n * assignment. Source objects are applied from left to right. Subsequent\n * sources overwrite property assignments of previous sources.\n *\n * **Note:** This method mutates `object`.\n *\n * @static\n * @memberOf _\n * @since 0.5.0\n * @category Object\n * @param {Object} object The destination object.\n * @param {...Object} [sources] The source objects.\n * @returns {Object} Returns `object`.\n * @example\n *\n * var object = {\n * 'a': [{ 'b': 2 }, { 'd': 4 }]\n * };\n *\n * var other = {\n * 'a': [{ 'c': 3 }, { 'e': 5 }]\n * };\n *\n * _.merge(object, other);\n * // => { 'a': [{ 'b': 2, 'c': 3 }, { 'd': 4, 'e': 5 }] }\n */\n var merge = createAssigner(function(object, source, srcIndex) {\n baseMerge(object, source, srcIndex);\n });\n\n /**\n * This method is like `_.merge` except that it accepts `customizer` which\n * is invoked to produce the merged values of the destination and source\n * properties. If `customizer` returns `undefined`, merging is handled by the\n * method instead. The `customizer` is invoked with six arguments:\n * (objValue, srcValue, key, object, source, stack).\n *\n * **Note:** This method mutates `object`.\n *\n * @static\n * @memberOf _\n * @since 4.0.0\n * @category Object\n * @param {Object} object The destination object.\n * @param {...Object} sources The source objects.\n * @param {Function} customizer The function to customize assigned values.\n * @returns {Object} Returns `object`.\n * @example\n *\n * function customizer(objValue, srcValue) {\n * if (_.isArray(objValue)) {\n * return objValue.concat(srcValue);\n * }\n * }\n *\n * var object = { 'a': [1], 'b': [2] };\n * var other = { 'a': [3], 'b': [4] };\n *\n * _.mergeWith(object, other, customizer);\n * // => { 'a': [1, 3], 'b': [2, 4] }\n */\n var mergeWith = createAssigner(function(object, source, srcIndex, customizer) {\n baseMerge(object, source, srcIndex, customizer);\n });\n\n /**\n * The opposite of `_.pick`; this method creates an object composed of the\n * own and inherited enumerable property paths of `object` that are not omitted.\n *\n * **Note:** This method is considerably slower than `_.pick`.\n *\n * @static\n * @since 0.1.0\n * @memberOf _\n * @category Object\n * @param {Object} object The source object.\n * @param {...(string|string[])} [paths] The property paths to omit.\n * @returns {Object} Returns the new object.\n * @example\n *\n * var object = { 'a': 1, 'b': '2', 'c': 3 };\n *\n * _.omit(object, ['a', 'c']);\n * // => { 'b': '2' }\n */\n var omit = flatRest(function(object, paths) {\n var result = {};\n if (object == null) {\n return result;\n }\n var isDeep = false;\n paths = arrayMap(paths, function(path) {\n path = castPath(path, object);\n isDeep || (isDeep = path.length > 1);\n return path;\n });\n copyObject(object, getAllKeysIn(object), result);\n if (isDeep) {\n result = baseClone(result, CLONE_DEEP_FLAG | CLONE_FLAT_FLAG | CLONE_SYMBOLS_FLAG, customOmitClone);\n }\n var length = paths.length;\n while (length--) {\n baseUnset(result, paths[length]);\n }\n return result;\n });\n\n /**\n * The opposite of `_.pickBy`; this method creates an object composed of\n * the own and inherited enumerable string keyed properties of `object` that\n * `predicate` doesn't return truthy for. The predicate is invoked with two\n * arguments: (value, key).\n *\n * @static\n * @memberOf _\n * @since 4.0.0\n * @category Object\n * @param {Object} object The source object.\n * @param {Function} [predicate=_.identity] The function invoked per property.\n * @returns {Object} Returns the new object.\n * @example\n *\n * var object = { 'a': 1, 'b': '2', 'c': 3 };\n *\n * _.omitBy(object, _.isNumber);\n * // => { 'b': '2' }\n */\n function omitBy(object, predicate) {\n return pickBy(object, negate(getIteratee(predicate)));\n }\n\n /**\n * Creates an object composed of the picked `object` properties.\n *\n * @static\n * @since 0.1.0\n * @memberOf _\n * @category Object\n * @param {Object} object The source object.\n * @param {...(string|string[])} [paths] The property paths to pick.\n * @returns {Object} Returns the new object.\n * @example\n *\n * var object = { 'a': 1, 'b': '2', 'c': 3 };\n *\n * _.pick(object, ['a', 'c']);\n * // => { 'a': 1, 'c': 3 }\n */\n var pick = flatRest(function(object, paths) {\n return object == null ? {} : basePick(object, paths);\n });\n\n /**\n * Creates an object composed of the `object` properties `predicate` returns\n * truthy for. The predicate is invoked with two arguments: (value, key).\n *\n * @static\n * @memberOf _\n * @since 4.0.0\n * @category Object\n * @param {Object} object The source object.\n * @param {Function} [predicate=_.identity] The function invoked per property.\n * @returns {Object} Returns the new object.\n * @example\n *\n * var object = { 'a': 1, 'b': '2', 'c': 3 };\n *\n * _.pickBy(object, _.isNumber);\n * // => { 'a': 1, 'c': 3 }\n */\n function pickBy(object, predicate) {\n if (object == null) {\n return {};\n }\n var props = arrayMap(getAllKeysIn(object), function(prop) {\n return [prop];\n });\n predicate = getIteratee(predicate);\n return basePickBy(object, props, function(value, path) {\n return predicate(value, path[0]);\n });\n }\n\n /**\n * This method is like `_.get` except that if the resolved value is a\n * function it's invoked with the `this` binding of its parent object and\n * its result is returned.\n *\n * @static\n * @since 0.1.0\n * @memberOf _\n * @category Object\n * @param {Object} object The object to query.\n * @param {Array|string} path The path of the property to resolve.\n * @param {*} [defaultValue] The value returned for `undefined` resolved values.\n * @returns {*} Returns the resolved value.\n * @example\n *\n * var object = { 'a': [{ 'b': { 'c1': 3, 'c2': _.constant(4) } }] };\n *\n * _.result(object, 'a[0].b.c1');\n * // => 3\n *\n * _.result(object, 'a[0].b.c2');\n * // => 4\n *\n * _.result(object, 'a[0].b.c3', 'default');\n * // => 'default'\n *\n * _.result(object, 'a[0].b.c3', _.constant('default'));\n * // => 'default'\n */\n function result(object, path, defaultValue) {\n path = castPath(path, object);\n\n var index = -1,\n length = path.length;\n\n // Ensure the loop is entered when path is empty.\n if (!length) {\n length = 1;\n object = undefined;\n }\n while (++index < length) {\n var value = object == null ? undefined : object[toKey(path[index])];\n if (value === undefined) {\n index = length;\n value = defaultValue;\n }\n object = isFunction(value) ? value.call(object) : value;\n }\n return object;\n }\n\n /**\n * Sets the value at `path` of `object`. If a portion of `path` doesn't exist,\n * it's created. Arrays are created for missing index properties while objects\n * are created for all other missing properties. Use `_.setWith` to customize\n * `path` creation.\n *\n * **Note:** This method mutates `object`.\n *\n * @static\n * @memberOf _\n * @since 3.7.0\n * @category Object\n * @param {Object} object The object to modify.\n * @param {Array|string} path The path of the property to set.\n * @param {*} value The value to set.\n * @returns {Object} Returns `object`.\n * @example\n *\n * var object = { 'a': [{ 'b': { 'c': 3 } }] };\n *\n * _.set(object, 'a[0].b.c', 4);\n * console.log(object.a[0].b.c);\n * // => 4\n *\n * _.set(object, ['x', '0', 'y', 'z'], 5);\n * console.log(object.x[0].y.z);\n * // => 5\n */\n function set(object, path, value) {\n return object == null ? object : baseSet(object, path, value);\n }\n\n /**\n * This method is like `_.set` except that it accepts `customizer` which is\n * invoked to produce the objects of `path`. If `customizer` returns `undefined`\n * path creation is handled by the method instead. The `customizer` is invoked\n * with three arguments: (nsValue, key, nsObject).\n *\n * **Note:** This method mutates `object`.\n *\n * @static\n * @memberOf _\n * @since 4.0.0\n * @category Object\n * @param {Object} object The object to modify.\n * @param {Array|string} path The path of the property to set.\n * @param {*} value The value to set.\n * @param {Function} [customizer] The function to customize assigned values.\n * @returns {Object} Returns `object`.\n * @example\n *\n * var object = {};\n *\n * _.setWith(object, '[0][1]', 'a', Object);\n * // => { '0': { '1': 'a' } }\n */\n function setWith(object, path, value, customizer) {\n customizer = typeof customizer == 'function' ? customizer : undefined;\n return object == null ? object : baseSet(object, path, value, customizer);\n }\n\n /**\n * Creates an array of own enumerable string keyed-value pairs for `object`\n * which can be consumed by `_.fromPairs`. If `object` is a map or set, its\n * entries are returned.\n *\n * @static\n * @memberOf _\n * @since 4.0.0\n * @alias entries\n * @category Object\n * @param {Object} object The object to query.\n * @returns {Array} Returns the key-value pairs.\n * @example\n *\n * function Foo() {\n * this.a = 1;\n * this.b = 2;\n * }\n *\n * Foo.prototype.c = 3;\n *\n * _.toPairs(new Foo);\n * // => [['a', 1], ['b', 2]] (iteration order is not guaranteed)\n */\n var toPairs = createToPairs(keys);\n\n /**\n * Creates an array of own and inherited enumerable string keyed-value pairs\n * for `object` which can be consumed by `_.fromPairs`. If `object` is a map\n * or set, its entries are returned.\n *\n * @static\n * @memberOf _\n * @since 4.0.0\n * @alias entriesIn\n * @category Object\n * @param {Object} object The object to query.\n * @returns {Array} Returns the key-value pairs.\n * @example\n *\n * function Foo() {\n * this.a = 1;\n * this.b = 2;\n * }\n *\n * Foo.prototype.c = 3;\n *\n * _.toPairsIn(new Foo);\n * // => [['a', 1], ['b', 2], ['c', 3]] (iteration order is not guaranteed)\n */\n var toPairsIn = createToPairs(keysIn);\n\n /**\n * An alternative to `_.reduce`; this method transforms `object` to a new\n * `accumulator` object which is the result of running each of its own\n * enumerable string keyed properties thru `iteratee`, with each invocation\n * potentially mutating the `accumulator` object. If `accumulator` is not\n * provided, a new object with the same `[[Prototype]]` will be used. The\n * iteratee is invoked with four arguments: (accumulator, value, key, object).\n * Iteratee functions may exit iteration early by explicitly returning `false`.\n *\n * @static\n * @memberOf _\n * @since 1.3.0\n * @category Object\n * @param {Object} object The object to iterate over.\n * @param {Function} [iteratee=_.identity] The function invoked per iteration.\n * @param {*} [accumulator] The custom accumulator value.\n * @returns {*} Returns the accumulated value.\n * @example\n *\n * _.transform([2, 3, 4], function(result, n) {\n * result.push(n *= n);\n * return n % 2 == 0;\n * }, []);\n * // => [4, 9]\n *\n * _.transform({ 'a': 1, 'b': 2, 'c': 1 }, function(result, value, key) {\n * (result[value] || (result[value] = [])).push(key);\n * }, {});\n * // => { '1': ['a', 'c'], '2': ['b'] }\n */\n function transform(object, iteratee, accumulator) {\n var isArr = isArray(object),\n isArrLike = isArr || isBuffer(object) || isTypedArray(object);\n\n iteratee = getIteratee(iteratee, 4);\n if (accumulator == null) {\n var Ctor = object && object.constructor;\n if (isArrLike) {\n accumulator = isArr ? new Ctor : [];\n }\n else if (isObject(object)) {\n accumulator = isFunction(Ctor) ? baseCreate(getPrototype(object)) : {};\n }\n else {\n accumulator = {};\n }\n }\n (isArrLike ? arrayEach : baseForOwn)(object, function(value, index, object) {\n return iteratee(accumulator, value, index, object);\n });\n return accumulator;\n }\n\n /**\n * Removes the property at `path` of `object`.\n *\n * **Note:** This method mutates `object`.\n *\n * @static\n * @memberOf _\n * @since 4.0.0\n * @category Object\n * @param {Object} object The object to modify.\n * @param {Array|string} path The path of the property to unset.\n * @returns {boolean} Returns `true` if the property is deleted, else `false`.\n * @example\n *\n * var object = { 'a': [{ 'b': { 'c': 7 } }] };\n * _.unset(object, 'a[0].b.c');\n * // => true\n *\n * console.log(object);\n * // => { 'a': [{ 'b': {} }] };\n *\n * _.unset(object, ['a', '0', 'b', 'c']);\n * // => true\n *\n * console.log(object);\n * // => { 'a': [{ 'b': {} }] };\n */\n function unset(object, path) {\n return object == null ? true : baseUnset(object, path);\n }\n\n /**\n * This method is like `_.set` except that accepts `updater` to produce the\n * value to set. Use `_.updateWith` to customize `path` creation. The `updater`\n * is invoked with one argument: (value).\n *\n * **Note:** This method mutates `object`.\n *\n * @static\n * @memberOf _\n * @since 4.6.0\n * @category Object\n * @param {Object} object The object to modify.\n * @param {Array|string} path The path of the property to set.\n * @param {Function} updater The function to produce the updated value.\n * @returns {Object} Returns `object`.\n * @example\n *\n * var object = { 'a': [{ 'b': { 'c': 3 } }] };\n *\n * _.update(object, 'a[0].b.c', function(n) { return n * n; });\n * console.log(object.a[0].b.c);\n * // => 9\n *\n * _.update(object, 'x[0].y.z', function(n) { return n ? n + 1 : 0; });\n * console.log(object.x[0].y.z);\n * // => 0\n */\n function update(object, path, updater) {\n return object == null ? object : baseUpdate(object, path, castFunction(updater));\n }\n\n /**\n * This method is like `_.update` except that it accepts `customizer` which is\n * invoked to produce the objects of `path`. If `customizer` returns `undefined`\n * path creation is handled by the method instead. The `customizer` is invoked\n * with three arguments: (nsValue, key, nsObject).\n *\n * **Note:** This method mutates `object`.\n *\n * @static\n * @memberOf _\n * @since 4.6.0\n * @category Object\n * @param {Object} object The object to modify.\n * @param {Array|string} path The path of the property to set.\n * @param {Function} updater The function to produce the updated value.\n * @param {Function} [customizer] The function to customize assigned values.\n * @returns {Object} Returns `object`.\n * @example\n *\n * var object = {};\n *\n * _.updateWith(object, '[0][1]', _.constant('a'), Object);\n * // => { '0': { '1': 'a' } }\n */\n function updateWith(object, path, updater, customizer) {\n customizer = typeof customizer == 'function' ? customizer : undefined;\n return object == null ? object : baseUpdate(object, path, castFunction(updater), customizer);\n }\n\n /**\n * Creates an array of the own enumerable string keyed property values of `object`.\n *\n * **Note:** Non-object values are coerced to objects.\n *\n * @static\n * @since 0.1.0\n * @memberOf _\n * @category Object\n * @param {Object} object The object to query.\n * @returns {Array} Returns the array of property values.\n * @example\n *\n * function Foo() {\n * this.a = 1;\n * this.b = 2;\n * }\n *\n * Foo.prototype.c = 3;\n *\n * _.values(new Foo);\n * // => [1, 2] (iteration order is not guaranteed)\n *\n * _.values('hi');\n * // => ['h', 'i']\n */\n function values(object) {\n return object == null ? [] : baseValues(object, keys(object));\n }\n\n /**\n * Creates an array of the own and inherited enumerable string keyed property\n * values of `object`.\n *\n * **Note:** Non-object values are coerced to objects.\n *\n * @static\n * @memberOf _\n * @since 3.0.0\n * @category Object\n * @param {Object} object The object to query.\n * @returns {Array} Returns the array of property values.\n * @example\n *\n * function Foo() {\n * this.a = 1;\n * this.b = 2;\n * }\n *\n * Foo.prototype.c = 3;\n *\n * _.valuesIn(new Foo);\n * // => [1, 2, 3] (iteration order is not guaranteed)\n */\n function valuesIn(object) {\n return object == null ? [] : baseValues(object, keysIn(object));\n }\n\n /*------------------------------------------------------------------------*/\n\n /**\n * Clamps `number` within the inclusive `lower` and `upper` bounds.\n *\n * @static\n * @memberOf _\n * @since 4.0.0\n * @category Number\n * @param {number} number The number to clamp.\n * @param {number} [lower] The lower bound.\n * @param {number} upper The upper bound.\n * @returns {number} Returns the clamped number.\n * @example\n *\n * _.clamp(-10, -5, 5);\n * // => -5\n *\n * _.clamp(10, -5, 5);\n * // => 5\n */\n function clamp(number, lower, upper) {\n if (upper === undefined) {\n upper = lower;\n lower = undefined;\n }\n if (upper !== undefined) {\n upper = toNumber(upper);\n upper = upper === upper ? upper : 0;\n }\n if (lower !== undefined) {\n lower = toNumber(lower);\n lower = lower === lower ? lower : 0;\n }\n return baseClamp(toNumber(number), lower, upper);\n }\n\n /**\n * Checks if `n` is between `start` and up to, but not including, `end`. If\n * `end` is not specified, it's set to `start` with `start` then set to `0`.\n * If `start` is greater than `end` the params are swapped to support\n * negative ranges.\n *\n * @static\n * @memberOf _\n * @since 3.3.0\n * @category Number\n * @param {number} number The number to check.\n * @param {number} [start=0] The start of the range.\n * @param {number} end The end of the range.\n * @returns {boolean} Returns `true` if `number` is in the range, else `false`.\n * @see _.range, _.rangeRight\n * @example\n *\n * _.inRange(3, 2, 4);\n * // => true\n *\n * _.inRange(4, 8);\n * // => true\n *\n * _.inRange(4, 2);\n * // => false\n *\n * _.inRange(2, 2);\n * // => false\n *\n * _.inRange(1.2, 2);\n * // => true\n *\n * _.inRange(5.2, 4);\n * // => false\n *\n * _.inRange(-3, -2, -6);\n * // => true\n */\n function inRange(number, start, end) {\n start = toFinite(start);\n if (end === undefined) {\n end = start;\n start = 0;\n } else {\n end = toFinite(end);\n }\n number = toNumber(number);\n return baseInRange(number, start, end);\n }\n\n /**\n * Produces a random number between the inclusive `lower` and `upper` bounds.\n * If only one argument is provided a number between `0` and the given number\n * is returned. If `floating` is `true`, or either `lower` or `upper` are\n * floats, a floating-point number is returned instead of an integer.\n *\n * **Note:** JavaScript follows the IEEE-754 standard for resolving\n * floating-point values which can produce unexpected results.\n *\n * @static\n * @memberOf _\n * @since 0.7.0\n * @category Number\n * @param {number} [lower=0] The lower bound.\n * @param {number} [upper=1] The upper bound.\n * @param {boolean} [floating] Specify returning a floating-point number.\n * @returns {number} Returns the random number.\n * @example\n *\n * _.random(0, 5);\n * // => an integer between 0 and 5\n *\n * _.random(5);\n * // => also an integer between 0 and 5\n *\n * _.random(5, true);\n * // => a floating-point number between 0 and 5\n *\n * _.random(1.2, 5.2);\n * // => a floating-point number between 1.2 and 5.2\n */\n function random(lower, upper, floating) {\n if (floating && typeof floating != 'boolean' && isIterateeCall(lower, upper, floating)) {\n upper = floating = undefined;\n }\n if (floating === undefined) {\n if (typeof upper == 'boolean') {\n floating = upper;\n upper = undefined;\n }\n else if (typeof lower == 'boolean') {\n floating = lower;\n lower = undefined;\n }\n }\n if (lower === undefined && upper === undefined) {\n lower = 0;\n upper = 1;\n }\n else {\n lower = toFinite(lower);\n if (upper === undefined) {\n upper = lower;\n lower = 0;\n } else {\n upper = toFinite(upper);\n }\n }\n if (lower > upper) {\n var temp = lower;\n lower = upper;\n upper = temp;\n }\n if (floating || lower % 1 || upper % 1) {\n var rand = nativeRandom();\n return nativeMin(lower + (rand * (upper - lower + freeParseFloat('1e-' + ((rand + '').length - 1)))), upper);\n }\n return baseRandom(lower, upper);\n }\n\n /*------------------------------------------------------------------------*/\n\n /**\n * Converts `string` to [camel case](https://en.wikipedia.org/wiki/CamelCase).\n *\n * @static\n * @memberOf _\n * @since 3.0.0\n * @category String\n * @param {string} [string=''] The string to convert.\n * @returns {string} Returns the camel cased string.\n * @example\n *\n * _.camelCase('Foo Bar');\n * // => 'fooBar'\n *\n * _.camelCase('--foo-bar--');\n * // => 'fooBar'\n *\n * _.camelCase('__FOO_BAR__');\n * // => 'fooBar'\n */\n var camelCase = createCompounder(function(result, word, index) {\n word = word.toLowerCase();\n return result + (index ? capitalize(word) : word);\n });\n\n /**\n * Converts the first character of `string` to upper case and the remaining\n * to lower case.\n *\n * @static\n * @memberOf _\n * @since 3.0.0\n * @category String\n * @param {string} [string=''] The string to capitalize.\n * @returns {string} Returns the capitalized string.\n * @example\n *\n * _.capitalize('FRED');\n * // => 'Fred'\n */\n function capitalize(string) {\n return upperFirst(toString(string).toLowerCase());\n }\n\n /**\n * Deburrs `string` by converting\n * [Latin-1 Supplement](https://en.wikipedia.org/wiki/Latin-1_Supplement_(Unicode_block)#Character_table)\n * and [Latin Extended-A](https://en.wikipedia.org/wiki/Latin_Extended-A)\n * letters to basic Latin letters and removing\n * [combining diacritical marks](https://en.wikipedia.org/wiki/Combining_Diacritical_Marks).\n *\n * @static\n * @memberOf _\n * @since 3.0.0\n * @category String\n * @param {string} [string=''] The string to deburr.\n * @returns {string} Returns the deburred string.\n * @example\n *\n * _.deburr('déjà vu');\n * // => 'deja vu'\n */\n function deburr(string) {\n string = toString(string);\n return string && string.replace(reLatin, deburrLetter).replace(reComboMark, '');\n }\n\n /**\n * Checks if `string` ends with the given target string.\n *\n * @static\n * @memberOf _\n * @since 3.0.0\n * @category String\n * @param {string} [string=''] The string to inspect.\n * @param {string} [target] The string to search for.\n * @param {number} [position=string.length] The position to search up to.\n * @returns {boolean} Returns `true` if `string` ends with `target`,\n * else `false`.\n * @example\n *\n * _.endsWith('abc', 'c');\n * // => true\n *\n * _.endsWith('abc', 'b');\n * // => false\n *\n * _.endsWith('abc', 'b', 2);\n * // => true\n */\n function endsWith(string, target, position) {\n string = toString(string);\n target = baseToString(target);\n\n var length = string.length;\n position = position === undefined\n ? length\n : baseClamp(toInteger(position), 0, length);\n\n var end = position;\n position -= target.length;\n return position >= 0 && string.slice(position, end) == target;\n }\n\n /**\n * Converts the characters \"&\", \"<\", \">\", '\"', and \"'\" in `string` to their\n * corresponding HTML entities.\n *\n * **Note:** No other characters are escaped. To escape additional\n * characters use a third-party library like [_he_](https://mths.be/he).\n *\n * Though the \">\" character is escaped for symmetry, characters like\n * \">\" and \"/\" don't need escaping in HTML and have no special meaning\n * unless they're part of a tag or unquoted attribute value. See\n * [Mathias Bynens's article](https://mathiasbynens.be/notes/ambiguous-ampersands)\n * (under \"semi-related fun fact\") for more details.\n *\n * When working with HTML you should always\n * [quote attribute values](http://wonko.com/post/html-escaping) to reduce\n * XSS vectors.\n *\n * @static\n * @since 0.1.0\n * @memberOf _\n * @category String\n * @param {string} [string=''] The string to escape.\n * @returns {string} Returns the escaped string.\n * @example\n *\n * _.escape('fred, barney, & pebbles');\n * // => 'fred, barney, & pebbles'\n */\n function escape(string) {\n string = toString(string);\n return (string && reHasUnescapedHtml.test(string))\n ? string.replace(reUnescapedHtml, escapeHtmlChar)\n : string;\n }\n\n /**\n * Escapes the `RegExp` special characters \"^\", \"$\", \"\\\", \".\", \"*\", \"+\",\n * \"?\", \"(\", \")\", \"[\", \"]\", \"{\", \"}\", and \"|\" in `string`.\n *\n * @static\n * @memberOf _\n * @since 3.0.0\n * @category String\n * @param {string} [string=''] The string to escape.\n * @returns {string} Returns the escaped string.\n * @example\n *\n * _.escapeRegExp('[lodash](https://lodash.com/)');\n * // => '\\[lodash\\]\\(https://lodash\\.com/\\)'\n */\n function escapeRegExp(string) {\n string = toString(string);\n return (string && reHasRegExpChar.test(string))\n ? string.replace(reRegExpChar, '\\\\$&')\n : string;\n }\n\n /**\n * Converts `string` to\n * [kebab case](https://en.wikipedia.org/wiki/Letter_case#Special_case_styles).\n *\n * @static\n * @memberOf _\n * @since 3.0.0\n * @category String\n * @param {string} [string=''] The string to convert.\n * @returns {string} Returns the kebab cased string.\n * @example\n *\n * _.kebabCase('Foo Bar');\n * // => 'foo-bar'\n *\n * _.kebabCase('fooBar');\n * // => 'foo-bar'\n *\n * _.kebabCase('__FOO_BAR__');\n * // => 'foo-bar'\n */\n var kebabCase = createCompounder(function(result, word, index) {\n return result + (index ? '-' : '') + word.toLowerCase();\n });\n\n /**\n * Converts `string`, as space separated words, to lower case.\n *\n * @static\n * @memberOf _\n * @since 4.0.0\n * @category String\n * @param {string} [string=''] The string to convert.\n * @returns {string} Returns the lower cased string.\n * @example\n *\n * _.lowerCase('--Foo-Bar--');\n * // => 'foo bar'\n *\n * _.lowerCase('fooBar');\n * // => 'foo bar'\n *\n * _.lowerCase('__FOO_BAR__');\n * // => 'foo bar'\n */\n var lowerCase = createCompounder(function(result, word, index) {\n return result + (index ? ' ' : '') + word.toLowerCase();\n });\n\n /**\n * Converts the first character of `string` to lower case.\n *\n * @static\n * @memberOf _\n * @since 4.0.0\n * @category String\n * @param {string} [string=''] The string to convert.\n * @returns {string} Returns the converted string.\n * @example\n *\n * _.lowerFirst('Fred');\n * // => 'fred'\n *\n * _.lowerFirst('FRED');\n * // => 'fRED'\n */\n var lowerFirst = createCaseFirst('toLowerCase');\n\n /**\n * Pads `string` on the left and right sides if it's shorter than `length`.\n * Padding characters are truncated if they can't be evenly divided by `length`.\n *\n * @static\n * @memberOf _\n * @since 3.0.0\n * @category String\n * @param {string} [string=''] The string to pad.\n * @param {number} [length=0] The padding length.\n * @param {string} [chars=' '] The string used as padding.\n * @returns {string} Returns the padded string.\n * @example\n *\n * _.pad('abc', 8);\n * // => ' abc '\n *\n * _.pad('abc', 8, '_-');\n * // => '_-abc_-_'\n *\n * _.pad('abc', 3);\n * // => 'abc'\n */\n function pad(string, length, chars) {\n string = toString(string);\n length = toInteger(length);\n\n var strLength = length ? stringSize(string) : 0;\n if (!length || strLength >= length) {\n return string;\n }\n var mid = (length - strLength) / 2;\n return (\n createPadding(nativeFloor(mid), chars) +\n string +\n createPadding(nativeCeil(mid), chars)\n );\n }\n\n /**\n * Pads `string` on the right side if it's shorter than `length`. Padding\n * characters are truncated if they exceed `length`.\n *\n * @static\n * @memberOf _\n * @since 4.0.0\n * @category String\n * @param {string} [string=''] The string to pad.\n * @param {number} [length=0] The padding length.\n * @param {string} [chars=' '] The string used as padding.\n * @returns {string} Returns the padded string.\n * @example\n *\n * _.padEnd('abc', 6);\n * // => 'abc '\n *\n * _.padEnd('abc', 6, '_-');\n * // => 'abc_-_'\n *\n * _.padEnd('abc', 3);\n * // => 'abc'\n */\n function padEnd(string, length, chars) {\n string = toString(string);\n length = toInteger(length);\n\n var strLength = length ? stringSize(string) : 0;\n return (length && strLength < length)\n ? (string + createPadding(length - strLength, chars))\n : string;\n }\n\n /**\n * Pads `string` on the left side if it's shorter than `length`. Padding\n * characters are truncated if they exceed `length`.\n *\n * @static\n * @memberOf _\n * @since 4.0.0\n * @category String\n * @param {string} [string=''] The string to pad.\n * @param {number} [length=0] The padding length.\n * @param {string} [chars=' '] The string used as padding.\n * @returns {string} Returns the padded string.\n * @example\n *\n * _.padStart('abc', 6);\n * // => ' abc'\n *\n * _.padStart('abc', 6, '_-');\n * // => '_-_abc'\n *\n * _.padStart('abc', 3);\n * // => 'abc'\n */\n function padStart(string, length, chars) {\n string = toString(string);\n length = toInteger(length);\n\n var strLength = length ? stringSize(string) : 0;\n return (length && strLength < length)\n ? (createPadding(length - strLength, chars) + string)\n : string;\n }\n\n /**\n * Converts `string` to an integer of the specified radix. If `radix` is\n * `undefined` or `0`, a `radix` of `10` is used unless `value` is a\n * hexadecimal, in which case a `radix` of `16` is used.\n *\n * **Note:** This method aligns with the\n * [ES5 implementation](https://es5.github.io/#x15.1.2.2) of `parseInt`.\n *\n * @static\n * @memberOf _\n * @since 1.1.0\n * @category String\n * @param {string} string The string to convert.\n * @param {number} [radix=10] The radix to interpret `value` by.\n * @param- {Object} [guard] Enables use as an iteratee for methods like `_.map`.\n * @returns {number} Returns the converted integer.\n * @example\n *\n * _.parseInt('08');\n * // => 8\n *\n * _.map(['6', '08', '10'], _.parseInt);\n * // => [6, 8, 10]\n */\n function parseInt(string, radix, guard) {\n if (guard || radix == null) {\n radix = 0;\n } else if (radix) {\n radix = +radix;\n }\n return nativeParseInt(toString(string).replace(reTrimStart, ''), radix || 0);\n }\n\n /**\n * Repeats the given string `n` times.\n *\n * @static\n * @memberOf _\n * @since 3.0.0\n * @category String\n * @param {string} [string=''] The string to repeat.\n * @param {number} [n=1] The number of times to repeat the string.\n * @param- {Object} [guard] Enables use as an iteratee for methods like `_.map`.\n * @returns {string} Returns the repeated string.\n * @example\n *\n * _.repeat('*', 3);\n * // => '***'\n *\n * _.repeat('abc', 2);\n * // => 'abcabc'\n *\n * _.repeat('abc', 0);\n * // => ''\n */\n function repeat(string, n, guard) {\n if ((guard ? isIterateeCall(string, n, guard) : n === undefined)) {\n n = 1;\n } else {\n n = toInteger(n);\n }\n return baseRepeat(toString(string), n);\n }\n\n /**\n * Replaces matches for `pattern` in `string` with `replacement`.\n *\n * **Note:** This method is based on\n * [`String#replace`](https://mdn.io/String/replace).\n *\n * @static\n * @memberOf _\n * @since 4.0.0\n * @category String\n * @param {string} [string=''] The string to modify.\n * @param {RegExp|string} pattern The pattern to replace.\n * @param {Function|string} replacement The match replacement.\n * @returns {string} Returns the modified string.\n * @example\n *\n * _.replace('Hi Fred', 'Fred', 'Barney');\n * // => 'Hi Barney'\n */\n function replace() {\n var args = arguments,\n string = toString(args[0]);\n\n return args.length < 3 ? string : string.replace(args[1], args[2]);\n }\n\n /**\n * Converts `string` to\n * [snake case](https://en.wikipedia.org/wiki/Snake_case).\n *\n * @static\n * @memberOf _\n * @since 3.0.0\n * @category String\n * @param {string} [string=''] The string to convert.\n * @returns {string} Returns the snake cased string.\n * @example\n *\n * _.snakeCase('Foo Bar');\n * // => 'foo_bar'\n *\n * _.snakeCase('fooBar');\n * // => 'foo_bar'\n *\n * _.snakeCase('--FOO-BAR--');\n * // => 'foo_bar'\n */\n var snakeCase = createCompounder(function(result, word, index) {\n return result + (index ? '_' : '') + word.toLowerCase();\n });\n\n /**\n * Splits `string` by `separator`.\n *\n * **Note:** This method is based on\n * [`String#split`](https://mdn.io/String/split).\n *\n * @static\n * @memberOf _\n * @since 4.0.0\n * @category String\n * @param {string} [string=''] The string to split.\n * @param {RegExp|string} separator The separator pattern to split by.\n * @param {number} [limit] The length to truncate results to.\n * @returns {Array} Returns the string segments.\n * @example\n *\n * _.split('a-b-c', '-', 2);\n * // => ['a', 'b']\n */\n function split(string, separator, limit) {\n if (limit && typeof limit != 'number' && isIterateeCall(string, separator, limit)) {\n separator = limit = undefined;\n }\n limit = limit === undefined ? MAX_ARRAY_LENGTH : limit >>> 0;\n if (!limit) {\n return [];\n }\n string = toString(string);\n if (string && (\n typeof separator == 'string' ||\n (separator != null && !isRegExp(separator))\n )) {\n separator = baseToString(separator);\n if (!separator && hasUnicode(string)) {\n return castSlice(stringToArray(string), 0, limit);\n }\n }\n return string.split(separator, limit);\n }\n\n /**\n * Converts `string` to\n * [start case](https://en.wikipedia.org/wiki/Letter_case#Stylistic_or_specialised_usage).\n *\n * @static\n * @memberOf _\n * @since 3.1.0\n * @category String\n * @param {string} [string=''] The string to convert.\n * @returns {string} Returns the start cased string.\n * @example\n *\n * _.startCase('--foo-bar--');\n * // => 'Foo Bar'\n *\n * _.startCase('fooBar');\n * // => 'Foo Bar'\n *\n * _.startCase('__FOO_BAR__');\n * // => 'FOO BAR'\n */\n var startCase = createCompounder(function(result, word, index) {\n return result + (index ? ' ' : '') + upperFirst(word);\n });\n\n /**\n * Checks if `string` starts with the given target string.\n *\n * @static\n * @memberOf _\n * @since 3.0.0\n * @category String\n * @param {string} [string=''] The string to inspect.\n * @param {string} [target] The string to search for.\n * @param {number} [position=0] The position to search from.\n * @returns {boolean} Returns `true` if `string` starts with `target`,\n * else `false`.\n * @example\n *\n * _.startsWith('abc', 'a');\n * // => true\n *\n * _.startsWith('abc', 'b');\n * // => false\n *\n * _.startsWith('abc', 'b', 1);\n * // => true\n */\n function startsWith(string, target, position) {\n string = toString(string);\n position = position == null\n ? 0\n : baseClamp(toInteger(position), 0, string.length);\n\n target = baseToString(target);\n return string.slice(position, position + target.length) == target;\n }\n\n /**\n * Creates a compiled template function that can interpolate data properties\n * in \"interpolate\" delimiters, HTML-escape interpolated data properties in\n * \"escape\" delimiters, and execute JavaScript in \"evaluate\" delimiters. Data\n * properties may be accessed as free variables in the template. If a setting\n * object is given, it takes precedence over `_.templateSettings` values.\n *\n * **Note:** In the development build `_.template` utilizes\n * [sourceURLs](http://www.html5rocks.com/en/tutorials/developertools/sourcemaps/#toc-sourceurl)\n * for easier debugging.\n *\n * For more information on precompiling templates see\n * [lodash's custom builds documentation](https://lodash.com/custom-builds).\n *\n * For more information on Chrome extension sandboxes see\n * [Chrome's extensions documentation](https://developer.chrome.com/extensions/sandboxingEval).\n *\n * @static\n * @since 0.1.0\n * @memberOf _\n * @category String\n * @param {string} [string=''] The template string.\n * @param {Object} [options={}] The options object.\n * @param {RegExp} [options.escape=_.templateSettings.escape]\n * The HTML \"escape\" delimiter.\n * @param {RegExp} [options.evaluate=_.templateSettings.evaluate]\n * The \"evaluate\" delimiter.\n * @param {Object} [options.imports=_.templateSettings.imports]\n * An object to import into the template as free variables.\n * @param {RegExp} [options.interpolate=_.templateSettings.interpolate]\n * The \"interpolate\" delimiter.\n * @param {string} [options.sourceURL='lodash.templateSources[n]']\n * The sourceURL of the compiled template.\n * @param {string} [options.variable='obj']\n * The data object variable name.\n * @param- {Object} [guard] Enables use as an iteratee for methods like `_.map`.\n * @returns {Function} Returns the compiled template function.\n * @example\n *\n * // Use the \"interpolate\" delimiter to create a compiled template.\n * var compiled = _.template('hello <%= user %>!');\n * compiled({ 'user': 'fred' });\n * // => 'hello fred!'\n *\n * // Use the HTML \"escape\" delimiter to escape data property values.\n * var compiled = _.template('<%- value %>');\n * compiled({ 'value': '