From 529cbf94fe0b3044f4a6e52f3e3394efb3329c14 Mon Sep 17 00:00:00 2001 From: Vignesh Date: Fri, 19 Sep 2025 13:10:05 -0400 Subject: [PATCH] also return model_path --- src/tensorflow_module.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/src/tensorflow_module.py b/src/tensorflow_module.py index 8b4716e..c7626a2 100755 --- a/src/tensorflow_module.py +++ b/src/tensorflow_module.py @@ -37,14 +37,10 @@ def new_service( def validate_config( cls, config: ServiceConfig ) -> Tuple[Sequence[str], Sequence[str]]: - model_path_err = ( - "model_path must be the location of the Tensorflow SavedModel directory " - "or the location of a Keras model file (.keras)" - ) - model_path = config.attributes.fields["model_path"].string_value + if model_path == "": - raise Exception(model_path_err) + raise Exception(get_model_path_error(model_path)) # If it's a Keras model file, okay. Otherwise, it must be a SavedModel directory _, ext = os.path.splitext(model_path) @@ -64,7 +60,7 @@ def validate_config( # and that the model file isn't too big (>500 MB) isValidSavedModel = False if not os.path.isdir(model_path): - raise Exception(model_path_err) + raise Exception(get_model_path_error(model_path)) for file in os.listdir(model_path): if ".pb" in file: isValidSavedModel = True @@ -77,7 +73,7 @@ def validate_config( ) if not isValidSavedModel: - raise Exception(model_path_err) + raise Exception(get_model_path_error(model_path)) return ([], []) @@ -293,3 +289,8 @@ def prepType(tensorType, is_keras): s = str(tensorType) inds = [i for i, letter in enumerate(s) if letter == "'"] return s[inds[0] + 1 : inds[1]] + + +def get_model_path_error(model_path: str) -> str: + """Generate informative error message for model_path validation failures.""" + return f"Invalid model_path '{model_path}'. model_path must be the location of the Tensorflow SavedModel directory or the location of a Keras model file (.keras)"