diff --git a/.gitignore b/.gitignore index 7a4b798..172ea14 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,214 @@ +# Largely copied from this link with a few additions at the bottom: https://github.com/github/gitignore/blob/main/Python.gitignore +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[codz] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py.cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# UV +# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +#uv.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock +#poetry.toml + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +# pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python. +# https://pdm-project.org/en/latest/usage/project/#working-with-version-control +#pdm.lock +#pdm.toml +.pdm-python +.pdm-build/ + +# pixi +# Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control. +#pixi.lock +# Pixi creates a virtual environment in the .pixi directory, just like venv module creates one +# in the .venv directory. It is recommended not to include this directory in version control. +.pixi + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.envrc +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +# Abstra +# Abstra is an AI-powered process automation framework. +# Ignore directories containing user credentials, local state, and settings. +# Learn more at https://abstra.io/docs +.abstra/ + +# Visual Studio Code +# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore +# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore +# and can be added to the global gitignore or merged into this file. However, if you prefer, +# you could uncomment the following to ignore the entire vscode folder +# .vscode/ + +# Ruff stuff: +.ruff_cache/ + +# PyPI configuration file +.pypirc + +# Marimo +marimo/_static/ +marimo/_lsp/ +__marimo__/ + +# Streamlit +.streamlit/secrets.toml + +# START OF ADDITIONS + +# Keras executable +*.keras + .setup .venv build/ -dist/ \ No newline at end of file +dist/ diff --git a/requirements.txt b/requirements.txt index 157fcb0..6004c28 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,4 +8,4 @@ pyinstaller pylint pytest pytest-asyncio -tensorflow_probability == 0.24.0 \ No newline at end of file +tensorflow_probability == 0.24.0 diff --git a/src/tensorflow_module.py b/src/tensorflow_module.py index 8b4716e..2a78fac 100755 --- a/src/tensorflow_module.py +++ b/src/tensorflow_module.py @@ -14,6 +14,7 @@ import numpy as np import google.protobuf.struct_pb2 as pb import tensorflow as tf +import keras LOGGER = getLogger(__name__) @@ -52,7 +53,6 @@ def validate_config( LOGGER.info( "Detected Keras model file at " + model_path - + ". Please note Keras support is limited." ) return ([], []) @@ -87,35 +87,54 @@ def reconfigure( self.model_path = config.attributes.fields["model_path"].string_value self.label_path = config.attributes.fields["label_path"].string_value self.is_keras = False - self.input_info = [] # input and output info are lists of tuples (name, shape, underlying type) + self.input_info = [] # input and output info are lists of tuples (name, shape, underlying type) self.output_info = [] _, ext = os.path.splitext(self.model_path) if ext.lower() == ".keras": # If it's a Keras model, load it using the Keras API - self.model = tf.keras.models.load_model(self.model_path) + self.model = keras.models.load_model(self.model_path) self.is_keras = True - # For now, we use first and last layer to get input and output info - in_config = self.model.layers[0].get_config() - out_config = self.model.layers[-1].get_config() - - # Keras model's output config's dtype is (sometimes?) a whole dict - outType = out_config.get("dtype") - if not isinstance(outType, str): - outType = None - - self.input_info.append( - ( - in_config.get("name"), - in_config.get("batch_shape"), - in_config.get("dtype"), + # So instead of handling just a single-input and single-output layer (as is when the Model is created using the + # Sequential API), we need to support the Functional API too which may have multi-input and output layers + # If input_info and output_info are empty, default to the first and last layer of the model + try: + inputs = self.model.inputs + if inputs: + self.input_info = [(i.name, i.shape, i.dtype) for i in inputs] + else: + raise AttributeError("'inputs' attributed not defined on the model, defaulting to the first layer instead") + except AttributeError: + in_config = self.model.layers[0].get_config() + self.input_info.append( + ( + in_config.get("name"), + in_config.get("batch_shape"), + in_config.get("dtype"), + ) ) - ) - self.output_info.append( - (out_config.get("name"), out_config.get("batch_shape"), outType) - ) + try: + outputs = self.model.outputs + if outputs: + self.output_info = [(o.name, o.shape, o.dtype) for o in outputs] + else: + raise AttributeError("'outputs' attributed not defined on the model, defaulting to the last layer instead") + except AttributeError: + out_config = self.model.layers[-1].get_config() + # Keras model's output config's dtype is (sometimes?) a whole dict + outType = out_config.get("dtype") + if not isinstance(outType, str): + LOGGER.info("Output dtype is not a string, using 'None' instead") + outType = None + self.output_info.append( + ( + out_config.get("name"), + out_config.get("batch_shape"), + outType, + ) + ) return # This is where we do the actual loading of the SavedModel @@ -227,7 +246,6 @@ async def metadata( Returns: Metadata: The metadata """ - extra = pb.Struct() extra["labels"] = self.label_path diff --git a/tests/test_tensorflow.py b/tests/test_tensorflow.py index 9d66ab1..c3b2320 100644 --- a/tests/test_tensorflow.py +++ b/tests/test_tensorflow.py @@ -7,18 +7,23 @@ import tensorflow as tf import numpy as np from numpy.typing import NDArray - +import keras def make_component_config(dictionary: Mapping[str, Any]) -> ComponentConfig: struct = Struct() struct.update(dictionary=dictionary) return ComponentConfig(attributes=struct) - +def make_sequential_keras_model(): + model = keras.Sequential([keras.layers.Dense(10), keras.layers.Dense(5)]) + model.build(input_shape=(None, 10)) + model.save("./tests/testmodel.keras") class TestTensorflowCPU: empty_config = make_component_config({}) + make_sequential_keras_model() + badconfig =make_component_config({ "model_path": "testModel" }) @@ -59,7 +64,7 @@ async def test_saved_model_infer(self): tfmodel = self.getTFCPU() tfmodel.reconfigure(config=self.saved_model_config, dependencies=None) fakeInput = {"input": np.ones([1,10,10,3])} # make a fake input thingy - out = await tfmodel.infer(input_tensors=fakeInput) + out = await tfmodel.infer(input_tensors=fakeInput) assert isinstance(out, Dict) for output in out: assert isinstance(out[output], np.ndarray) @@ -95,4 +100,30 @@ async def test_keras_metadata(self): assert isinstance(md, Metadata) assert hasattr(md, "name") assert hasattr(md, "input_info") - assert hasattr(md, "output_info") \ No newline at end of file + assert hasattr(md, "output_info") + + # KERAS TESTS + def getTFCPUKeras(self): + tfmodel = TensorflowModule("test") + tfmodel.model = tf.keras.models.load_model("./tests/testmodel.keras") + return tfmodel + + @pytest.mark.asyncio + async def test_infer_keras(self): + tf_keras_model = self.getTFCPUKeras() + tf_keras_model.reconfigure(config=self.keras_config, dependencies=None) + fakeInput = {"input_1": np.ones([1, 4])} + out = await tf_keras_model.infer(input_tensors=fakeInput) + assert isinstance(out, Dict) + for output in out.values(): + assert isinstance(output, np.ndarray) + + @pytest.mark.asyncio + async def test_metadata_keras(self): + tf_keras_model = self.getTFCPUKeras() + tf_keras_model.reconfigure(config=self.keras_config, dependencies=None) + md = await tf_keras_model.metadata() + assert isinstance(md, Metadata) + assert hasattr(md, "name") + assert hasattr(md, "input_info") + assert hasattr(md, "output_info")