Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
212 changes: 211 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,214 @@
# Largely copied from this link with a few additions at the bottom: https://github.com/github/gitignore/blob/main/Python.gitignore
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[codz]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py.cover
.hypothesis/
.pytest_cache/
cover/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
.pybuilder/
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version

# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock

# UV
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
#uv.lock

# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
#poetry.toml

# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
# pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
# https://pdm-project.org/en/latest/usage/project/#working-with-version-control
#pdm.lock
#pdm.toml
.pdm-python
.pdm-build/

# pixi
# Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
#pixi.lock
# Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
# in the .venv directory. It is recommended not to include this directory in version control.
.pixi

# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.envrc
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# pytype static type analyzer
.pytype/

# Cython debug symbols
cython_debug/

# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

# Abstra
# Abstra is an AI-powered process automation framework.
# Ignore directories containing user credentials, local state, and settings.
# Learn more at https://abstra.io/docs
.abstra/

# Visual Studio Code
# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
# and can be added to the global gitignore or merged into this file. However, if you prefer,
# you could uncomment the following to ignore the entire vscode folder
# .vscode/

# Ruff stuff:
.ruff_cache/

# PyPI configuration file
.pypirc

# Marimo
marimo/_static/
marimo/_lsp/
__marimo__/

# Streamlit
.streamlit/secrets.toml

# START OF ADDITIONS

# Keras executable
*.keras

.setup
.venv
build/
dist/
dist/
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ pyinstaller
pylint
pytest
pytest-asyncio
tensorflow_probability == 0.24.0
tensorflow_probability == 0.24.0
62 changes: 40 additions & 22 deletions src/tensorflow_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import numpy as np
import google.protobuf.struct_pb2 as pb
import tensorflow as tf
import keras


LOGGER = getLogger(__name__)
Expand Down Expand Up @@ -52,7 +53,6 @@ def validate_config(
LOGGER.info(
"Detected Keras model file at "
+ model_path
+ ". Please note Keras support is limited."
)
return ([], [])

Expand Down Expand Up @@ -87,35 +87,54 @@ def reconfigure(
self.model_path = config.attributes.fields["model_path"].string_value
self.label_path = config.attributes.fields["label_path"].string_value
self.is_keras = False
self.input_info = [] # input and output info are lists of tuples (name, shape, underlying type)
self.input_info = [] # input and output info are lists of tuples (name, shape, underlying type)
self.output_info = []

_, ext = os.path.splitext(self.model_path)
if ext.lower() == ".keras":
# If it's a Keras model, load it using the Keras API
self.model = tf.keras.models.load_model(self.model_path)
self.model = keras.models.load_model(self.model_path)
self.is_keras = True

# For now, we use first and last layer to get input and output info
in_config = self.model.layers[0].get_config()
out_config = self.model.layers[-1].get_config()

# Keras model's output config's dtype is (sometimes?) a whole dict
outType = out_config.get("dtype")
if not isinstance(outType, str):
outType = None

self.input_info.append(
(
in_config.get("name"),
in_config.get("batch_shape"),
in_config.get("dtype"),
# So instead of handling just a single-input and single-output layer (as is when the Model is created using the
# Sequential API), we need to support the Functional API too which may have multi-input and output layers
# If input_info and output_info are empty, default to the first and last layer of the model
try:
inputs = self.model.inputs
if inputs:
self.input_info = [(i.name, i.shape, i.dtype) for i in inputs]
else:
raise AttributeError("'inputs' attributed not defined on the model, defaulting to the first layer instead")
except AttributeError:
in_config = self.model.layers[0].get_config()
self.input_info.append(
(
in_config.get("name"),
in_config.get("batch_shape"),
in_config.get("dtype"),
)
)
)
self.output_info.append(
(out_config.get("name"), out_config.get("batch_shape"), outType)
)

try:
outputs = self.model.outputs
if outputs:
self.output_info = [(o.name, o.shape, o.dtype) for o in outputs]
else:
raise AttributeError("'outputs' attributed not defined on the model, defaulting to the last layer instead")
except AttributeError:
out_config = self.model.layers[-1].get_config()
# Keras model's output config's dtype is (sometimes?) a whole dict
outType = out_config.get("dtype")
if not isinstance(outType, str):
LOGGER.info("Output dtype is not a string, using 'None' instead")
outType = None
self.output_info.append(
(
out_config.get("name"),
out_config.get("batch_shape"),
outType,
)
)
return

# This is where we do the actual loading of the SavedModel
Expand Down Expand Up @@ -227,7 +246,6 @@ async def metadata(
Returns:
Metadata: The metadata
"""

extra = pb.Struct()
extra["labels"] = self.label_path

Expand Down
39 changes: 35 additions & 4 deletions tests/test_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,23 @@
import tensorflow as tf
import numpy as np
from numpy.typing import NDArray

import keras

def make_component_config(dictionary: Mapping[str, Any]) -> ComponentConfig:
struct = Struct()
struct.update(dictionary=dictionary)
return ComponentConfig(attributes=struct)


def make_sequential_keras_model():
model = keras.Sequential([keras.layers.Dense(10), keras.layers.Dense(5)])
model.build(input_shape=(None, 10))
model.save("./tests/testmodel.keras")

class TestTensorflowCPU:

empty_config = make_component_config({})
make_sequential_keras_model()

badconfig =make_component_config({
"model_path": "testModel"
})
Expand Down Expand Up @@ -59,7 +64,7 @@ async def test_saved_model_infer(self):
tfmodel = self.getTFCPU()
tfmodel.reconfigure(config=self.saved_model_config, dependencies=None)
fakeInput = {"input": np.ones([1,10,10,3])} # make a fake input thingy
out = await tfmodel.infer(input_tensors=fakeInput)
out = await tfmodel.infer(input_tensors=fakeInput)
assert isinstance(out, Dict)
for output in out:
assert isinstance(out[output], np.ndarray)
Expand Down Expand Up @@ -95,4 +100,30 @@ async def test_keras_metadata(self):
assert isinstance(md, Metadata)
assert hasattr(md, "name")
assert hasattr(md, "input_info")
assert hasattr(md, "output_info")
assert hasattr(md, "output_info")

# KERAS TESTS
def getTFCPUKeras(self):
tfmodel = TensorflowModule("test")
tfmodel.model = tf.keras.models.load_model("./tests/testmodel.keras")
return tfmodel

@pytest.mark.asyncio
async def test_infer_keras(self):
tf_keras_model = self.getTFCPUKeras()
tf_keras_model.reconfigure(config=self.keras_config, dependencies=None)
fakeInput = {"input_1": np.ones([1, 4])}
out = await tf_keras_model.infer(input_tensors=fakeInput)
assert isinstance(out, Dict)
for output in out.values():
assert isinstance(output, np.ndarray)

@pytest.mark.asyncio
async def test_metadata_keras(self):
tf_keras_model = self.getTFCPUKeras()
tf_keras_model.reconfigure(config=self.keras_config, dependencies=None)
md = await tf_keras_model.metadata()
assert isinstance(md, Metadata)
assert hasattr(md, "name")
assert hasattr(md, "input_info")
assert hasattr(md, "output_info")
Loading