Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions src/tensorflow_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,10 @@ def new_service(
def validate_config(
cls, config: ServiceConfig
) -> Tuple[Sequence[str], Sequence[str]]:
model_path_err = (
"model_path must be the location of the Tensorflow SavedModel directory "
"or the location of a Keras model file (.keras)"
)

model_path = config.attributes.fields["model_path"].string_value

if model_path == "":
raise Exception(model_path_err)
raise Exception(get_model_path_error(model_path))

# If it's a Keras model file, okay. Otherwise, it must be a SavedModel directory
_, ext = os.path.splitext(model_path)
Expand All @@ -64,7 +60,7 @@ def validate_config(
# and that the model file isn't too big (>500 MB)
isValidSavedModel = False
if not os.path.isdir(model_path):
raise Exception(model_path_err)
raise Exception(get_model_path_error(model_path))
for file in os.listdir(model_path):
if ".pb" in file:
isValidSavedModel = True
Expand All @@ -77,7 +73,7 @@ def validate_config(
)

if not isValidSavedModel:
raise Exception(model_path_err)
raise Exception(get_model_path_error(model_path))

return ([], [])

Expand Down Expand Up @@ -293,3 +289,8 @@ def prepType(tensorType, is_keras):
s = str(tensorType)
inds = [i for i, letter in enumerate(s) if letter == "'"]
return s[inds[0] + 1 : inds[1]]


def get_model_path_error(model_path: str) -> str:
"""Generate informative error message for model_path validation failures."""
return f"Invalid model_path '{model_path}'. model_path must be the location of the Tensorflow SavedModel directory or the location of a Keras model file (.keras)"
Loading