diff --git a/.gitignore b/.gitignore index af37b90..7a4b798 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,4 @@ .setup +.venv +build/ +dist/ \ No newline at end of file diff --git a/src/tensorflow_module.py b/src/tensorflow_module.py index 6e64035..8b4716e 100755 --- a/src/tensorflow_module.py +++ b/src/tensorflow_module.py @@ -149,6 +149,7 @@ def _prepare_inputs(self, input_tensors: Dict[str, NDArray]) -> NDArray: input_t = tf.convert_to_tensor(input_t, dtype=self.input_info[i][2]) input_list.append(input_t) + if len(input_vars) == 1: return np.squeeze(np.asarray(input_list), axis=0) return np.asarray(input_list) @@ -279,7 +280,11 @@ def prepShape(tensorShape): # Want to return a simple string ("float32", "int64", etc.) def prepType(tensorType, is_keras): - if tensorType is None or not isinstance(tensorType, str): + if tensorType is None: + return "unknown" + if hasattr(tensorType, 'name'): + return tensorType.name + if not isinstance(tensorType, str): return "unknown" if is_keras: return tensorType diff --git a/tests/mock_model.keras b/tests/mock_model.keras new file mode 100644 index 0000000..a5c653b Binary files /dev/null and b/tests/mock_model.keras differ diff --git a/tests/test_tensorflow.py b/tests/test_tensorflow.py index f349496..9d66ab1 100644 --- a/tests/test_tensorflow.py +++ b/tests/test_tensorflow.py @@ -22,16 +22,26 @@ class TestTensorflowCPU: badconfig =make_component_config({ "model_path": "testModel" }) - goodconfig =make_component_config({ + saved_model_config =make_component_config({ "model_path": "./tests/EffNet", "label_path": "put/Labels/here.txt" }) + keras_config = make_component_config({ + "model_path": "./tests/mock_model.keras", + "label_path": "put/Labels/here.txt" + }) def getTFCPU(self): tfmodel = TensorflowModule("test") tfmodel.model = tf.saved_model.load("./tests/EffNet") return tfmodel + + def getKerasModel(self): + tfmodel = TensorflowModule("test") + tfmodel.model = tf.keras.models.load_model("./tests/mock_model.keras") + tfmodel.is_keras = True + return tfmodel def test_validate(self): @@ -40,13 +50,14 @@ def test_validate(self): response = tfm.validate_config(config=self.empty_config) with pytest.raises(Exception): response = tfm.validate_config(config=self.badconfig) - response = tfm.validate_config(config=self.goodconfig) + response = tfm.validate_config(config=self.saved_model_config) + response = tfm.validate_config(config=self.keras_config) @pytest.mark.asyncio - async def test_infer(self): + async def test_saved_model_infer(self): tfmodel = self.getTFCPU() - tfmodel.reconfigure(config=self.goodconfig, dependencies=None) + tfmodel.reconfigure(config=self.saved_model_config, dependencies=None) fakeInput = {"input": np.ones([1,10,10,3])} # make a fake input thingy out = await tfmodel.infer(input_tensors=fakeInput) assert isinstance(out, Dict) @@ -55,9 +66,31 @@ async def test_infer(self): @pytest.mark.asyncio - async def test_metadata(self): + async def test_saved_model_metadata(self): tfmodel = self.getTFCPU() - tfmodel.reconfigure(config=self.goodconfig, dependencies=None) + tfmodel.reconfigure(config=self.saved_model_config, dependencies=None) + md = await tfmodel.metadata() + assert isinstance(md, Metadata) + assert hasattr(md, "name") + assert hasattr(md, "input_info") + assert hasattr(md, "output_info") + + @pytest.mark.asyncio + async def test_keras_infer(self): + tfmodel = self.getKerasModel() + tfmodel.reconfigure(config=self.keras_config, dependencies=None) + + fakeInput = {"input": np.ones([1, 4]).astype(np.float32)} + + out = await tfmodel.infer(input_tensors=fakeInput) + assert isinstance(out, Dict) + for output in out: + assert isinstance(out[output], np.ndarray) + + @pytest.mark.asyncio + async def test_keras_metadata(self): + tfmodel = self.getKerasModel() + tfmodel.reconfigure(config=self.keras_config, dependencies=None) md = await tfmodel.metadata() assert isinstance(md, Metadata) assert hasattr(md, "name")