diff --git a/copick_server/server.py b/copick_server/server.py index 19e8b21..ae7bff8 100644 --- a/copick_server/server.py +++ b/copick_server/server.py @@ -217,13 +217,20 @@ def create_copick_app(root: copick.models.CopickRoot, cors_origins: Optional[Lis # Add CORS middleware if origins are specified if cors_origins: - app.add_middleware( - CORSMiddleware, - allow_origins=cors_origins, - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], - ) + # Ensure CORS middleware is properly initialized + try: + from fastapi.middleware.cors import CORSMiddleware + app.add_middleware( + CORSMiddleware, + allow_origins=cors_origins, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + # Print for debugging + print(f"CORS middleware added with origins: {cors_origins}") + except Exception as e: + print(f"Error adding CORS middleware: {str(e)}") return app diff --git a/pyproject.toml b/pyproject.toml index bb9eaa1..6105915 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,7 @@ dependencies = [ "zarr", "fsspec", "fastapi", - "copick-utils", + "copick-utils @ git+https://github.com/copick/copick-utils", ] [project.optional-dependencies] @@ -30,5 +30,8 @@ build-backend = "hatchling.build" [tool.hatch.build.targets.wheel] packages = ["copick_server"] +[tool.hatch.metadata] +allow-direct-references = true + [tool.uv.sources] copick-utils = { git = "https://github.com/copick/copick-utils" } diff --git a/tests/test_server.py b/tests/test_server.py index c56b1b0..1aaf84f 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -1,4 +1,5 @@ import json +import types import pytest from unittest.mock import MagicMock, patch @@ -12,16 +13,22 @@ def test_create_copick_app(app): def test_cors_middleware(mock_copick_root): - """Test that CORS middleware is added correctly.""" + """Test that CORS is properly configured.""" from copick_server.server import create_copick_app + from fastapi.testclient import TestClient # Create app with CORS origins app = create_copick_app(mock_copick_root, cors_origins=["https://example.com"]) - # Check that the middleware exists - assert len(app.user_middleware) > 0 - middleware_classes = [m.__class__.__name__ for m in app.user_middleware] - assert "CORSMiddleware" in middleware_classes + # Create a test client + client = TestClient(app) + + # Make a request with an Origin header + response = client.get("/any-path", headers={"Origin": "https://example.com"}) + + # Check if CORS headers are present in the response + assert "access-control-allow-origin" in response.headers, "CORS headers not found in response" + assert response.headers["access-control-allow-origin"] == "https://example.com", "Incorrect CORS origin value" @pytest.mark.asyncio @@ -40,13 +47,27 @@ async def test_handle_tomogram_request(mock_handle_tomogram, client, monkeypatch root_mock = MagicMock() root_mock.get_run.return_value = run_mock - # Patch to replace the route handler's root with our mock - with patch("copick_server.server.CopickRoute.root", root_mock): - # Set up mock for _handle_tomogram - mock_response = MagicMock() - mock_response.status_code = 200 - mock_handle_tomogram.return_value = mock_response - + # Set up mock for _handle_tomogram + mock_response = MagicMock() + mock_response.status_code = 200 + mock_handle_tomogram.return_value = mock_response + + # Find the route handler in the application + route_handler = None + for route in client.app.routes: + if isinstance(route.endpoint, types.MethodType) and route.endpoint.__self__.__class__.__name__ == 'CopickRoute': + route_handler = route.endpoint.__self__ + break + + assert route_handler is not None, "Could not find CopickRoute handler" + + # Save the original root + original_root = route_handler.root + + # Temporarily replace the root + route_handler.root = root_mock + + try: # Make the request response = client.get("/test_run/Tomograms/VoxelSpacing10.0/test.zarr") @@ -58,6 +79,9 @@ async def test_handle_tomogram_request(mock_handle_tomogram, client, monkeypatch # Verify _handle_tomogram was called mock_handle_tomogram.assert_called_once() + finally: + # Restore the original root + route_handler.root = original_root @pytest.mark.asyncio @@ -69,13 +93,27 @@ async def test_handle_picks_request(mock_handle_picks, client, monkeypatch): root_mock = MagicMock() root_mock.get_run.return_value = run_mock - # Patch to replace the route handler's root with our mock - with patch("copick_server.server.CopickRoute.root", root_mock): - # Set up mock for _handle_picks - mock_response = MagicMock() - mock_response.status_code = 200 - mock_handle_picks.return_value = mock_response - + # Set up mock for _handle_picks + mock_response = MagicMock() + mock_response.status_code = 200 + mock_handle_picks.return_value = mock_response + + # Find the route handler in the application + route_handler = None + for route in client.app.routes: + if isinstance(route.endpoint, types.MethodType) and route.endpoint.__self__.__class__.__name__ == 'CopickRoute': + route_handler = route.endpoint.__self__ + break + + assert route_handler is not None, "Could not find CopickRoute handler" + + # Save the original root + original_root = route_handler.root + + # Temporarily replace the root + route_handler.root = root_mock + + try: # Make the request response = client.get("/test_run/Picks/user_session_test.json") @@ -87,6 +125,9 @@ async def test_handle_picks_request(mock_handle_picks, client, monkeypatch): # Verify _handle_picks was called mock_handle_picks.assert_called_once() + finally: + # Restore the original root + route_handler.root = original_root @pytest.mark.asyncio @@ -98,13 +139,27 @@ async def test_handle_segmentation_request(mock_handle_segmentation, client, mon root_mock = MagicMock() root_mock.get_run.return_value = run_mock - # Patch to replace the route handler's root with our mock - with patch("copick_server.server.CopickRoute.root", root_mock): - # Set up mock for _handle_segmentation - mock_response = MagicMock() - mock_response.status_code = 200 - mock_handle_segmentation.return_value = mock_response - + # Set up mock for _handle_segmentation + mock_response = MagicMock() + mock_response.status_code = 200 + mock_handle_segmentation.return_value = mock_response + + # Find the route handler in the application + route_handler = None + for route in client.app.routes: + if isinstance(route.endpoint, types.MethodType) and route.endpoint.__self__.__class__.__name__ == 'CopickRoute': + route_handler = route.endpoint.__self__ + break + + assert route_handler is not None, "Could not find CopickRoute handler" + + # Save the original root + original_root = route_handler.root + + # Temporarily replace the root + route_handler.root = root_mock + + try: # Make the request response = client.get("/test_run/Segmentations/10.0_user_session_test.zarr") @@ -116,3 +171,6 @@ async def test_handle_segmentation_request(mock_handle_segmentation, client, mon # Verify _handle_segmentation was called mock_handle_segmentation.assert_called_once() + finally: + # Restore the original root + route_handler.root = original_root