diff --git a/copick_server/client.py b/copick_server/client.py index e5e0312..88d004b 100644 --- a/copick_server/client.py +++ b/copick_server/client.py @@ -5,7 +5,7 @@ # First read the tomogram to get the shape store = get_mapper("http://localhost:8017/16463/Tomograms/VoxelSpacing10.012/wbp.zarr") -tomo = zarr.open(store, mode='r') +tomo = zarr.open(store, mode="r") full_shape = tomo["0"].shape print(f"Tomogram shape: {full_shape}") @@ -39,4 +39,4 @@ if response.status_code == 200: print("Successfully wrote segmentation") else: - print(f"Failed to write segmentation: {response.status_code} - {response.text}") \ No newline at end of file + print(f"Failed to write segmentation: {response.status_code} - {response.text}") diff --git a/copick_server/server.py b/copick_server/server.py index ae7bff8..342ad4a 100644 --- a/copick_server/server.py +++ b/copick_server/server.py @@ -1,5 +1,5 @@ import json -from typing import Dict, List, Optional, Union +from typing import List, Optional import click import copick @@ -7,42 +7,45 @@ import uvicorn import threading -import zarr -from fsspec import AbstractFileSystem from fastapi import FastAPI, Request, Response -from fastapi.middleware.cors import CORSMiddleware -from fastapi.routing import APIRoute + class CopickRoute: """Route handler for Copick data entities.""" - + def __init__(self, root: copick.models.CopickRoot): self.root = root - + async def handle_request(self, request: Request, path: str): # Parse path parameters path_parts = path.split("/") - + # Handle different path patterns try: if len(path_parts) >= 3: run_name = path_parts[0] data_type = path_parts[1] - + # Get the run run = self.root.get_run(run_name) if run is None: return Response(status_code=404) - + if data_type == "Tomograms": - return await self._handle_tomogram(request, run, "/".join(path_parts[2:])) + return await self._handle_tomogram( + request, run, "/".join(path_parts[2:]) + ) elif data_type == "Picks": - return await self._handle_picks(request, run, "/".join(path_parts[2:])) + return await self._handle_picks( + request, run, "/".join(path_parts[2:]) + ) elif data_type == "Segmentations": - return await self._handle_segmentation(request, run, "/".join(path_parts[2:])) - + return await self._handle_segmentation( + request, run, "/".join(path_parts[2:]) + ) + return Response(status_code=404) - + except Exception as e: print(f"Error handling request: {str(e)}") return Response(status_code=500) @@ -52,24 +55,24 @@ async def _handle_tomogram(self, request, run, path): parts = path.split("/") if len(parts) < 2: return Response(status_code=404) - + vs_str = parts[0].replace("VoxelSpacing", "") try: voxel_spacing = float(vs_str) except ValueError: return Response(status_code=404) - + tomo_type = parts[1].replace(".zarr", "") # Get the tomogram vs = run.get_voxel_spacing(voxel_spacing) if vs is None: return Response(status_code=404) - + tomogram = vs.get_tomogram(tomo_type) if tomogram is None: return Response(status_code=404) - + # Handle the request if request.method == "PUT" and not tomogram.read_only: try: @@ -92,20 +95,22 @@ async def _handle_picks(self, request, run, path): parts = path.split("/") if len(parts) < 1: return Response(status_code=404) - + pick_file = parts[0] pick_parts = pick_file.split("_") if len(pick_parts) != 3: return Response(status_code=404) - + user_id, session_id, object_name = pick_parts object_name = object_name.replace(".json", "") - + # Get or create picks picks = None if request.method == "PUT": try: - picks = run.new_picks(object_name=object_name, user_id=user_id, session_id=session_id) + picks = run.new_picks( + object_name=object_name, user_id=user_id, session_id=session_id + ) data = await request.json() picks.meta = copick.models.CopickPicksFile(**data) picks.store() @@ -114,13 +119,15 @@ async def _handle_picks(self, request, run, path): print(f"Picks write error: {str(e)}") return Response(status_code=500) else: - picks = run.get_picks(object_name=object_name, user_id=user_id, session_id=session_id) + picks = run.get_picks( + object_name=object_name, user_id=user_id, session_id=session_id + ) if not picks: return Response(status_code=404) - + if request.method == "HEAD": return Response(status_code=200) - + return Response(json.dumps(picks[0].meta.dict()), status_code=200) async def _handle_segmentation(self, request, run, path): @@ -128,33 +135,33 @@ async def _handle_segmentation(self, request, run, path): parts = path.split("/") if len(parts) < 1: return Response(status_code=404) - + seg_file = parts[0].replace(".zarr", "") seg_parts = seg_file.split("_") if len(seg_parts) < 4: return Response(status_code=404) - + voxel_size = float(seg_parts[0]) user_id = seg_parts[1] session_id = seg_parts[2] name = "_".join(seg_parts[3:]) is_multilabel = "multilabel" in name - + # Get or create segmentation if request.method == "PUT": try: # Get the data from the request body blob = await request.body() - + # Extract shape information (first 24 bytes contain 3 int64 values) shape = np.frombuffer(blob[:24], dtype=np.int64) - + # Extract the actual data and reshape it data = np.frombuffer(blob[24:], dtype=np.uint8).reshape(shape) - + # Import the writer utility from copick_utils.writers.write import segmentation - + # Use the utility function to write the segmentation seg = segmentation( run=run, @@ -163,9 +170,9 @@ async def _handle_segmentation(self, request, run, path): name=name.replace("-multilabel", ""), session_id=session_id, voxel_size=voxel_size, - multilabel=is_multilabel + multilabel=is_multilabel, ) - + return Response(status_code=200) except Exception as e: print(f"Segmentation write error: {str(e)}") @@ -176,11 +183,11 @@ async def _handle_segmentation(self, request, run, path): name=name.replace("-multilabel", ""), user_id=user_id, session_id=session_id, - is_multilabel=is_multilabel + is_multilabel=is_multilabel, ) if not segs: return Response(status_code=404) - + seg = segs[0] try: body = seg.zarr()["/".join(parts[1:])] @@ -190,16 +197,19 @@ async def _handle_segmentation(self, request, run, path): except KeyError: return Response(status_code=404) -def create_copick_app(root: copick.models.CopickRoot, cors_origins: Optional[List[str]] = None) -> FastAPI: + +def create_copick_app( + root: copick.models.CopickRoot, cors_origins: Optional[List[str]] = None +) -> FastAPI: """Create a FastAPI app for serving a Copick project. - + Parameters ---------- root : copick.models.CopickRoot Copick project root to serve cors_origins : list of str, optional List of allowed CORS origins. Use ["*"] to allow all. - + Returns ------- app : FastAPI @@ -207,19 +217,18 @@ def create_copick_app(root: copick.models.CopickRoot, cors_origins: Optional[Lis """ app = FastAPI() route_handler = CopickRoute(root) - + # Add the catch-all route app.add_api_route( - "/{path:path}", - route_handler.handle_request, - methods=["GET", "HEAD", "PUT"] + "/{path:path}", route_handler.handle_request, methods=["GET", "HEAD", "PUT"] ) - + # Add CORS middleware if origins are specified if cors_origins: # Ensure CORS middleware is properly initialized try: from fastapi.middleware.cors import CORSMiddleware + app.add_middleware( CORSMiddleware, allow_origins=cors_origins, @@ -231,12 +240,19 @@ def create_copick_app(root: copick.models.CopickRoot, cors_origins: Optional[Lis print(f"CORS middleware added with origins: {cors_origins}") except Exception as e: print(f"Error adding CORS middleware: {str(e)}") - + return app -def serve_copick(config_path: Optional[str] = None, dataset_ids: Optional[List[int]] = None, overlay_root: str = "/tmp/overlay_root", allowed_origins: Optional[List[str]] = None, **kwargs): + +def serve_copick( + config_path: Optional[str] = None, + dataset_ids: Optional[List[int]] = None, + overlay_root: str = "/tmp/overlay_root", + allowed_origins: Optional[List[str]] = None, + **kwargs, +): """Start an HTTP server serving a Copick project. - + Parameters ---------- config_path : str, optional @@ -249,13 +265,15 @@ def serve_copick(config_path: Optional[str] = None, dataset_ids: Optional[List[i List of allowed CORS origins. Use ["*"] to allow all. **kwargs Additional arguments passed to uvicorn.run() - + Notes ----- Either config_path or dataset_ids must be provided, but not both. """ if config_path and dataset_ids: - raise ValueError("Either config_path or dataset_ids must be provided, but not both.") + raise ValueError( + "Either config_path or dataset_ids must be provided, but not both." + ) elif config_path: root = copick.from_file(config_path) elif dataset_ids: @@ -266,14 +284,21 @@ def serve_copick(config_path: Optional[str] = None, dataset_ids: Optional[List[i ) else: raise ValueError("Either config_path or dataset_ids must be provided.") - + app = create_copick_app(root, allowed_origins) uvicorn.run(app, **kwargs) return app -def serve_copick_threaded(config_path: Optional[str] = None, dataset_ids: Optional[List[int]] = None, overlay_root: str = "/tmp/overlay_root", allowed_origins: Optional[List[str]] = None, **kwargs): + +def serve_copick_threaded( + config_path: Optional[str] = None, + dataset_ids: Optional[List[int]] = None, + overlay_root: str = "/tmp/overlay_root", + allowed_origins: Optional[List[str]] = None, + **kwargs, +): """Start an HTTP server in a background thread and return the app. - + Parameters ---------- config_path : str, optional @@ -286,18 +311,20 @@ def serve_copick_threaded(config_path: Optional[str] = None, dataset_ids: Option List of allowed CORS origins. Use ["*"] to allow all. **kwargs Additional arguments passed to uvicorn.run() - + Returns ------- app : FastAPI FastAPI application - + Notes ----- Either config_path or dataset_ids must be provided, but not both. """ if config_path and dataset_ids: - raise ValueError("Either config_path or dataset_ids must be provided, but not both.") + raise ValueError( + "Either config_path or dataset_ids must be provided, but not both." + ) elif config_path: root = copick.from_file(config_path) elif dataset_ids: @@ -308,20 +335,21 @@ def serve_copick_threaded(config_path: Optional[str] = None, dataset_ids: Option ) else: raise ValueError("Either config_path or dataset_ids must be provided.") - + app = create_copick_app(root, allowed_origins) - + # Start the server in a background thread server_thread = threading.Thread( target=uvicorn.run, args=(app,), kwargs=kwargs, - daemon=True # This makes the thread exit when the main thread exits + daemon=True, # This makes the thread exit when the main thread exits ) server_thread.start() - + return app + @click.group() @click.pass_context def cli(ctx): @@ -374,13 +402,22 @@ def cli(ctx): ) @click.option("--reload", is_flag=True, default=False, help="Enable auto-reload.") @click.pass_context -def serve(ctx, config: Optional[str] = None, dataset_ids: Optional[tuple] = None, overlay_root: str = "/tmp/overlay_root", cors: Optional[str] = None, host: str = "127.0.0.1", port: int = 8000, reload: bool = False): +def serve( + ctx, + config: Optional[str] = None, + dataset_ids: Optional[tuple] = None, + overlay_root: str = "/tmp/overlay_root", + cors: Optional[str] = None, + host: str = "127.0.0.1", + port: int = 8000, + reload: bool = False, +): """Serve a Copick project over HTTP.""" if config and dataset_ids: ctx.fail("Either --config or --dataset-ids must be provided, not both.") elif not config and not dataset_ids: ctx.fail("Either --config or --dataset-ids must be provided.") - + try: serve_copick( config_path=config, @@ -389,7 +426,7 @@ def serve(ctx, config: Optional[str] = None, dataset_ids: Optional[tuple] = None allowed_origins=[cors] if cors else None, host=host, port=port, - reload=reload + reload=reload, ) except Exception as e: ctx.fail(f"Error serving Copick project: {str(e)}") @@ -398,5 +435,6 @@ def serve(ctx, config: Optional[str] = None, dataset_ids: Optional[tuple] = None def main(): cli() + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/tests/conftest.py b/tests/conftest.py index cb21d13..d607e3d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,4 @@ import json -import os import tempfile from pathlib import Path @@ -17,7 +16,7 @@ def example_config(): # Create the overlay directory structure overlay_path = Path(temp_dir) / "overlay" overlay_path.mkdir() - + # Create a simple Copick config config_data = { "name": "Test Project", @@ -29,23 +28,21 @@ def example_config(): "is_particle": True, "label": 1, "color": [0, 117, 220, 255], - "radius": 150.0 + "radius": 150.0, } ], "user_id": "test_user", "config_type": "cryoet_data_portal", "overlay_root": f"local://{str(overlay_path)}/", "dataset_ids": [12345], - "overlay_fs_args": { - "auto_mkdir": True - } + "overlay_fs_args": {"auto_mkdir": True}, } - + # Write config to a temporary file config_path = Path(temp_dir) / "test_config.json" with open(config_path, "w") as f: json.dump(config_data, f) - + yield str(config_path) @@ -54,7 +51,7 @@ def mock_copick_root(monkeypatch, example_config): """Mock a CopickRoot instance for testing.""" # Create a minimal mock root with just enough functionality for testing # In a real setup, you would mock specific methods of CopickRoot - + # This is a simple case - for more complex mocking, consider using unittest.mock return copick.from_file(example_config) diff --git a/tests/test_route_handlers.py b/tests/test_route_handlers.py index aa56aeb..54882e4 100644 --- a/tests/test_route_handlers.py +++ b/tests/test_route_handlers.py @@ -1,4 +1,3 @@ -import json import pytest import numpy as np from unittest.mock import MagicMock, patch, AsyncMock @@ -8,12 +7,12 @@ async def test_handle_tomogram_invalid_path(): """Test handling of an invalid tomogram path.""" from copick_server.server import CopickRoute - + # Create mocks run_mock = MagicMock() request_mock = MagicMock() route_handler = CopickRoute(MagicMock()) - + # Test with too short path response = await route_handler._handle_tomogram(request_mock, run_mock, "invalid") assert response.status_code == 404 @@ -23,14 +22,16 @@ async def test_handle_tomogram_invalid_path(): async def test_handle_tomogram_invalid_voxel_spacing(): """Test handling of an invalid voxel spacing in tomogram path.""" from copick_server.server import CopickRoute - + # Create mocks run_mock = MagicMock() request_mock = MagicMock() route_handler = CopickRoute(MagicMock()) - + # Test with invalid voxel spacing - response = await route_handler._handle_tomogram(request_mock, run_mock, "VoxelSpacingXYZ/test.zarr") + response = await route_handler._handle_tomogram( + request_mock, run_mock, "VoxelSpacingXYZ/test.zarr" + ) assert response.status_code == 404 @@ -38,15 +39,17 @@ async def test_handle_tomogram_invalid_voxel_spacing(): async def test_handle_tomogram_unknown_voxel_spacing(): """Test handling of an unknown voxel spacing.""" from copick_server.server import CopickRoute - + # Create mocks run_mock = MagicMock() run_mock.get_voxel_spacing.return_value = None request_mock = MagicMock() route_handler = CopickRoute(MagicMock()) - + # Test with unknown voxel spacing - response = await route_handler._handle_tomogram(request_mock, run_mock, "VoxelSpacing10.0/test.zarr") + response = await route_handler._handle_tomogram( + request_mock, run_mock, "VoxelSpacing10.0/test.zarr" + ) assert response.status_code == 404 run_mock.get_voxel_spacing.assert_called_once_with(10.0) @@ -55,7 +58,7 @@ async def test_handle_tomogram_unknown_voxel_spacing(): async def test_handle_tomogram_unknown_tomogram(): """Test handling of an unknown tomogram type.""" from copick_server.server import CopickRoute - + # Create mocks vs_mock = MagicMock() vs_mock.get_tomogram.return_value = None @@ -63,9 +66,11 @@ async def test_handle_tomogram_unknown_tomogram(): run_mock.get_voxel_spacing.return_value = vs_mock request_mock = MagicMock() route_handler = CopickRoute(MagicMock()) - + # Test with unknown tomogram - response = await route_handler._handle_tomogram(request_mock, run_mock, "VoxelSpacing10.0/test.zarr") + response = await route_handler._handle_tomogram( + request_mock, run_mock, "VoxelSpacing10.0/test.zarr" + ) assert response.status_code == 404 vs_mock.get_tomogram.assert_called_once_with("test") @@ -74,7 +79,7 @@ async def test_handle_tomogram_unknown_tomogram(): async def test_handle_tomogram_put_readonly(): """Test handling of PUT request to readonly tomogram.""" from copick_server.server import CopickRoute - + # Create mocks tomogram_mock = MagicMock() tomogram_mock.read_only = True @@ -85,19 +90,21 @@ async def test_handle_tomogram_put_readonly(): request_mock = MagicMock() request_mock.method = "PUT" route_handler = CopickRoute(MagicMock()) - + # Test PUT to readonly tomogram body_mock = AsyncMock() request_mock.body = body_mock - + # Mock the zarr method to return a dict-like object zarr_mock = MagicMock() zarr_mock.__getitem__.return_value = b"test_data" tomogram_mock.zarr.return_value = zarr_mock - - response = await route_handler._handle_tomogram(request_mock, run_mock, "VoxelSpacing10.0/test.zarr/0") + + response = await route_handler._handle_tomogram( + request_mock, run_mock, "VoxelSpacing10.0/test.zarr/0" + ) assert response.status_code == 200 - + # Since it's readonly, zarr.__getitem__ should be called, not zarr.__setitem__ zarr_mock.__getitem__.assert_called_once_with("0") zarr_mock.__setitem__.assert_not_called() @@ -107,12 +114,12 @@ async def test_handle_tomogram_put_readonly(): async def test_handle_picks_invalid_path(): """Test handling of an invalid picks path.""" from copick_server.server import CopickRoute - + # Create mocks run_mock = MagicMock() request_mock = MagicMock() route_handler = CopickRoute(MagicMock()) - + # Test with empty path response = await route_handler._handle_picks(request_mock, run_mock, "") assert response.status_code == 404 @@ -122,14 +129,16 @@ async def test_handle_picks_invalid_path(): async def test_handle_picks_invalid_format(): """Test handling of picks with invalid filename format.""" from copick_server.server import CopickRoute - + # Create mocks run_mock = MagicMock() request_mock = MagicMock() route_handler = CopickRoute(MagicMock()) - + # Test with invalid format (should have 3 parts: user_session_object.json) - response = await route_handler._handle_picks(request_mock, run_mock, "invalid_format.json") + response = await route_handler._handle_picks( + request_mock, run_mock, "invalid_format.json" + ) assert response.status_code == 404 @@ -137,16 +146,18 @@ async def test_handle_picks_invalid_format(): async def test_handle_picks_get_not_found(): """Test handling of GET request for non-existent picks.""" from copick_server.server import CopickRoute - + # Create mocks run_mock = MagicMock() run_mock.get_picks.return_value = [] # No picks found request_mock = MagicMock() request_mock.method = "GET" route_handler = CopickRoute(MagicMock()) - + # Test GET for non-existent picks - response = await route_handler._handle_picks(request_mock, run_mock, "user_session_object.json") + response = await route_handler._handle_picks( + request_mock, run_mock, "user_session_object.json" + ) assert response.status_code == 404 run_mock.get_picks.assert_called_once_with( object_name="object", user_id="user", session_id="session" @@ -157,12 +168,12 @@ async def test_handle_picks_get_not_found(): async def test_handle_segmentation_invalid_path(): """Test handling of an invalid segmentation path.""" from copick_server.server import CopickRoute - + # Create mocks run_mock = MagicMock() request_mock = MagicMock() route_handler = CopickRoute(MagicMock()) - + # Test with empty path response = await route_handler._handle_segmentation(request_mock, run_mock, "") assert response.status_code == 404 @@ -172,14 +183,16 @@ async def test_handle_segmentation_invalid_path(): async def test_handle_segmentation_invalid_format(): """Test handling of segmentation with invalid filename format.""" from copick_server.server import CopickRoute - + # Create mocks run_mock = MagicMock() request_mock = MagicMock() route_handler = CopickRoute(MagicMock()) - + # Test with invalid format (should have at least 4 parts: voxel_user_session_name.zarr) - response = await route_handler._handle_segmentation(request_mock, run_mock, "invalid_format.zarr") + response = await route_handler._handle_segmentation( + request_mock, run_mock, "invalid_format.zarr" + ) assert response.status_code == 404 @@ -188,29 +201,31 @@ async def test_handle_segmentation_invalid_format(): async def test_handle_segmentation_put(mock_write_segmentation): """Test handling of PUT request for a segmentation.""" from copick_server.server import CopickRoute - + # Create mocks run_mock = MagicMock() request_mock = MagicMock() request_mock.method = "PUT" route_handler = CopickRoute(MagicMock()) - + # Mock the request body to return a byte array with shape info and data # (24 bytes for shape + data) shape = np.array([10, 10, 10], dtype=np.int64) data = np.zeros((10, 10, 10), dtype=np.uint8) body = shape.tobytes() + data.tobytes() - + body_mock = AsyncMock() body_mock.return_value = body request_mock.body = body_mock - + # Test PUT for segmentation - response = await route_handler._handle_segmentation(request_mock, run_mock, "10.0_user_session_object.zarr") - + response = await route_handler._handle_segmentation( + request_mock, run_mock, "10.0_user_session_object.zarr" + ) + assert response.status_code == 200 mock_write_segmentation.assert_called_once() - + # Check that the parameters are correct args, kwargs = mock_write_segmentation.call_args assert kwargs["run"] == run_mock @@ -218,8 +233,8 @@ async def test_handle_segmentation_put(mock_write_segmentation): assert kwargs["session_id"] == "session" assert kwargs["name"] == "object" assert kwargs["voxel_size"] == 10.0 - assert kwargs["multilabel"] == False - + assert not kwargs["multilabel"] + # Check that the segmentation volume shape is correct assert kwargs["segmentation_volume"].shape == (10, 10, 10) @@ -228,22 +243,24 @@ async def test_handle_segmentation_put(mock_write_segmentation): async def test_handle_segmentation_get_not_found(): """Test handling of GET request for non-existent segmentation.""" from copick_server.server import CopickRoute - + # Create mocks run_mock = MagicMock() run_mock.get_segmentations.return_value = [] # No segmentations found request_mock = MagicMock() request_mock.method = "GET" route_handler = CopickRoute(MagicMock()) - + # Test GET for non-existent segmentation - response = await route_handler._handle_segmentation(request_mock, run_mock, "10.0_user_session_object.zarr") - + response = await route_handler._handle_segmentation( + request_mock, run_mock, "10.0_user_session_object.zarr" + ) + assert response.status_code == 404 run_mock.get_segmentations.assert_called_once_with( voxel_size=10.0, name="object", user_id="user", session_id="session", - is_multilabel=False + is_multilabel=False, ) diff --git a/tests/test_server.py b/tests/test_server.py index 1aaf84f..4dce6cf 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -1,4 +1,3 @@ -import json import types import pytest from unittest.mock import MagicMock, patch @@ -16,19 +15,23 @@ def test_cors_middleware(mock_copick_root): """Test that CORS is properly configured.""" from copick_server.server import create_copick_app from fastapi.testclient import TestClient - + # Create app with CORS origins app = create_copick_app(mock_copick_root, cors_origins=["https://example.com"]) - + # Create a test client client = TestClient(app) - + # Make a request with an Origin header response = client.get("/any-path", headers={"Origin": "https://example.com"}) - + # Check if CORS headers are present in the response - assert "access-control-allow-origin" in response.headers, "CORS headers not found in response" - assert response.headers["access-control-allow-origin"] == "https://example.com", "Incorrect CORS origin value" + assert ( + "access-control-allow-origin" in response.headers + ), "CORS headers not found in response" + assert ( + response.headers["access-control-allow-origin"] == "https://example.com" + ), "Incorrect CORS origin value" @pytest.mark.asyncio @@ -46,37 +49,40 @@ async def test_handle_tomogram_request(mock_handle_tomogram, client, monkeypatch run_mock = MagicMock() root_mock = MagicMock() root_mock.get_run.return_value = run_mock - + # Set up mock for _handle_tomogram mock_response = MagicMock() mock_response.status_code = 200 mock_handle_tomogram.return_value = mock_response - + # Find the route handler in the application route_handler = None for route in client.app.routes: - if isinstance(route.endpoint, types.MethodType) and route.endpoint.__self__.__class__.__name__ == 'CopickRoute': + if ( + isinstance(route.endpoint, types.MethodType) + and route.endpoint.__self__.__class__.__name__ == "CopickRoute" + ): route_handler = route.endpoint.__self__ break - + assert route_handler is not None, "Could not find CopickRoute handler" - + # Save the original root original_root = route_handler.root - + # Temporarily replace the root route_handler.root = root_mock - + try: # Make the request response = client.get("/test_run/Tomograms/VoxelSpacing10.0/test.zarr") - + # Verify the response assert response.status_code == 200 - + # Verify the correct run was obtained root_mock.get_run.assert_called_once_with("test_run") - + # Verify _handle_tomogram was called mock_handle_tomogram.assert_called_once() finally: @@ -92,37 +98,40 @@ async def test_handle_picks_request(mock_handle_picks, client, monkeypatch): run_mock = MagicMock() root_mock = MagicMock() root_mock.get_run.return_value = run_mock - + # Set up mock for _handle_picks mock_response = MagicMock() mock_response.status_code = 200 mock_handle_picks.return_value = mock_response - + # Find the route handler in the application route_handler = None for route in client.app.routes: - if isinstance(route.endpoint, types.MethodType) and route.endpoint.__self__.__class__.__name__ == 'CopickRoute': + if ( + isinstance(route.endpoint, types.MethodType) + and route.endpoint.__self__.__class__.__name__ == "CopickRoute" + ): route_handler = route.endpoint.__self__ break - + assert route_handler is not None, "Could not find CopickRoute handler" - + # Save the original root original_root = route_handler.root - + # Temporarily replace the root route_handler.root = root_mock - + try: # Make the request response = client.get("/test_run/Picks/user_session_test.json") - + # Verify the response assert response.status_code == 200 - + # Verify the correct run was obtained root_mock.get_run.assert_called_once_with("test_run") - + # Verify _handle_picks was called mock_handle_picks.assert_called_once() finally: @@ -132,43 +141,48 @@ async def test_handle_picks_request(mock_handle_picks, client, monkeypatch): @pytest.mark.asyncio @patch("copick_server.server.CopickRoute._handle_segmentation") -async def test_handle_segmentation_request(mock_handle_segmentation, client, monkeypatch): +async def test_handle_segmentation_request( + mock_handle_segmentation, client, monkeypatch +): """Test that segmentation requests are routed correctly.""" # Mock the get_run method to return a valid run run_mock = MagicMock() root_mock = MagicMock() root_mock.get_run.return_value = run_mock - + # Set up mock for _handle_segmentation mock_response = MagicMock() mock_response.status_code = 200 mock_handle_segmentation.return_value = mock_response - + # Find the route handler in the application route_handler = None for route in client.app.routes: - if isinstance(route.endpoint, types.MethodType) and route.endpoint.__self__.__class__.__name__ == 'CopickRoute': + if ( + isinstance(route.endpoint, types.MethodType) + and route.endpoint.__self__.__class__.__name__ == "CopickRoute" + ): route_handler = route.endpoint.__self__ break - + assert route_handler is not None, "Could not find CopickRoute handler" - + # Save the original root original_root = route_handler.root - + # Temporarily replace the root route_handler.root = root_mock - + try: # Make the request response = client.get("/test_run/Segmentations/10.0_user_session_test.zarr") - + # Verify the response assert response.status_code == 200 - + # Verify the correct run was obtained root_mock.get_run.assert_called_once_with("test_run") - + # Verify _handle_segmentation was called mock_handle_segmentation.assert_called_once() finally: