Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
.setup
.venv
build/
dist/
7 changes: 6 additions & 1 deletion src/tensorflow_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ def _prepare_inputs(self, input_tensors: Dict[str, NDArray]) -> NDArray:
input_t = tf.convert_to_tensor(input_t, dtype=self.input_info[i][2])
input_list.append(input_t)


if len(input_vars) == 1:
return np.squeeze(np.asarray(input_list), axis=0)
return np.asarray(input_list)
Expand Down Expand Up @@ -279,7 +280,11 @@ def prepShape(tensorShape):

# Want to return a simple string ("float32", "int64", etc.)
def prepType(tensorType, is_keras):
if tensorType is None or not isinstance(tensorType, str):
if tensorType is None:
return "unknown"
if hasattr(tensorType, 'name'):
return tensorType.name
if not isinstance(tensorType, str):
return "unknown"
if is_keras:
return tensorType
Expand Down
Binary file added tests/mock_model.keras
Binary file not shown.
45 changes: 39 additions & 6 deletions tests/test_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,26 @@ class TestTensorflowCPU:
badconfig =make_component_config({
"model_path": "testModel"
})
goodconfig =make_component_config({
saved_model_config =make_component_config({
"model_path": "./tests/EffNet",
"label_path": "put/Labels/here.txt"
})
keras_config = make_component_config({
"model_path": "./tests/mock_model.keras",
"label_path": "put/Labels/here.txt"
})


def getTFCPU(self):
tfmodel = TensorflowModule("test")
tfmodel.model = tf.saved_model.load("./tests/EffNet")
return tfmodel

def getKerasModel(self):
tfmodel = TensorflowModule("test")
tfmodel.model = tf.keras.models.load_model("./tests/mock_model.keras")
tfmodel.is_keras = True
return tfmodel


def test_validate(self):
Expand All @@ -40,13 +50,14 @@ def test_validate(self):
response = tfm.validate_config(config=self.empty_config)
with pytest.raises(Exception):
response = tfm.validate_config(config=self.badconfig)
response = tfm.validate_config(config=self.goodconfig)
response = tfm.validate_config(config=self.saved_model_config)
response = tfm.validate_config(config=self.keras_config)


@pytest.mark.asyncio
async def test_infer(self):
async def test_saved_model_infer(self):
tfmodel = self.getTFCPU()
tfmodel.reconfigure(config=self.goodconfig, dependencies=None)
tfmodel.reconfigure(config=self.saved_model_config, dependencies=None)
fakeInput = {"input": np.ones([1,10,10,3])} # make a fake input thingy
out = await tfmodel.infer(input_tensors=fakeInput)
assert isinstance(out, Dict)
Expand All @@ -55,9 +66,31 @@ async def test_infer(self):


@pytest.mark.asyncio
async def test_metadata(self):
async def test_saved_model_metadata(self):
tfmodel = self.getTFCPU()
tfmodel.reconfigure(config=self.goodconfig, dependencies=None)
tfmodel.reconfigure(config=self.saved_model_config, dependencies=None)
md = await tfmodel.metadata()
assert isinstance(md, Metadata)
assert hasattr(md, "name")
assert hasattr(md, "input_info")
assert hasattr(md, "output_info")

@pytest.mark.asyncio
async def test_keras_infer(self):
tfmodel = self.getKerasModel()
tfmodel.reconfigure(config=self.keras_config, dependencies=None)

fakeInput = {"input": np.ones([1, 4]).astype(np.float32)}

out = await tfmodel.infer(input_tensors=fakeInput)
assert isinstance(out, Dict)
for output in out:
assert isinstance(out[output], np.ndarray)

@pytest.mark.asyncio
async def test_keras_metadata(self):
tfmodel = self.getKerasModel()
tfmodel.reconfigure(config=self.keras_config, dependencies=None)
md = await tfmodel.metadata()
assert isinstance(md, Metadata)
assert hasattr(md, "name")
Expand Down