diff --git a/tests/e2e/configuration/lightspeed-stack-invalid-feedback-storage.yaml b/tests/e2e/configuration/lightspeed-stack-invalid-feedback-storage.yaml new file mode 100644 index 000000000..2474cefa4 --- /dev/null +++ b/tests/e2e/configuration/lightspeed-stack-invalid-feedback-storage.yaml @@ -0,0 +1,25 @@ +name: Lightspeed Core Service (LCS) +service: + host: 0.0.0.0 + port: 8080 + auth_enabled: false + workers: 1 + color_log: true + access_log: true +llama_stack: + # Uses a remote llama-stack service + # The instance would have already been started with a llama-stack-run.yaml file + use_as_library_client: false + # Alternative for "as library use" + # use_as_library_client: true + # library_client_config_path: + url: http://llama-stack:8321 + api_key: xyzzy +user_data_collection: + feedback_enabled: true + feedback_storage: "/invalid" + transcripts_enabled: true + transcripts_storage: "/tmp/data/transcripts" + +authentication: + module: "noop-with-token" diff --git a/tests/e2e/features/environment.py b/tests/e2e/features/environment.py index 2c1cbfc4d..c46311601 100644 --- a/tests/e2e/features/environment.py +++ b/tests/e2e/features/environment.py @@ -7,12 +7,18 @@ 4. after_scenario """ +import requests import subprocess import time from behave.model import Scenario, Feature from behave.runner import Context -from tests.e2e.utils.utils import switch_config_and_restart +from tests.e2e.utils.utils import ( + switch_config, + restart_container, + remove_config_backup, + create_config_backup, +) try: import os # noqa: F401 @@ -32,10 +38,18 @@ def before_scenario(context: Context, scenario: Scenario) -> None: if "local" in scenario.effective_tags and not context.local: scenario.skip("Marked with @local") return + if "InvalidFeedbackStorageConfig" in scenario.effective_tags: + context.scenario_config = ( + "tests/e2e/configuration/lightspeed-stack-invalid-feedback-storage.yaml" + ) def after_scenario(context: Context, scenario: Scenario) -> None: """Run after each scenario is run.""" + if "InvalidFeedbackStorageConfig" in scenario.effective_tags: + switch_config(context.feature_config) + restart_container("lightspeed-stack") + # Restore Llama Stack connection if it was disrupted if hasattr(context, "llama_stack_was_running") and context.llama_stack_was_running: try: @@ -87,19 +101,28 @@ def after_scenario(context: Context, scenario: Scenario) -> None: def before_feature(context: Context, feature: Feature) -> None: """Run before each feature file is exercised.""" if "Authorized" in feature.tags: - context.backup_file = switch_config_and_restart( - "lightspeed-stack.yaml", - "tests/e2e/configuration/lightspeed-stack-auth-noop-token.yaml", - "lightspeed-stack", + context.feature_config = ( + "tests/e2e/configuration/lightspeed-stack-auth-noop-token.yaml" ) + context.default_config_backup = create_config_backup("lightspeed-stack.yaml") + switch_config(context.feature_config) + restart_container("lightspeed-stack") + + if "Feedback" in feature.tags: + context.feedback_conversations = [] def after_feature(context: Context, feature: Feature) -> None: """Run after each feature file is exercised.""" if "Authorized" in feature.tags: - switch_config_and_restart( - "lightspeed-stack.yaml", - context.backup_file, - "lightspeed-stack", - cleanup=True, - ) + switch_config(context.default_config_backup) + restart_container("lightspeed-stack") + remove_config_backup(context.default_config_backup) + + if "Feedback" in feature.tags: + print(context.feedback_conversations) + for conversation_id in context.feedback_conversations: + url = f"http://localhost:8080/v1/conversations/{conversation_id}" + headers = context.auth_headers if hasattr(context, "auth_headers") else {} + response = requests.delete(url, headers=headers) + assert response.status_code == 200, url diff --git a/tests/e2e/features/feedback.feature b/tests/e2e/features/feedback.feature index bf5b43c20..dce0d0514 100644 --- a/tests/e2e/features/feedback.feature +++ b/tests/e2e/features/feedback.feature @@ -1,90 +1,293 @@ -# Feature: feedback endpoint API tests +@Authorized @Feedback +Feature: feedback endpoint API tests -# Background: -# Given The service is started locally -# And REST API service hostname is localhost -# And REST API service port is 8080 -# And REST API service prefix is /v1 + Background: + Given The service is started locally + And REST API service hostname is localhost + And REST API service port is 8080 + And REST API service prefix is /v1 + And I set the Authorization header to Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6Ikpva + Scenario: Check if enabling the feedback is working + Given The system is in default state + When The feedback is enabled + Then The status code of the response is 200 + And And the body of the response has the following structure + """ + { + "status": + { + "updated_status": true + } + } + """ + + Scenario: Check if disabling the feedback is working + Given The system is in default state + When The feedback is disabled + Then The status code of the response is 200 + And And the body of the response has the following structure + """ + { + "status": + { + "updated_status": false + } + } + """ -# Scenario: Check if feedback endpoint is working -# Given The system is in default state -# When I access endpoint "feedback" using HTTP POST with conversation ID conversationID -# """ -# { -# "llm_response": "bar", -# "sentiment": -1, -# "user_feedback": "Not satisfied with the response quality", -# "user_question": "random question" -# } -# """ -# Then The status code of the response is 200 -# And The body of the response is the following -# """ -# {"response": "feedback received"} -# """ + Scenario: Check if toggling the feedback with incorrect attribute name fails + Given The system is in default state + When I update feedback status with + """ + { + "no_status": true + } + """ + Then The status code of the response is 422 + And And the body of the response has the following structure + """ + { + "detail": [ + { + "type": "extra_forbidden", + "loc": [ + "body", + "no_status" + ], + "msg": "Extra inputs are not permitted", + "input": true + } + ] + } + """ -# Scenario: Check if feedback endpoint is not working when not authorized -# Given The system is in default state -# And I remove the auth header -# When I access endpoint "feedback" using HTTP POST with conversation ID conversationID -# """ -# { -# "llm_response": "bar", -# "sentiment": -1, -# "user_feedback": "Not satisfied with the response quality", -# "user_question": "random question" -# } -# """ -# Then The status code of the response is 400 -# And The body of the response is the following -# """ -# {"response": "feedback received"} -# """ + Scenario: Check if getting feedback status returns true when feedback is enabled + Given The system is in default state + And The feedback is enabled + When I retreive the current feedback status + Then The status code of the response is 200 + And The body of the response is the following + """ + { + "functionality": "feedback", + "status": { + "enabled": true + } + } + """ -# Scenario: Check if feedback endpoint is not working when feedback is disabled -# Given The system is in default state -# And I disable the feedback -# When I access endpoint "feedback" using HTTP POST with conversation ID conversationID -# """ -# { -# "llm_response": "bar", -# "sentiment": -1, -# "user_feedback": "Not satisfied with the response quality", -# "user_question": "random question" -# } -# """ -# Then The status code of the response is 403 -# And The body of the response is the following -# """ -# {"response": "feedback received"} -# """ + Scenario: Check if getting feedback status returns false when feedback is disabled + Given The system is in default state + And The feedback is disabled + When I retreive the current feedback status + Then The status code of the response is 200 + And The body of the response is the following + """ + { + "functionality": "feedback", + "status": { + "enabled": false + } + } + """ -# Scenario: Check if feedback endpoint fails with incorrect body format when conversationID is not present -# Given The system is in default state -# When I access endpoint "feedback" using HTTP POST method -# """ -# { -# "llm_response": "bar", -# "sentiment": -1, -# "user_feedback": "Not satisfied with the response quality", -# "user_question": "random question" -# } -# """ -# Then The status code of the response is 422 -# And The body of the response is the following -# """ -# { "type": "missing", "loc": [ "body", "conversation_id" ], "msg": "Field required", } -# """ + Scenario: Check if feedback endpoint is not working when feedback is disabled + Given The system is in default state + And A new conversation is initialized + And The feedback is disabled + When I submit the following feedback for the conversation created before + """ + { + "llm_response": "bar", + "sentiment": -1, + "user_feedback": "Not satisfied with the response quality", + "user_question": "Sample Question" + } + """ + Then The status code of the response is 403 + And The body of the response is the following + """ + { + "detail": "Forbidden: User is not authorized to access this resource" + } + """ -# Scenario: Check if feedback/status endpoint is working -# Given The system is in default state -# When I access REST API endpoint "feedback/status" using HTTP GET method -# Then The status code of the response is 200 -# And The body of the response is the following -# """ -# {"functionality": "feedback", "status": { "enabled": true}} -# """ + Scenario: Check if feedback endpoint fails when required fields are not specified + Given The system is in default state + And The feedback is enabled + When I submit the following feedback without specifying conversation ID + """ + { + } + """ + Then The status code of the response is 422 + And And the body of the response has the following structure + """ + { + "detail": [ + { + "type": "missing", + "loc": [ + "body", + "conversation_id" + ], + "msg": "Field required" + }, + { + "type": "missing", + "loc": [ + "body", + "user_question" + ], + "msg": "Field required" + }, + { + "type": "missing", + "loc": [ + "body", + "llm_response" + ], + "msg": "Field required" + } + ] + } + """ + Scenario: Check if feedback endpoint is working when sentiment is negative + Given The system is in default state + And A new conversation is initialized + And The feedback is enabled + When I submit the following feedback for the conversation created before + """ + { + "llm_response": "bar", + "sentiment": -1, + "user_feedback": "Not satisfied with the response quality", + "user_question": "Sample Question" + } + """ + Then The status code of the response is 200 + And The body of the response is the following + """ + { + "response": "feedback received" + } + """ + Scenario: Check if feedback endpoint is working when sentiment is positive + Given The system is in default state + And A new conversation is initialized + And The feedback is enabled + When I submit the following feedback for the conversation created before + """ + { + "llm_response": "bar", + "sentiment": 1, + "user_feedback": "Satisfied with the response quality", + "user_question": "Sample Question" + } + """ + Then The status code of the response is 200 + And The body of the response is the following + """ + { + "response": "feedback received" + } + """ + + Scenario: Check if feedback submittion fails when invald sentiment is passed + Given The system is in default state + And A new conversation is initialized + And The feedback is enabled + When I submit the following feedback for the conversation created before + """ + { + "llm_response": "Sample Response", + "sentiment": 0, + "user_feedback": "Not satisfied with the response quality", + "user_question": "Sample Question" + } + """ + Then The status code of the response is 422 + And And the body of the response has the following structure + """ + { + "detail": [{ + "type": "value_error", + "loc": ["body", "sentiment"], + "msg": "Value error, Improper sentiment value of 0, needs to be -1 or 1", + "input": 0 + }] + } + """ + + @skip + Scenario: Check if feedback submittion fails when nonexisting conversation ID is passed + Given The system is in default state + And A new conversation is initialized + And The feedback is enabled + When I submit the following feedback for nonexisting conversation "12345678-abcd-0000-0123-456789abcdef" + """ + { + "llm_response": "Sample Response", + "sentiment": -1, + "user_feedback": "Not satisfied with the response quality", + "user_question": "Sample Question" + } + """ + Then The status code of the response is 422 + And The body of the response is the following + """ + { + "response": "User has no access to this conversation" + } + """ + + Scenario: Check if feedback endpoint is not working when not authorized + Given The system is in default state + And A new conversation is initialized + And I remove the auth header + When I submit the following feedback for the conversation created before + """ + { + "llm_response": "Sample Response", + "sentiment": -1, + "user_feedback": "Not satisfied with the response quality", + "user_question": "Sample Question" + } + """ + Then The status code of the response is 400 + And The body of the response is the following + """ + { + "detail": "No Authorization header found" + } + """ + + @InvalidFeedbackStorageConfig + Scenario: Check if feedback submittion fails when invalid feedback storage path is configured + Given The system is in default state + And The feedback is enabled + And An invalid feedback storage path is configured + And A new conversation is initialized + When I submit the following feedback for the conversation created before + """ + { + "llm_response": "Sample Response", + "sentiment": -1, + "user_feedback": "Not satisfied with the response quality", + "user_question": "Sample Question" + } + """ + Then The status code of the response is 500 + And The body of the response is the following + """ + { + "detail": { + "response": "Error storing user feedback", + "cause": "[Errno 13] Permission denied: '/invalid'" + } + } + """ \ No newline at end of file diff --git a/tests/e2e/features/steps/auth.py b/tests/e2e/features/steps/auth.py index 64b5cb006..c2bf3c7d0 100644 --- a/tests/e2e/features/steps/auth.py +++ b/tests/e2e/features/steps/auth.py @@ -15,6 +15,13 @@ def set_authorization_header_custom(context: Context, header_value: str) -> None print(f"🔑 Set Authorization header to: {header_value}") +@given("I remove the auth header") # type: ignore +def remove_authorization_header(context: Context) -> None: + """Remove Authorization header.""" + if hasattr(context, "auth_headers") and "Authorization" in context.auth_headers: + del context.auth_headers["Authorization"] + + @when("I access endpoint {endpoint} using HTTP POST method with user_id {user_id}") def access_rest_api_endpoint_post( context: Context, endpoint: str, user_id: str diff --git a/tests/e2e/features/steps/common_http.py b/tests/e2e/features/steps/common_http.py index 6b1e56a75..33ed48cbb 100644 --- a/tests/e2e/features/steps/common_http.py +++ b/tests/e2e/features/steps/common_http.py @@ -5,7 +5,11 @@ import requests from behave import then, when, step # pyright: ignore[reportAttributeAccessIssue] from behave.runner import Context -from tests.e2e.utils.utils import normalize_endpoint, validate_json +from tests.e2e.utils.utils import ( + normalize_endpoint, + validate_json, + validate_json_partially, +) # default timeout for HTTP operations DEFAULT_TIMEOUT = 10 @@ -235,15 +239,16 @@ def access_rest_api_endpoint_get(context: Context, endpoint: str) -> None: base = f"http://{context.hostname}:{context.port}" path = f"{context.api_prefix}/{endpoint}".replace("//", "/") url = base + path + headers = context.auth_headers if hasattr(context, "auth_headers") else {} # initial value context.response = None # perform REST API call - context.response = requests.get(url, timeout=DEFAULT_TIMEOUT) + context.response = requests.get(url, headers=headers, timeout=DEFAULT_TIMEOUT) @when("I access endpoint {endpoint} using HTTP POST method") -def access_rest_api_endpoint_post(context: Context, endpoint: str) -> None: +def access_non_rest_api_endpoint_post(context: Context, endpoint: str) -> None: """Send POST HTTP request with JSON payload to tested service. The JSON payload is retrieved from `context.text` attribute, @@ -257,11 +262,64 @@ def access_rest_api_endpoint_post(context: Context, endpoint: str) -> None: assert context.text is not None, "Payload needs to be specified" data = json.loads(context.text) + headers = context.auth_headers if hasattr(context, "auth_headers") else {} + # initial value + context.response = None + + # perform REST API call + context.response = requests.post( + url, json=data, headers=headers, timeout=DEFAULT_TIMEOUT + ) + + +@when("I access REST API endpoint {endpoint} using HTTP POST method") +def access_rest_api_endpoint_post(context: Context, endpoint: str) -> None: + """Send POST HTTP request with JSON payload to tested service. + + The JSON payload is retrieved from `context.text` attribute, + which must not be None. The response is stored in + `context.response` attribute. + """ + endpoint = normalize_endpoint(endpoint) + base = f"http://{context.hostname}:{context.port}" + path = f"{context.api_prefix}/{endpoint}".replace("//", "/") + url = base + path + + assert context.text is not None, "Payload needs to be specified" + data = json.loads(context.text) + headers = context.auth_headers if hasattr(context, "auth_headers") else {} + # initial value + context.response = None + + # perform REST API call + context.response = requests.post( + url, json=data, headers=headers, timeout=DEFAULT_TIMEOUT + ) + + +@when("I access REST API endpoint {endpoint} using HTTP PUT method") +def access_rest_api_endpoint_put(context: Context, endpoint: str) -> None: + """Send PUT HTTP request with JSON payload to tested service. + + The JSON payload is retrieved from `context.text` attribute, + which must not be None. The response is stored in + `context.response` attribute. + """ + endpoint = normalize_endpoint(endpoint) + base = f"http://{context.hostname}:{context.port}" + path = f"{context.api_prefix}/{endpoint}".replace("//", "/") + url = base + path + + assert context.text is not None, "Payload needs to be specified" + data = json.loads(context.text) + headers = context.auth_headers if hasattr(context, "auth_headers") else {} # initial value context.response = None # perform REST API call - context.response = requests.post(url, json=data, timeout=DEFAULT_TIMEOUT) + context.response = requests.put( + url, json=data, headers=headers, timeout=DEFAULT_TIMEOUT + ) @then('The status message of the response is "{expected_message}"') @@ -303,3 +361,16 @@ def check_for_null_attribute(context: Context, attribute: str) -> None: assert ( value is None ), f"Attribute {attribute} should be null, but it contains {value}" + + +@then("And the body of the response has the following structure") +def check_response_partially(context: Context) -> None: + """Validate that the response body matches the expected JSON structure. + + Compares the actual response JSON against the expected structure defined + in `context.text`, ignoring extra keys or values not specified. + """ + assert context.response is not None, "Request needs to be performed first" + body = context.response.json() + expected = json.loads(context.text or "{}") + validate_json_partially(body, expected) diff --git a/tests/e2e/features/steps/feedback.py b/tests/e2e/features/steps/feedback.py index 5184db8da..ccd92db39 100644 --- a/tests/e2e/features/steps/feedback.py +++ b/tests/e2e/features/steps/feedback.py @@ -1,38 +1,122 @@ -"""Implementation of common test steps.""" +"""Implementation of common test steps for the feedback API.""" -from behave import given, when # pyright: ignore[reportAttributeAccessIssue] +from behave import given, when, step # pyright: ignore[reportAttributeAccessIssue] from behave.runner import Context import requests +import json +from tests.e2e.utils.utils import switch_config, restart_container +from tests.e2e.features.steps.common_http import access_rest_api_endpoint_get # default timeout for HTTP operations DEFAULT_TIMEOUT = 10 -@when( - "I access endpoint {endpoint:w} using HTTP POST with conversation ID {conversationID:w}" -) -def access_rest_api_endpoint_post( - context: Context, endpoint: str, conversation_id: str +@step("The feedback is enabled") # type: ignore +def enable_feedback(context: Context) -> None: + """Enable the feedback endpoint and assert success.""" + assert context is not None + payload = {"status": True} + access_feedback_put_endpoint(context, payload) + assert context.response.status_code == 200, "Enabling feedback was unsuccessful" + + +@step("The feedback is disabled") # type: ignore +def disable_feedback(context: Context) -> None: + """Disable the feedback endpoint and assert success.""" + assert context is not None + payload = {"status": False} + access_feedback_put_endpoint(context, payload) + assert context.response.status_code == 200, "Disabling feedback was unsuccessful" + + +@when("I update feedback status with") # type: ignore +def set_feedback(context: Context) -> None: + """Enable or disable feedback via PUT request.""" + assert context.text is not None, "Payload needs to be specified" + payload = json.loads(context.text or "{}") + access_feedback_put_endpoint(context, payload) + + +def access_feedback_put_endpoint(context: Context, payload: dict) -> None: + """Update feedback using a JSON payload.""" + assert context is not None + endpoint = "feedback/status" + base = f"http://{context.hostname}:{context.port}" + path = f"{context.api_prefix}/{endpoint}".replace("//", "/") + url = base + path + headers = context.auth_headers if hasattr(context, "auth_headers") else {} + response = requests.put(url, headers=headers, json=payload) + context.response = response + + +@when("I submit the following feedback for the conversation created before") # type: ignore +def submit_feedback_valid_conversation(context: Context) -> None: + """Submit feedback for previousl created conversation.""" + assert ( + hasattr(context, "conversation_id") and context.conversation_id is not None + ), "Conversation for feedback submission is not created" + access_feedback_post_endpoint(context, context.conversation_id) + + +@when('I submit the following feedback for nonexisting conversation "{conversation_id}"') # type: ignore +def submit_feedback_nonexisting_conversation( + context: Context, conversation_id: str ) -> None: - """Send POST HTTP request with JSON payload to tested service. + """Submit feedback for a non-existing conversation ID.""" + access_feedback_post_endpoint(context, conversation_id) + + +@when("I submit the following feedback without specifying conversation ID") # type: ignore +def submit_feedback_without_conversation(context: Context) -> None: + """Submit feedback with no conversation ID.""" + access_feedback_post_endpoint(context, None) - The JSON payload is retrieved from `context.text` attribute, - which must not be None. The response is stored in - `context.response` attribute. - """ + +def access_feedback_post_endpoint( + context: Context, conversation_id: str | None +) -> None: + """Send POST HTTP request with JSON payload to tested service.""" + endpoint = "feedback" base = f"http://{context.hostname}:{context.port}" path = f"{context.api_prefix}/{endpoint}".replace("//", "/") url = base + path + payload = json.loads(context.text or "{}") + if conversation_id is not None: + payload["conversation_id"] = conversation_id + headers = context.auth_headers if hasattr(context, "auth_headers") else {} + context.response = requests.post(url, headers=headers, json=payload) - assert conversation_id is not None, "Payload needs to be specified" - # TODO: finish the conversation ID handling - # perform REST API call - context.response = requests.post(url, timeout=DEFAULT_TIMEOUT) +@when("I retreive the current feedback status") # type: ignore +def access_feedback_get_endpoint(context: Context) -> None: + """Retrieve the current feedback status via GET request.""" + access_rest_api_endpoint_get(context, "feedback/status") -@given("I disable the feedback") -def disable_feedback(context: Context) -> None: - """Disable feedback.""" - # TODO: add step implementation - assert context is not None +@given("A new conversation is initialized") # type: ignore +def initialize_conversation(context: Context) -> None: + """Create a conversation for submitting feedback.""" + endpoint = "query" + base = f"http://{context.hostname}:{context.port}" + path = f"{context.api_prefix}/{endpoint}".replace("//", "/") + url = base + path + headers = context.auth_headers if hasattr(context, "auth_headers") else {} + payload = {"query": "Say Hello.", "system_prompt": "You are a helpful assistant"} + + response = requests.post(url, headers=headers, json=payload) + assert ( + response.status_code == 200 + ), f"Failed to create conversation: {response.text}" + + body = response.json() + context.conversation_id = body["conversation_id"] + assert context.conversation_id, "Conversation was not created." + context.feedback_conversations.append(context.conversation_id) + context.response = response + + +@given("An invalid feedback storage path is configured") # type: ignore +def configure_invalid_feedback_storage_path(context: Context) -> None: + """Set an invalid feedback storage path and restart the container.""" + switch_config(context.scenario_config) + restart_container("lightspeed-stack") diff --git a/tests/e2e/utils/utils.py b/tests/e2e/utils/utils.py index 5f189ff90..54407350a 100644 --- a/tests/e2e/utils/utils.py +++ b/tests/e2e/utils/utils.py @@ -4,9 +4,8 @@ import shutil import subprocess import time -from typing import Any - import jsonschema +from typing import Any def normalize_endpoint(endpoint: str) -> str: @@ -67,39 +66,68 @@ def wait_for_container_health(container_name: str, max_attempts: int = 3) -> Non print(f"Could not check health status for {container_name}") -def switch_config_and_restart( - original_file: str, - replacement_file: str, - container_name: str, - cleanup: bool = False, -) -> str: - """Switch configuration file and restart container. +def validate_json_partially(actual: Any, expected: Any): + """Recursively validate that `actual` JSON contains all keys and values specified in `expected`. - Args: - original_file: Path to the original configuration file - replacement_file: Path to the replacement configuration file - container_name: Name of the container to restart - cleanup: If True, remove the backup file after restoration (default: False) - - Returns: - str: Path to the backup file for restoration + Extra elements/keys are ignored. Raises AssertionError if validation fails. """ - backup_file = f"{original_file}.backup" + if isinstance(expected, dict): + for key, expected_value in expected.items(): + assert key in actual, f"Missing key in JSON: {key}" + validate_json_partially(actual[key], expected_value) + + elif isinstance(expected, list): + for schema_item in expected: + matched = False + for item in actual: + try: + validate_json_partially(item, schema_item) + matched = True + break + except AssertionError: + continue + assert ( + matched + ), f"No matching element found in list for schema item {schema_item}" + + else: + assert actual == expected, f"Value mismatch: expected {expected}, got {actual}" + + +def switch_config( + source_path: str, destination_path: str = "lightspeed-stack.yaml" +) -> None: + """Overwrite the config in `destination_path` by `source_path`.""" + try: + shutil.copy(source_path, destination_path) + except (FileNotFoundError, PermissionError, OSError) as e: + print(f"Failed to copy replacement file: {e}") + raise + - if not cleanup and not os.path.exists(backup_file): +def create_config_backup(config_path: str) -> str: + """Create a backup of `config_path` if it does not already exist.""" + backup_file = f"{config_path}.backup" + if not os.path.exists(backup_file): try: - shutil.copy(original_file, backup_file) + shutil.copy(config_path, backup_file) except (FileNotFoundError, PermissionError, OSError) as e: print(f"Failed to create backup: {e}") raise + return backup_file + + +def remove_config_backup(backup_path: str) -> None: + """Delete the backup file at `backup_path` if it exists.""" + if os.path.exists(backup_path): + try: + os.remove(backup_path) + except OSError as e: + print(f"Warning: Could not remove backup file {backup_path}: {e}") - try: - shutil.copy(replacement_file, original_file) - except (FileNotFoundError, PermissionError, OSError) as e: - print(f"Failed to copy replacement file: {e}") - raise - # Restart container +def restart_container(container_name: str) -> None: + """Restart a Docker container by name and wait until it is healthy.""" try: subprocess.run( ["docker", "restart", container_name], @@ -113,12 +141,3 @@ def switch_config_and_restart( # Wait for container to be healthy wait_for_container_health(container_name) - - # Clean up backup file - if cleanup and os.path.exists(backup_file): - try: - os.remove(backup_file) - except OSError as e: - print(f"Warning: Could not remove backup file {backup_file}: {e}") - - return backup_file