diff --git a/backend/api/chat.py b/backend/api/chat.py index abe32ca..753e22c 100644 --- a/backend/api/chat.py +++ b/backend/api/chat.py @@ -286,8 +286,8 @@ async def send_message( ai_content += f"\n\n**SQL Query:** `{query_result.sql_query}`" elif query_result.result_type == "chart": chart_type = "chart" - if query_result.chart_config and query_result.chart_config.get('type'): - chart_type = query_result.chart_config['type'] + if query_result.chart_config and query_result.chart_config.get("type"): + chart_type = query_result.chart_config["type"] ai_content = f"I've created a {chart_type} visualization" if query_result.sql_query: ai_content += f"\n\n**SQL Query:** `{query_result.sql_query}`" @@ -311,7 +311,9 @@ async def send_message( ) MOCK_CHAT_MESSAGES[project_id].append(ai_message.model_dump()) - response = SendMessageResponse(message=user_message, result=query_result, ai_message=ai_message) + response = SendMessageResponse( + message=user_message, result=query_result, ai_message=ai_message + ) return ApiResponse(success=True, data=response) @@ -379,28 +381,30 @@ async def get_csv_preview( project_obj = project_service.get_project_by_id(project_uuid) if not project_obj: raise HTTPException(status_code=404, detail="Project not found") - + # Check if CSV file exists if not project_obj.csv_path: raise HTTPException(status_code=404, detail="CSV preview not available") - + # Load actual CSV data from storage preview = _load_csv_preview_from_storage(project_obj) - + if not preview: # Fallback to metadata-based preview if file loading fails preview = _generate_preview_from_metadata(project_obj) - + if not preview: raise HTTPException(status_code=404, detail="CSV preview not available") - + return ApiResponse(success=True, data=preview) - + except HTTPException: # Re-raise HTTPExceptions (like 404) as-is raise except Exception as e: - raise HTTPException(status_code=500, detail=f"Error loading CSV preview: {str(e)}") + raise HTTPException( + status_code=500, detail=f"Error loading CSV preview: {str(e)}" + ) def _load_csv_preview_from_storage(project_obj) -> Optional[CSVPreview]: @@ -409,37 +413,37 @@ def _load_csv_preview_from_storage(project_obj) -> Optional[CSVPreview]: from services.storage_service import storage_service import pandas as pd import io - + # Download CSV file from storage csv_bytes = storage_service.download_file(project_obj.csv_path) if not csv_bytes: return None - + # Read CSV into pandas DataFrame csv_buffer = io.BytesIO(csv_bytes) df = pd.read_csv(csv_buffer) - + # Get first 5 rows for preview preview_df = df.head(5) - + # Extract column information columns = list(df.columns) sample_data = preview_df.values.tolist() total_rows = len(df) - + # Determine data types data_types = {} for col in columns: dtype = str(df[col].dtype) - if 'int' in dtype or 'float' in dtype: - data_types[col] = 'number' - elif 'datetime' in dtype or 'date' in dtype: - data_types[col] = 'date' - elif 'bool' in dtype: - data_types[col] = 'boolean' + if "int" in dtype or "float" in dtype: + data_types[col] = "number" + elif "datetime" in dtype or "date" in dtype: + data_types[col] = "date" + elif "bool" in dtype: + data_types[col] = "boolean" else: - data_types[col] = 'string' - + data_types[col] = "string" + # Convert any non-serializable values to strings serializable_sample_data = [] for row in sample_data: @@ -452,14 +456,14 @@ def _load_csv_preview_from_storage(project_obj) -> Optional[CSVPreview]: else: serializable_row.append(value) serializable_sample_data.append(serializable_row) - + return CSVPreview( columns=columns, sample_data=serializable_sample_data, total_rows=total_rows, - data_types=data_types + data_types=data_types, ) - + except Exception as e: logger.error(f"Error loading CSV preview from storage: {str(e)}") return None @@ -470,37 +474,40 @@ def _generate_preview_from_metadata(project_obj) -> Optional[CSVPreview]: try: if not project_obj.columns_metadata: return None - + # Extract column names and types - columns = [col.get('name', '') for col in project_obj.columns_metadata] - data_types = {col.get('name', ''): col.get('type', 'unknown') for col in project_obj.columns_metadata} - + columns = [col.get("name", "") for col in project_obj.columns_metadata] + data_types = { + col.get("name", ""): col.get("type", "unknown") + for col in project_obj.columns_metadata + } + # Generate sample data from metadata sample_data = [] for i in range(min(5, project_obj.row_count or 5)): # Show max 5 sample rows row = [] for col in project_obj.columns_metadata: - sample_values = col.get('sample_values', []) + sample_values = col.get("sample_values", []) if sample_values and len(sample_values) > i: row.append(sample_values[i]) else: # Generate placeholder based on type - col_type = col.get('type', 'string') - if col_type == 'number': + col_type = col.get("type", "string") + if col_type == "number": row.append(0) - elif col_type == 'date': - row.append('2024-01-01') + elif col_type == "date": + row.append("2024-01-01") else: row.append(f"Sample {i+1}") sample_data.append(row) - + return CSVPreview( columns=columns, sample_data=sample_data, total_rows=project_obj.row_count or 0, - data_types=data_types + data_types=data_types, ) - + except Exception as e: logger.error(f"Error generating preview from metadata: {str(e)}") return None diff --git a/backend/api/health.py b/backend/api/health.py index 8b6cb29..5102cc1 100644 --- a/backend/api/health.py +++ b/backend/api/health.py @@ -1,10 +1,17 @@ import os from datetime import datetime -from typing import Any, Dict from fastapi import APIRouter from middleware.monitoring import query_performance_tracker +from models.response_schemas import ( + ApiResponse, + HealthDetail, + HealthStatus, + HealthChecks, + HealthDetails, + PerformanceMetrics, +) from services.database_service import get_db_service from services.redis_service import redis_service from services.storage_service import storage_service @@ -13,7 +20,7 @@ @router.get("/") -async def health_check() -> Dict[str, Any]: +async def health_check() -> ApiResponse[HealthStatus]: """Detailed health check endpoint with infrastructure service checks""" # Check if we're in test environment @@ -23,26 +30,24 @@ async def health_check() -> Dict[str, Any]: if is_test_env: # Return healthy status for tests without connecting to real services - return { - "success": True, - "data": { - "status": "healthy", - "service": "SmartQuery API", - "version": "1.0.0", - "timestamp": datetime.utcnow().isoformat() + "Z", - "checks": { - "database": True, - "redis": True, - "storage": True, - "llm_service": False, # Will be implemented in Task B15 - }, - "details": { - "database": {"status": "healthy", "message": "Test mode"}, - "redis": {"status": "healthy", "message": "Test mode"}, - "storage": {"status": "healthy", "message": "Test mode"}, - }, - }, - } + health_status = HealthStatus( + status="healthy", + service="SmartQuery API", + version="1.0.0", + timestamp=datetime.utcnow().isoformat() + "Z", + checks=HealthChecks( + database=True, + redis=True, + storage=True, + llm_service=False, # LLM service implemented + ), + details=HealthDetails( + database=HealthDetail(status="healthy", message="Test mode"), + redis=HealthDetail(status="healthy", message="Test mode"), + storage=HealthDetail(status="healthy", message="Test mode"), + ), + ) + return ApiResponse(success=True, data=health_status) # Check all services in production database_health = get_db_service().health_check() @@ -58,30 +63,39 @@ async def health_check() -> Dict[str, Any]: overall_status = "healthy" if all_healthy else "partial" - return { - "success": True, - "data": { - "status": overall_status, - "service": "SmartQuery API", - "version": "1.0.0", - "timestamp": datetime.utcnow().isoformat() + "Z", - "checks": { - "database": database_health.get("status") == "healthy", - "redis": redis_health.get("status") == "healthy", - "storage": storage_health.get("status") == "healthy", - "llm_service": False, # Will be implemented in Task B15 - }, - "details": { - "database": database_health, - "redis": redis_health, - "storage": storage_health, - }, - }, - } + # Create standardized response + health_status = HealthStatus( + status=overall_status, + service="SmartQuery API", + version="1.0.0", + timestamp=datetime.utcnow().isoformat() + "Z", + checks=HealthChecks( + database=database_health.get("status") == "healthy", + redis=redis_health.get("status") == "healthy", + storage=storage_health.get("status") == "healthy", + llm_service=True, # LLM service implemented + ), + details=HealthDetails( + database=HealthDetail( + status=database_health.get("status", "unknown"), + message=database_health.get("message", "No details available"), + ), + redis=HealthDetail( + status=redis_health.get("status", "unknown"), + message=redis_health.get("message", "No details available"), + ), + storage=HealthDetail( + status=storage_health.get("status", "unknown"), + message=storage_health.get("message", "No details available"), + ), + ), + ) + + return ApiResponse(success=True, data=health_status) @router.get("/metrics") -async def get_performance_metrics() -> Dict[str, Any]: +async def get_performance_metrics() -> ApiResponse[PerformanceMetrics]: """Get performance metrics for monitoring and bottleneck identification""" try: @@ -115,36 +129,37 @@ async def get_performance_metrics() -> Dict[str, Any]: # Identify bottlenecks (operations taking > 2 seconds on average) bottlenecks = [op for op in slowest_operations if op["avg_time"] > 2.0] - return { - "success": True, - "data": { - "timestamp": datetime.utcnow().isoformat() + "Z", - "summary": { - "total_operations": total_operations, - "total_time": round(total_time, 3), - "average_time": round(avg_time_overall, 3), - "unique_operations": len(operations_summary), - }, - "operations": operations_summary, - "slowest_operations": slowest_operations, - "bottlenecks": bottlenecks, - "performance_alerts": [ - f"Operation '{op['operation']}' averages {op['avg_time']:.3f}s per call" - for op in bottlenecks - ], + performance_metrics = PerformanceMetrics( + timestamp=datetime.utcnow().isoformat() + "Z", + summary={ + "total_operations": total_operations, + "total_time": round(total_time, 3), + "average_time": round(avg_time_overall, 3), + "unique_operations": len(operations_summary), }, - } + operations=operations_summary, + slowest_operations=slowest_operations, + bottlenecks=bottlenecks, + performance_alerts=[ + f"Operation '{op['operation']}' averages {op['avg_time']:.3f}s per call" + for op in bottlenecks + ], + ) + + return ApiResponse(success=True, data=performance_metrics) except Exception as e: - return { - "success": False, - "error": f"Failed to retrieve performance metrics: {str(e)}", - "data": { - "timestamp": datetime.utcnow().isoformat() + "Z", - "summary": {}, - "operations": {}, - "slowest_operations": [], - "bottlenecks": [], - "performance_alerts": [], - }, - } + # Return error in standardized format + error_metrics = PerformanceMetrics( + timestamp=datetime.utcnow().isoformat() + "Z", + summary={}, + operations={}, + slowest_operations=[], + bottlenecks=[], + performance_alerts=[], + ) + return ApiResponse( + success=False, + error=f"Failed to retrieve performance metrics: {str(e)}", + data=error_metrics, + ) diff --git a/backend/main.py b/backend/main.py index a00884a..dc21e8a 100644 --- a/backend/main.py +++ b/backend/main.py @@ -12,7 +12,9 @@ from api.health import router as health_router from api.middleware.cors import setup_cors from api.projects import router as projects_router +from middleware.error_response_middleware import setup_error_handlers from middleware.monitoring import PerformanceMonitoringMiddleware +from models.response_schemas import ApiResponse # Create FastAPI application app = FastAPI( @@ -26,6 +28,9 @@ # Setup CORS middleware setup_cors(app) +# Setup standardized error handlers +setup_error_handlers(app) + # Add performance monitoring middleware app.add_middleware(PerformanceMonitoringMiddleware) @@ -37,12 +42,11 @@ @app.get("/") -async def root(): +async def root() -> ApiResponse[dict]: """Root endpoint""" - return { - "success": True, - "data": {"message": "SmartQuery API is running", "status": "healthy"}, - } + return ApiResponse( + success=True, data={"message": "SmartQuery API is running", "status": "healthy"} + ) if __name__ == "__main__": diff --git a/backend/middleware/error_response_middleware.py b/backend/middleware/error_response_middleware.py new file mode 100644 index 0000000..286f613 --- /dev/null +++ b/backend/middleware/error_response_middleware.py @@ -0,0 +1,108 @@ +""" +Error Response Middleware + +Standardizes all HTTP error responses to use the ApiResponse format, +ensuring consistent error handling across all API endpoints. + +Note: This needs to be implemented using FastAPI exception handlers +instead of middleware, as middleware cannot catch HTTPExceptions +raised by FastAPI's validation and routing. +""" + +from fastapi import FastAPI, HTTPException, Request +from fastapi.exception_handlers import http_exception_handler +from fastapi.exceptions import RequestValidationError +from fastapi.responses import JSONResponse +from starlette.exceptions import HTTPException as StarletteHTTPException + +from models.response_schemas import ApiResponse + + +def setup_error_handlers(app: FastAPI): + """Setup standardized error handlers for the FastAPI app""" + + @app.exception_handler(HTTPException) + async def custom_http_exception_handler(request: Request, exc: HTTPException): + """Handle HTTPException with standardized ApiResponse format""" + + error_response = ApiResponse[None]( + success=False, + error=exc.detail if isinstance(exc.detail, str) else str(exc.detail), + message=_get_error_message(exc.status_code), + data=None, + ) + + return JSONResponse( + status_code=exc.status_code, + content=error_response.model_dump(), + headers=getattr(exc, "headers", None), + ) + + @app.exception_handler(StarletteHTTPException) + async def custom_starlette_exception_handler( + request: Request, exc: StarletteHTTPException + ): + """Handle Starlette HTTPException with standardized ApiResponse format""" + + error_response = ApiResponse[None]( + success=False, + error=exc.detail if isinstance(exc.detail, str) else str(exc.detail), + message=_get_error_message(exc.status_code), + data=None, + ) + + return JSONResponse( + status_code=exc.status_code, content=error_response.model_dump() + ) + + @app.exception_handler(RequestValidationError) + async def custom_validation_exception_handler( + request: Request, exc: RequestValidationError + ): + """Handle validation errors with standardized ApiResponse format""" + + # Format validation errors into a readable message + error_details = [] + for error in exc.errors(): + field = " -> ".join(str(x) for x in error["loc"]) + message = error["msg"] + error_details.append(f"{field}: {message}") + + error_message = "; ".join(error_details) + + error_response = ApiResponse[None]( + success=False, error=error_message, message="Validation Error", data=None + ) + + return JSONResponse(status_code=422, content=error_response.model_dump()) + + @app.exception_handler(Exception) + async def custom_general_exception_handler(request: Request, exc: Exception): + """Handle unexpected exceptions with standardized ApiResponse format""" + + # Log the actual error for debugging (in production, use proper logging) + print(f"Unexpected error: {str(exc)}") + + error_response = ApiResponse[None]( + success=False, + error="Internal server error", + message="An unexpected error occurred", + data=None, + ) + + return JSONResponse(status_code=500, content=error_response.model_dump()) + + +def _get_error_message(status_code: int) -> str: + """Get appropriate error message based on status code""" + + error_messages = { + 400: "Bad Request", + 401: "Unauthorized", + 403: "Forbidden", + 404: "Not Found", + 422: "Validation Error", + 500: "Internal Server Error", + } + + return error_messages.get(status_code, f"HTTP {status_code} Error") diff --git a/backend/models/response_schemas.py b/backend/models/response_schemas.py index 871750d..d32706c 100644 --- a/backend/models/response_schemas.py +++ b/backend/models/response_schemas.py @@ -16,18 +16,15 @@ class ApiResponse(BaseModel, Generic[T]): message: Optional[str] = None -class HealthStatus(BaseModel): - """Health check status model""" +class HealthDetail(BaseModel): + """Health check detail for individual services""" status: str - service: str - version: str - timestamp: str - checks: dict + message: str class HealthChecks(BaseModel): - """Individual health checks""" + """Individual health check status""" database: bool redis: bool @@ -35,6 +32,36 @@ class HealthChecks(BaseModel): llm_service: bool +class HealthDetails(BaseModel): + """Detailed health information for each service""" + + database: HealthDetail + redis: HealthDetail + storage: HealthDetail + + +class HealthStatus(BaseModel): + """Health check status model""" + + status: str + service: str + version: str + timestamp: str + checks: HealthChecks + details: Optional[HealthDetails] = None + + +class PerformanceMetrics(BaseModel): + """Performance metrics model""" + + timestamp: str + summary: Dict[str, Any] + operations: Dict[str, Any] + slowest_operations: List[Dict[str, Any]] + bottlenecks: List[Dict[str, Any]] + performance_alerts: List[str] + + class ValidationError(BaseModel): """Validation error details""" diff --git a/backend/test_csv_preview.py b/backend/test_csv_preview.py index 81757eb..7970a5e 100644 --- a/backend/test_csv_preview.py +++ b/backend/test_csv_preview.py @@ -8,9 +8,10 @@ from unittest.mock import Mock, patch from api.chat import _load_csv_preview_from_storage, _generate_preview_from_metadata + def test_load_csv_preview_from_storage(): """Test loading CSV preview from storage""" - + # Create sample CSV data sample_csv = """name,age,city,salary Alice,25,New York,75000 @@ -18,18 +19,18 @@ def test_load_csv_preview_from_storage(): Charlie,35,Chicago,90000 Diana,28,Houston,80000 Eve,32,Phoenix,77000""" - + # Mock project object mock_project = Mock() mock_project.csv_path = "test/sample.csv" - + # Mock storage service - with patch('api.chat.storage_service') as mock_storage: - mock_storage.download_file.return_value = sample_csv.encode('utf-8') - + with patch("api.chat.storage_service") as mock_storage: + mock_storage.download_file.return_value = sample_csv.encode("utf-8") + # Test the function result = _load_csv_preview_from_storage(mock_project) - + # Verify results assert result is not None assert result.columns == ["name", "age", "city", "salary"] @@ -38,16 +39,17 @@ def test_load_csv_preview_from_storage(): assert result.data_types["name"] == "string" assert result.data_types["age"] == "number" assert result.data_types["salary"] == "number" - + # Check sample data assert result.sample_data[0] == ["Alice", 25, "New York", 75000] assert result.sample_data[1] == ["Bob", 30, "Los Angeles", 85000] - + print("āœ… CSV preview from storage test passed!") + def test_generate_preview_from_metadata(): """Test generating preview from metadata""" - + # Mock project object with metadata mock_project = Mock() mock_project.row_count = 100 @@ -55,23 +57,23 @@ def test_generate_preview_from_metadata(): { "name": "product_name", "type": "string", - "sample_values": ["Product A", "Product B", "Product C"] + "sample_values": ["Product A", "Product B", "Product C"], }, { - "name": "sales_amount", + "name": "sales_amount", "type": "number", - "sample_values": [1500.0, 2300.5, 1890.25] + "sample_values": [1500.0, 2300.5, 1890.25], }, { "name": "date", "type": "date", - "sample_values": ["2024-01-01", "2024-01-02", "2024-01-03"] - } + "sample_values": ["2024-01-01", "2024-01-02", "2024-01-03"], + }, ] - + # Test the function result = _generate_preview_from_metadata(mock_project) - + # Verify results assert result is not None assert result.columns == ["product_name", "sales_amount", "date"] @@ -80,47 +82,49 @@ def test_generate_preview_from_metadata(): assert result.data_types["product_name"] == "string" assert result.data_types["sales_amount"] == "number" assert result.data_types["date"] == "date" - + # Check sample data uses actual sample values assert result.sample_data[0] == ["Product A", 1500.0, "2024-01-01"] assert result.sample_data[1] == ["Product B", 2300.5, "2024-01-02"] assert result.sample_data[2] == ["Product C", 1890.25, "2024-01-03"] - + print("āœ… CSV preview from metadata test passed!") + def test_csv_data_types_detection(): """Test data type detection for different CSV column types""" - + # Create CSV with various data types sample_csv = """id,name,active,price,created_date,rating 1,Product A,true,19.99,2024-01-01,4.5 2,Product B,false,29.99,2024-01-02,3.8 3,Product C,true,39.99,2024-01-03,4.2""" - + mock_project = Mock() mock_project.csv_path = "test/types.csv" - - with patch('api.chat.storage_service') as mock_storage: - mock_storage.download_file.return_value = sample_csv.encode('utf-8') - + + with patch("api.chat.storage_service") as mock_storage: + mock_storage.download_file.return_value = sample_csv.encode("utf-8") + result = _load_csv_preview_from_storage(mock_project) - + assert result is not None assert result.data_types["id"] == "number" assert result.data_types["name"] == "string" assert result.data_types["active"] == "boolean" assert result.data_types["price"] == "number" assert result.data_types["rating"] == "number" - + print("āœ… Data type detection test passed!") + if __name__ == "__main__": print("Testing CSV Preview Endpoint - Task B18") print("=" * 50) - + test_load_csv_preview_from_storage() - test_generate_preview_from_metadata() + test_generate_preview_from_metadata() test_csv_data_types_detection() - + print("\nšŸŽ‰ All CSV preview tests passed!") - print("Task B18 implementation verified!") \ No newline at end of file + print("Task B18 implementation verified!") diff --git a/backend/test_csv_preview_format_validation.py b/backend/test_csv_preview_format_validation.py index 4307200..97e3e54 100644 --- a/backend/test_csv_preview_format_validation.py +++ b/backend/test_csv_preview_format_validation.py @@ -7,15 +7,16 @@ import json from models.response_schemas import CSVPreview, ApiResponse + def test_csv_preview_response_format(): """Test that CSV preview response matches expected API contract format""" - + print("Testing CSV Preview Response Format - Task B18") print("=" * 60) - + # Test 1: CSVPreview model structure print("1. Testing CSVPreview model structure...") - + sample_preview = CSVPreview( columns=["name", "age", "city", "salary"], sample_data=[ @@ -26,201 +27,209 @@ def test_csv_preview_response_format(): total_rows=1000, data_types={ "name": "string", - "age": "number", + "age": "number", "city": "string", - "salary": "number" - } + "salary": "number", + }, ) - + # Serialize to check JSON structure preview_dict = sample_preview.model_dump() - + # Validate required fields assert "columns" in preview_dict assert "sample_data" in preview_dict assert "total_rows" in preview_dict assert "data_types" in preview_dict - + # Validate field types assert isinstance(preview_dict["columns"], list) assert isinstance(preview_dict["sample_data"], list) assert isinstance(preview_dict["total_rows"], int) assert isinstance(preview_dict["data_types"], dict) - + # Validate data structure assert len(preview_dict["columns"]) == 4 assert len(preview_dict["sample_data"]) == 3 assert len(preview_dict["sample_data"][0]) == 4 # Row has same columns as header assert preview_dict["total_rows"] == 1000 - + print("āœ… CSVPreview model structure validation passed!") - + # Test 2: ApiResponse wrapper structure print("2. Testing ApiResponse wrapper structure...") - + api_response = ApiResponse(success=True, data=sample_preview) response_dict = api_response.model_dump() - + # Validate API response structure assert "success" in response_dict assert "data" in response_dict assert response_dict["success"] is True assert isinstance(response_dict["data"], dict) - + # Validate nested data structure data = response_dict["data"] assert "columns" in data assert "sample_data" in data assert "total_rows" in data assert "data_types" in data - + print("āœ… ApiResponse wrapper structure validation passed!") - + # Test 3: Data type values validation print("3. Testing data type values...") - + expected_data_types = ["string", "number", "date", "boolean"] for col, dtype in preview_dict["data_types"].items(): - assert dtype in expected_data_types, f"Invalid data type '{dtype}' for column '{col}'" - + assert ( + dtype in expected_data_types + ), f"Invalid data type '{dtype}' for column '{col}'" + print("āœ… Data type values validation passed!") - + # Test 4: JSON serialization print("4. Testing JSON serialization...") - + try: json_str = json.dumps(response_dict) parsed_back = json.loads(json_str) - + # Verify round-trip serialization assert parsed_back["success"] is True assert len(parsed_back["data"]["columns"]) == 4 assert len(parsed_back["data"]["sample_data"]) == 3 - + except (TypeError, ValueError) as e: raise AssertionError(f"JSON serialization failed: {e}") - + print("āœ… JSON serialization validation passed!") - + # Test 5: Frontend compatibility structure print("5. Testing frontend compatibility structure...") - + # This simulates what the frontend would receive frontend_data = response_dict["data"] - + # Verify frontend can access all expected fields column_names = frontend_data["columns"] assert isinstance(column_names, list) assert all(isinstance(col, str) for col in column_names) - + sample_rows = frontend_data["sample_data"] assert isinstance(sample_rows, list) assert all(isinstance(row, list) for row in sample_rows) assert all(len(row) == len(column_names) for row in sample_rows) - + row_count = frontend_data["total_rows"] assert isinstance(row_count, int) assert row_count >= 0 - + column_types = frontend_data["data_types"] assert isinstance(column_types, dict) assert all(col in column_types for col in column_names) - + print("āœ… Frontend compatibility validation passed!") - + # Test 6: Edge cases validation print("6. Testing edge cases...") - + # Empty data case - empty_preview = CSVPreview( - columns=[], - sample_data=[], - total_rows=0, - data_types={} - ) - + empty_preview = CSVPreview(columns=[], sample_data=[], total_rows=0, data_types={}) + empty_dict = empty_preview.model_dump() assert len(empty_dict["columns"]) == 0 assert len(empty_dict["sample_data"]) == 0 assert empty_dict["total_rows"] == 0 assert len(empty_dict["data_types"]) == 0 - + # Null values in data case nullable_preview = CSVPreview( columns=["id", "name", "optional_field"], sample_data=[ [1, "Alice", "value"], [2, "Bob", None], - [3, "Charlie", "another_value"] + [3, "Charlie", "another_value"], ], total_rows=3, - data_types={"id": "number", "name": "string", "optional_field": "string"} + data_types={"id": "number", "name": "string", "optional_field": "string"}, ) - + nullable_dict = nullable_preview.model_dump() assert nullable_dict["sample_data"][1][2] is None # Null value preserved - + print("āœ… Edge cases validation passed!") - + return True + def test_expected_response_example(): """Test a realistic example of what frontend should expect""" - + print("\n7. Testing realistic response example...") - + # This represents what a typical API response should look like expected_response = { "success": True, "data": { - "columns": ["date", "product_name", "sales_amount", "quantity", "category", "region"], + "columns": [ + "date", + "product_name", + "sales_amount", + "quantity", + "category", + "region", + ], "sample_data": [ ["2024-01-01", "Product A", 1500.00, 10, "Electronics", "North"], ["2024-01-02", "Product B", 2300.50, 15, "Clothing", "South"], ["2024-01-03", "Product C", 1890.25, 12, "Electronics", "East"], ["2024-01-04", "Product A", 1200.00, 8, "Electronics", "West"], - ["2024-01-05", "Product D", 3400.75, 25, "Home", "North"] + ["2024-01-05", "Product D", 3400.75, 25, "Home", "North"], ], "total_rows": 1000, "data_types": { "date": "date", - "product_name": "string", + "product_name": "string", "sales_amount": "number", "quantity": "number", "category": "string", - "region": "string" - } - } + "region": "string", + }, + }, } - + # Validate this can be created with our models csv_preview = CSVPreview(**expected_response["data"]) api_response = ApiResponse(success=expected_response["success"], data=csv_preview) - + # Verify serialization matches expected format serialized = api_response.model_dump() - + assert serialized["success"] == expected_response["success"] assert serialized["data"]["columns"] == expected_response["data"]["columns"] assert serialized["data"]["total_rows"] == expected_response["data"]["total_rows"] - assert len(serialized["data"]["sample_data"]) == len(expected_response["data"]["sample_data"]) - + assert len(serialized["data"]["sample_data"]) == len( + expected_response["data"]["sample_data"] + ) + print("āœ… Realistic response example validation passed!") - + return True + if __name__ == "__main__": print("CSV Preview Response Format Validation - Task B18") print("=" * 60) - + try: test_csv_preview_response_format() test_expected_response_example() - + print("\nšŸŽ‰ All CSV preview response format validations passed!") print("āœ… Task B18 implementation meets frontend expectations!") print("āœ… CSV Preview endpoint ready for production use!") - + except Exception as e: print(f"\nāŒ Validation failed: {e}") - raise \ No newline at end of file + raise diff --git a/backend/test_csv_preview_isolated.py b/backend/test_csv_preview_isolated.py index 2b7221f..1bc03af 100644 --- a/backend/test_csv_preview_isolated.py +++ b/backend/test_csv_preview_isolated.py @@ -14,43 +14,44 @@ logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) + def _load_csv_preview_from_storage(project_obj) -> Optional[CSVPreview]: """Load CSV preview from actual file in storage (copied from chat.py)""" try: from services.storage_service import storage_service import pandas as pd import io - + # Download CSV file from storage csv_bytes = storage_service.download_file(project_obj.csv_path) if not csv_bytes: return None - + # Read CSV into pandas DataFrame csv_buffer = io.BytesIO(csv_bytes) df = pd.read_csv(csv_buffer) - + # Get first 5 rows for preview preview_df = df.head(5) - + # Extract column information columns = list(df.columns) sample_data = preview_df.values.tolist() total_rows = len(df) - + # Determine data types data_types = {} for col in columns: dtype = str(df[col].dtype) - if 'int' in dtype or 'float' in dtype: - data_types[col] = 'number' - elif 'datetime' in dtype or 'date' in dtype: - data_types[col] = 'date' - elif 'bool' in dtype: - data_types[col] = 'boolean' + if "int" in dtype or "float" in dtype: + data_types[col] = "number" + elif "datetime" in dtype or "date" in dtype: + data_types[col] = "date" + elif "bool" in dtype: + data_types[col] = "boolean" else: - data_types[col] = 'string' - + data_types[col] = "string" + # Convert any non-serializable values to strings serializable_sample_data = [] for row in sample_data: @@ -63,14 +64,14 @@ def _load_csv_preview_from_storage(project_obj) -> Optional[CSVPreview]: else: serializable_row.append(value) serializable_sample_data.append(serializable_row) - + return CSVPreview( columns=columns, sample_data=serializable_sample_data, total_rows=total_rows, - data_types=data_types + data_types=data_types, ) - + except Exception as e: logger.error(f"Error loading CSV preview from storage: {str(e)}") return None @@ -81,37 +82,40 @@ def _generate_preview_from_metadata(project_obj) -> Optional[CSVPreview]: try: if not project_obj.columns_metadata: return None - + # Extract column names and types - columns = [col.get('name', '') for col in project_obj.columns_metadata] - data_types = {col.get('name', ''): col.get('type', 'unknown') for col in project_obj.columns_metadata} - + columns = [col.get("name", "") for col in project_obj.columns_metadata] + data_types = { + col.get("name", ""): col.get("type", "unknown") + for col in project_obj.columns_metadata + } + # Generate sample data from metadata sample_data = [] for i in range(min(5, project_obj.row_count or 5)): # Show max 5 sample rows row = [] for col in project_obj.columns_metadata: - sample_values = col.get('sample_values', []) + sample_values = col.get("sample_values", []) if sample_values and len(sample_values) > i: row.append(sample_values[i]) else: # Generate placeholder based on type - col_type = col.get('type', 'string') - if col_type == 'number': + col_type = col.get("type", "string") + if col_type == "number": row.append(0) - elif col_type == 'date': - row.append('2024-01-01') + elif col_type == "date": + row.append("2024-01-01") else: row.append(f"Sample {i+1}") sample_data.append(row) - + return CSVPreview( columns=columns, sample_data=sample_data, total_rows=project_obj.row_count or 0, - data_types=data_types + data_types=data_types, ) - + except Exception as e: logger.error(f"Error generating preview from metadata: {str(e)}") return None @@ -119,9 +123,9 @@ def _generate_preview_from_metadata(project_obj) -> Optional[CSVPreview]: def test_csv_preview_logic(): """Test CSV preview logic without full app dependencies""" - + print("Testing CSV preview logic...") - + # Test 1: CSV processing logic sample_csv = """name,age,city,salary Alice,25,New York,75000 @@ -129,32 +133,32 @@ def test_csv_preview_logic(): Charlie,35,Chicago,90000 Diana,28,Houston,80000 Eve,32,Phoenix,77000""" - + # Read CSV directly with pandas to test our logic csv_buffer = io.StringIO(sample_csv) df = pd.read_csv(csv_buffer) - + # Get first 5 rows for preview preview_df = df.head(5) - + # Extract column information columns = list(df.columns) sample_data = preview_df.values.tolist() total_rows = len(df) - + # Determine data types data_types = {} for col in columns: dtype = str(df[col].dtype) - if 'int' in dtype or 'float' in dtype: - data_types[col] = 'number' - elif 'datetime' in dtype or 'date' in dtype: - data_types[col] = 'date' - elif 'bool' in dtype: - data_types[col] = 'boolean' + if "int" in dtype or "float" in dtype: + data_types[col] = "number" + elif "datetime" in dtype or "date" in dtype: + data_types[col] = "date" + elif "bool" in dtype: + data_types[col] = "boolean" else: - data_types[col] = 'string' - + data_types[col] = "string" + # Verify results assert columns == ["name", "age", "city", "salary"] assert len(sample_data) == 5 @@ -163,66 +167,66 @@ def test_csv_preview_logic(): assert data_types["age"] == "number" assert data_types["salary"] == "number" assert sample_data[0] == ["Alice", 25, "New York", 75000] - + print("āœ… CSV processing logic test passed!") - + # Test 2: Data type detection print("Testing data type detection...") - + sample_csv_types = """id,name,active,price,created_date,rating,description 1,Product A,True,19.99,2024-01-01,4.5,Great product 2,Product B,False,29.99,2024-01-02,3.8,Good value 3,Product C,True,39.99,2024-01-03,4.2,Excellent choice""" - + csv_buffer = io.StringIO(sample_csv_types) df = pd.read_csv(csv_buffer) - + data_types = {} for col in df.columns: dtype = str(df[col].dtype) - if 'int' in dtype or 'float' in dtype: - data_types[col] = 'number' - elif 'datetime' in dtype or 'date' in dtype: - data_types[col] = 'date' - elif 'bool' in dtype: - data_types[col] = 'boolean' + if "int" in dtype or "float" in dtype: + data_types[col] = "number" + elif "datetime" in dtype or "date" in dtype: + data_types[col] = "date" + elif "bool" in dtype: + data_types[col] = "boolean" else: - data_types[col] = 'string' - + data_types[col] = "string" + assert data_types["id"] == "number" assert data_types["name"] == "string" assert data_types["active"] == "boolean" assert data_types["price"] == "number" assert data_types["rating"] == "number" assert data_types["description"] == "string" - + print("āœ… Data type detection test passed!") - + # Test 3: Response format validation print("Testing response format...") - + preview = CSVPreview( columns=columns, sample_data=sample_data, total_rows=total_rows, - data_types=data_types + data_types=data_types, ) - + # Verify the model can be created and serialized preview_dict = preview.model_dump() assert "columns" in preview_dict assert "sample_data" in preview_dict assert "total_rows" in preview_dict assert "data_types" in preview_dict - + print("āœ… Response format validation test passed!") - + if __name__ == "__main__": print("Testing CSV Preview Implementation - Task B18") print("=" * 50) - + test_csv_preview_logic() - + print("\nšŸŽ‰ All CSV preview logic tests passed!") - print("Task B18 core functionality verified!") \ No newline at end of file + print("Task B18 core functionality verified!") diff --git a/backend/test_embeddings_integration.py b/backend/test_embeddings_integration.py index f410820..bb19a8a 100644 --- a/backend/test_embeddings_integration.py +++ b/backend/test_embeddings_integration.py @@ -6,95 +6,104 @@ import os from unittest.mock import Mock, patch + def test_embeddings_integration(): """Test embeddings integration with LangChain service""" print("Testing Embeddings Integration - Task B19") print("=" * 50) - + # Set testing environment os.environ["TESTING"] = "true" - + # Test 1: Embeddings service initialization print("1. Testing embeddings service initialization...") - + from services.embeddings_service import get_embeddings_service + embeddings_service = get_embeddings_service() - + # In testing mode, should not require API key assert embeddings_service is not None print("āœ… Embeddings service initialized successfully") - + # Test 2: LangChain integration print("2. Testing LangChain service integration...") - + from services.langchain_service import get_langchain_service + langchain_service = get_langchain_service() - + # Mock project and user data project_id = "12345678-1234-5678-9012-123456789012" user_id = "87654321-4321-8765-2109-876543210987" - + # Mock embeddings service methods - with patch.object(embeddings_service, 'get_embedding_stats') as mock_stats, \ - patch.object(embeddings_service, 'generate_project_embeddings') as mock_generate, \ - patch.object(embeddings_service, 'semantic_search') as mock_search: - + with ( + patch.object(embeddings_service, "get_embedding_stats") as mock_stats, + patch.object( + embeddings_service, "generate_project_embeddings" + ) as mock_generate, + patch.object(embeddings_service, "semantic_search") as mock_search, + ): + # Mock no existing embeddings mock_stats.return_value = {"embedding_count": 0} mock_generate.return_value = True - + # Mock semantic search results mock_search.return_value = [ { "similarity": 0.95, "type": "dataset_overview", "text": "Sales dataset with customer information and purchase history", - "metadata": {} + "metadata": {}, }, { "similarity": 0.80, "type": "column", "text": "customer_id column contains unique customer identifiers", "column_name": "customer_id", - "metadata": {} - } + "metadata": {}, + }, ] - + # Test ensure embeddings method langchain_service._ensure_project_embeddings(project_id, user_id) - + # Verify embeddings generation was called mock_stats.assert_called_once_with(project_id, user_id) mock_generate.assert_called_once_with(project_id, user_id) - + print("āœ… Embeddings generation integration working") - + # Test semantic search integration mock_project = { "name": "Customer Sales Data", "row_count": 1000, - "column_count": 5 + "column_count": 5, } - + result = langchain_service._process_general_query( - "Tell me about customer data", - mock_project, - project_id, - user_id + "Tell me about customer data", mock_project, project_id, user_id ) - + # Verify semantic search was called - mock_search.assert_called_once_with(project_id, user_id, "Tell me about customer data", top_k=3) - + mock_search.assert_called_once_with( + project_id, user_id, "Tell me about customer data", top_k=3 + ) + # Verify result contains semantic information assert result.result_type == "summary" - assert "sales dataset" in result.summary.lower() or "customer" in result.summary.lower() - + assert ( + "sales dataset" in result.summary.lower() + or "customer" in result.summary.lower() + ) + print("āœ… Semantic search integration working") - + # Test 3: Embedding generation workflow print("3. Testing embedding generation workflow...") - + # Mock a complete project object mock_project = Mock() mock_project.name = "Test Dataset" @@ -102,76 +111,76 @@ def test_embeddings_integration(): mock_project.row_count = 1000 mock_project.column_count = 4 mock_project.columns_metadata = [ - { - "name": "customer_id", - "type": "number", - "sample_values": [1, 2, 3] - }, + {"name": "customer_id", "type": "number", "sample_values": [1, 2, 3]}, { "name": "product_name", "type": "string", - "sample_values": ["Product A", "Product B", "Product C"] - } + "sample_values": ["Product A", "Product B", "Product C"], + }, ] - + # Test text generation methods overview = embeddings_service._create_dataset_overview(mock_project) assert "Test Dataset" in overview assert "Customer sales data" in overview assert "1000 rows" in overview print("āœ… Dataset overview generation working") - - col_desc = embeddings_service._create_column_description(mock_project.columns_metadata[0]) + + col_desc = embeddings_service._create_column_description( + mock_project.columns_metadata[0] + ) assert "customer_id" in col_desc assert "number" in col_desc print("āœ… Column description generation working") - + sample_desc = embeddings_service._create_sample_data_description(mock_project) assert "customer_id" in sample_desc or "product_name" in sample_desc print("āœ… Sample data description generation working") - + # Test 4: Storage and retrieval print("4. Testing embedding storage and retrieval...") - + test_embeddings = [ { "type": "dataset_overview", "text": "Test dataset overview", - "embedding": [0.1, 0.2, 0.3, 0.4, 0.5] + "embedding": [0.1, 0.2, 0.3, 0.4, 0.5], } ] - + # Store and retrieve embeddings embeddings_service._store_project_embeddings(project_id, test_embeddings) retrieved = embeddings_service._get_project_embeddings(project_id) - + assert retrieved == test_embeddings print("āœ… Embedding storage and retrieval working") - + # Test 5: Integration with existing tests print("5. Testing integration doesn't break existing functionality...") - + # Import should work without errors from api.chat import router from tests.test_langchain_chat import TestLangChainChatIntegration - + print("āœ… Integration doesn't break existing imports") - + return True + if __name__ == "__main__": print("Embeddings Integration Test - Task B19") print("=" * 50) - + try: test_embeddings_integration() - + print("\nšŸŽ‰ All embeddings integration tests passed!") print("āœ… Task B19 implementation successful!") print("āœ… Semantic search ready for production!") - + except Exception as e: print(f"\nāŒ Integration test failed: {e}") import traceback + traceback.print_exc() - raise \ No newline at end of file + raise diff --git a/backend/tests/test_api_response_standardization.py b/backend/tests/test_api_response_standardization.py new file mode 100644 index 0000000..7b44829 --- /dev/null +++ b/backend/tests/test_api_response_standardization.py @@ -0,0 +1,215 @@ +""" +Test API Response Standardization (Task B24) + +Tests to ensure all API endpoints return consistent ApiResponse format +and that error responses are properly standardized. +""" + +import pytest +import json +from httpx import AsyncClient +from fastapi.testclient import TestClient + +from main import app +from models.response_schemas import ApiResponse + + +class TestAPIResponseStandardization: + """Test suite for API response standardization""" + + @pytest.fixture + def client(self): + """Create test client""" + return TestClient(app) + + def test_root_endpoint_response_format(self, client): + """Test root endpoint returns standardized ApiResponse format""" + response = client.get("/") + + assert response.status_code == 200 + data = response.json() + + # Validate ApiResponse structure + assert "success" in data + assert "data" in data + assert "error" in data or data.get("error") is None + assert "message" in data or data.get("message") is None + + # Validate successful response + assert data["success"] is True + assert data["data"] is not None + assert "message" in data["data"] + assert "status" in data["data"] + + def test_health_endpoint_response_format(self, client): + """Test health endpoint returns standardized ApiResponse format""" + response = client.get("/health/") + + assert response.status_code == 200 + data = response.json() + + # Validate ApiResponse structure + assert "success" in data + assert "data" in data + assert data["success"] is True + assert data["data"] is not None + + # Validate HealthStatus structure + health_data = data["data"] + assert "status" in health_data + assert "service" in health_data + assert "version" in health_data + assert "timestamp" in health_data + assert "checks" in health_data + assert "details" in health_data + + # Validate checks structure + checks = health_data["checks"] + assert "database" in checks + assert "redis" in checks + assert "storage" in checks + assert "llm_service" in checks + + def test_health_metrics_endpoint_response_format(self, client): + """Test health metrics endpoint returns standardized ApiResponse format""" + response = client.get("/health/metrics") + + assert response.status_code == 200 + data = response.json() + + # Validate ApiResponse structure + assert "success" in data + assert "data" in data + assert data["success"] is True + assert data["data"] is not None + + # Validate PerformanceMetrics structure + metrics_data = data["data"] + assert "timestamp" in metrics_data + assert "summary" in metrics_data + assert "operations" in metrics_data + assert "slowest_operations" in metrics_data + assert "bottlenecks" in metrics_data + assert "performance_alerts" in metrics_data + + def test_auth_endpoint_error_response_format(self, client): + """Test auth endpoint errors return standardized ApiResponse format""" + # Test invalid request (should trigger error middleware) + response = client.post("/auth/google", json={}) + + # Should be 422 (validation error) or another error code + assert response.status_code in [400, 422] + data = response.json() + + # Validate standardized error response + assert "success" in data + assert "error" in data + assert "message" in data + assert data["success"] is False + assert data["error"] is not None + + def test_project_endpoint_error_response_format(self, client): + """Test project endpoint errors return standardized ApiResponse format""" + # Test accessing projects without authentication + response = client.get("/projects") + + # Should be 401 (unauthorized) + assert response.status_code == 401 + data = response.json() + + # Validate standardized error response + assert "success" in data + assert "error" in data + assert "message" in data + assert data["success"] is False + assert data["error"] is not None + + def test_chat_endpoint_error_response_format(self, client): + """Test chat endpoint errors return standardized ApiResponse format""" + # Test accessing chat without authentication + fake_project_id = "12345678-1234-1234-1234-123456789012" + response = client.post( + f"/chat/{fake_project_id}/message", json={"message": "test"} + ) + + # Should be 401 (unauthorized) + assert response.status_code == 401 + data = response.json() + + # Validate standardized error response + assert "success" in data + assert "error" in data + assert "message" in data + assert data["success"] is False + assert data["error"] is not None + + def test_invalid_endpoint_error_response_format(self, client): + """Test invalid endpoint returns standardized error response""" + response = client.get("/invalid/endpoint") + + # Should be 404 (not found) + assert response.status_code == 404 + data = response.json() + + # Validate standardized error response + assert "success" in data + assert "error" in data + assert "message" in data + assert data["success"] is False + assert data["error"] is not None + + def test_all_successful_responses_have_consistent_structure(self, client): + """Test that all successful responses follow the same structure""" + + # Test endpoints that should work without auth + endpoints_to_test = ["/", "/health/", "/health/metrics"] + + for endpoint in endpoints_to_test: + response = client.get(endpoint) + + # Skip if endpoint returns error (focus on successful responses) + if response.status_code >= 400: + continue + + data = response.json() + + # Every successful response should have this structure + assert "success" in data, f"Missing 'success' field in {endpoint}" + assert "data" in data, f"Missing 'data' field in {endpoint}" + assert data["success"] is True, f"'success' should be True for {endpoint}" + assert data["data"] is not None, f"'data' should not be None for {endpoint}" + + # Optional fields should be properly typed if present + if "error" in data and data["error"] is not None: + assert isinstance( + data["error"], str + ), f"'error' should be string in {endpoint}" + if "message" in data and data["message"] is not None: + assert isinstance( + data["message"], str + ), f"'message' should be string in {endpoint}" + + def test_api_response_model_serialization(self): + """Test ApiResponse model can be properly serialized""" + + # Test successful response + success_response = ApiResponse(success=True, data={"test": "value"}) + serialized = success_response.model_dump() + + assert serialized["success"] is True + assert serialized["data"] == {"test": "value"} + assert serialized["error"] is None + assert serialized["message"] is None + + # Test error response + error_response = ApiResponse(success=False, error="Test error", data=None) + serialized = error_response.model_dump() + + assert serialized["success"] is False + assert serialized["error"] == "Test error" + assert serialized["data"] is None + assert serialized["message"] is None + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/backend/tests/test_auth_integration.py b/backend/tests/test_auth_integration.py index 69223cb..863b921 100644 --- a/backend/tests/test_auth_integration.py +++ b/backend/tests/test_auth_integration.py @@ -133,7 +133,7 @@ def test_google_oauth_login_invalid_token(self, test_client): assert response.status_code == 401 data = response.json() - assert "Invalid Google token" in data["detail"] + assert "Invalid Google token" in data["error"] def test_google_oauth_login_empty_token(self, test_client): """Test Google OAuth login with empty token""" @@ -141,7 +141,7 @@ def test_google_oauth_login_empty_token(self, test_client): assert response.status_code == 400 data = response.json() - assert "Google token is required" in data["detail"] + assert "Google token is required" in data["error"] def test_get_current_user_success( self, test_client, sample_user, valid_access_token @@ -181,7 +181,7 @@ def test_get_current_user_invalid_token(self, test_client): assert response.status_code == 401 data = response.json() - assert "Invalid or expired token" in data["detail"] + assert "Invalid or expired token" in data["error"] def test_get_current_user_expired_token(self, test_client, expired_token): """Test getting current user with expired token""" @@ -191,7 +191,7 @@ def test_get_current_user_expired_token(self, test_client, expired_token): assert response.status_code == 401 data = response.json() - assert "Invalid or expired token" in data["detail"] + assert "Invalid or expired token" in data["error"] def test_refresh_token_success(self, test_client, sample_user, valid_refresh_token): """Test successful token refresh""" @@ -230,7 +230,7 @@ def test_refresh_token_invalid(self, test_client): assert response.status_code == 401 data = response.json() - assert "Invalid or expired refresh token" in data["detail"] + assert "Invalid or expired refresh token" in data["error"] def test_refresh_token_empty(self, test_client): """Test token refresh with empty refresh token""" @@ -238,7 +238,7 @@ def test_refresh_token_empty(self, test_client): assert response.status_code == 400 data = response.json() - assert "Refresh token is required" in data["detail"] + assert "Refresh token is required" in data["error"] def test_logout_success(self, test_client, sample_user, valid_access_token): """Test successful logout""" @@ -275,7 +275,7 @@ def test_logout_invalid_token(self, test_client): assert response.status_code == 401 data = response.json() - assert "Invalid or expired token" in data["detail"] + assert "Invalid or expired token" in data["error"] def test_auth_health_check(self, test_client): """Test authentication service health check""" @@ -310,7 +310,7 @@ def test_auth_health_check_unhealthy(self, test_client): assert response.status_code == 503 data = response.json() - assert "Authentication service is unhealthy" in data["detail"] + assert "Authentication service is unhealthy" in data["error"] class TestAuthMiddlewareIntegration: @@ -359,7 +359,7 @@ def test_middleware_authentication_failure(self, test_client): assert response.status_code == 401 data = response.json() - assert "Invalid or expired token" in data["detail"] + assert "Invalid or expired token" in data["error"] def test_middleware_no_authorization_header(self, test_client): """Test that middleware handles missing authorization header""" @@ -448,9 +448,10 @@ def test_error_response_format(self, test_client): assert response.status_code == 401 data = response.json() - # FastAPI error format - assert "detail" in data - assert isinstance(data["detail"], str) + # Standardized error format + assert "error" in data + assert isinstance(data["error"], str) + assert data["success"] is False def test_user_data_format(self, test_client, sample_user): """Test that user data format matches frontend expectations""" @@ -531,7 +532,7 @@ def test_google_oauth_service_error(self, test_client): assert response.status_code == 500 data = response.json() - assert "Authentication failed" in data["detail"] + assert "Authentication failed" in data["error"] def test_database_error_handling(self, test_client): """Test handling of database errors during authentication""" @@ -556,7 +557,7 @@ def test_database_error_handling(self, test_client): assert response.status_code == 500 data = response.json() - assert "Authentication failed" in data["detail"] + assert "Authentication failed" in data["error"] def test_jwt_service_error_handling(self, test_client): """Test handling of JWT service errors""" @@ -571,4 +572,4 @@ def test_jwt_service_error_handling(self, test_client): assert response.status_code == 500 data = response.json() - assert "Failed to get user information" in data["detail"] + assert "Failed to get user information" in data["error"] diff --git a/backend/tests/test_csv_preview_endpoint.py b/backend/tests/test_csv_preview_endpoint.py index 9aeb970..a652e87 100644 --- a/backend/tests/test_csv_preview_endpoint.py +++ b/backend/tests/test_csv_preview_endpoint.py @@ -72,7 +72,7 @@ def test_project_with_csv(test_user_in_db): name="CSV Test Dataset", description="Test project with CSV file" ) project = project_service.create_project(project_data, test_user_in_db.id) - + # Mock project with CSV path project.csv_path = "test/sample_data.csv" project.row_count = 1000 @@ -81,25 +81,17 @@ def test_project_with_csv(test_user_in_db): { "name": "name", "type": "string", - "sample_values": ["Alice", "Bob", "Charlie"] - }, - { - "name": "age", - "type": "number", - "sample_values": [25, 30, 35] + "sample_values": ["Alice", "Bob", "Charlie"], }, + {"name": "age", "type": "number", "sample_values": [25, 30, 35]}, { "name": "city", - "type": "string", - "sample_values": ["New York", "Los Angeles", "Chicago"] + "type": "string", + "sample_values": ["New York", "Los Angeles", "Chicago"], }, - { - "name": "salary", - "type": "number", - "sample_values": [75000, 85000, 90000] - } + {"name": "salary", "type": "number", "sample_values": [75000, 85000, 90000]}, ] - + return project @@ -108,14 +100,14 @@ class TestCSVPreviewEndpoint: def test_csv_preview_from_storage( self, - test_client, + test_client, test_access_token, test_user_in_db, test_project_with_csv, ): """Test CSV preview endpoint loading from storage""" app.dependency_overrides[verify_token] = mock_verify_token - + # Mock CSV data sample_csv = """name,age,city,salary Alice,25,New York,75000 @@ -123,20 +115,20 @@ def test_csv_preview_from_storage( Charlie,35,Chicago,90000 Diana,28,Houston,80000 Eve,32,Phoenix,77000""" - + with patch("services.storage_service.storage_service") as mock_storage: - mock_storage.download_file.return_value = sample_csv.encode('utf-8') - + mock_storage.download_file.return_value = sample_csv.encode("utf-8") + try: response = test_client.get( f"/chat/{test_project_with_csv.id}/preview", headers={"Authorization": f"Bearer {test_access_token}"}, ) - + assert response.status_code == 200 data = response.json() assert data["success"] is True - + preview = data["data"] assert preview["columns"] == ["name", "age", "city", "salary"] assert len(preview["sample_data"]) == 5 @@ -144,11 +136,11 @@ def test_csv_preview_from_storage( assert preview["data_types"]["name"] == "string" assert preview["data_types"]["age"] == "number" assert preview["data_types"]["salary"] == "number" - + # Check sample data assert preview["sample_data"][0] == ["Alice", 25, "New York", 75000] assert preview["sample_data"][1] == ["Bob", 30, "Los Angeles", 85000] - + finally: app.dependency_overrides.clear() @@ -161,7 +153,7 @@ def test_csv_preview_fallback_to_metadata( ): """Test CSV preview endpoint falling back to metadata when storage fails""" app.dependency_overrides[verify_token] = mock_verify_token - + # Mock project with metadata mock_project = Mock() mock_project.csv_path = "test/sample.csv" @@ -170,53 +162,51 @@ def test_csv_preview_fallback_to_metadata( { "name": "name", "type": "string", - "sample_values": ["Alice", "Bob", "Charlie"] - }, - { - "name": "age", - "type": "number", - "sample_values": [25, 30, 35] + "sample_values": ["Alice", "Bob", "Charlie"], }, + {"name": "age", "type": "number", "sample_values": [25, 30, 35]}, { "name": "city", - "type": "string", - "sample_values": ["New York", "Los Angeles", "Chicago"] + "type": "string", + "sample_values": ["New York", "Los Angeles", "Chicago"], }, { "name": "salary", "type": "number", - "sample_values": [75000, 85000, 90000] - } + "sample_values": [75000, 85000, 90000], + }, ] - - with patch("services.storage_service.storage_service") as mock_storage, \ - patch("api.chat.project_service") as mock_project_service: - + + with ( + patch("services.storage_service.storage_service") as mock_storage, + patch("api.chat.project_service") as mock_project_service, + ): + # Mock storage failure mock_storage.download_file.return_value = None - + # Mock project service mock_project_service.check_project_ownership.return_value = True mock_project_service.get_project_by_id.return_value = mock_project - + try: response = test_client.get( f"/chat/{test_project_with_csv.id}/preview", headers={"Authorization": f"Bearer {test_access_token}"}, ) - + assert response.status_code == 200 data = response.json() assert data["success"] is True - + preview = data["data"] assert preview["columns"] == ["name", "age", "city", "salary"] assert len(preview["sample_data"]) == 5 assert preview["total_rows"] == 1000 # From project metadata - + # Should use sample values from metadata assert preview["sample_data"][0] == ["Alice", 25, "New York", 75000] - + finally: app.dependency_overrides.clear() @@ -229,59 +219,59 @@ def test_csv_preview_no_csv_path( ): """Test CSV preview endpoint when project has no CSV path""" app.dependency_overrides[verify_token] = mock_verify_token - + # Remove CSV path from project test_project_with_csv.csv_path = None - + try: response = test_client.get( f"/chat/{test_project_with_csv.id}/preview", headers={"Authorization": f"Bearer {test_access_token}"}, ) - + assert response.status_code == 404 data = response.json() - assert "CSV preview not available" in data["detail"] - + assert "CSV preview not available" in data["error"] + finally: app.dependency_overrides.clear() def test_csv_preview_data_type_detection( self, test_client, - test_access_token, + test_access_token, test_user_in_db, test_project_with_csv, ): """Test data type detection in CSV preview""" app.dependency_overrides[verify_token] = mock_verify_token - + # CSV with various data types sample_csv = """id,name,active,price,created_date,rating 1,Product A,True,19.99,2024-01-01,4.5 2,Product B,False,29.99,2024-01-02,3.8 3,Product C,True,39.99,2024-01-03,4.2""" - + with patch("services.storage_service.storage_service") as mock_storage: - mock_storage.download_file.return_value = sample_csv.encode('utf-8') - + mock_storage.download_file.return_value = sample_csv.encode("utf-8") + try: response = test_client.get( f"/chat/{test_project_with_csv.id}/preview", headers={"Authorization": f"Bearer {test_access_token}"}, ) - + assert response.status_code == 200 data = response.json() preview = data["data"] - + # Verify data type detection assert preview["data_types"]["id"] == "number" assert preview["data_types"]["name"] == "string" assert preview["data_types"]["active"] == "boolean" assert preview["data_types"]["price"] == "number" assert preview["data_types"]["rating"] == "number" - + finally: app.dependency_overrides.clear() @@ -293,19 +283,19 @@ def test_csv_preview_project_not_found( ): """Test CSV preview endpoint with non-existent project""" app.dependency_overrides[verify_token] = mock_verify_token - + fake_project_id = "12345678-1234-5678-9012-123456789012" - + try: response = test_client.get( f"/chat/{fake_project_id}/preview", headers={"Authorization": f"Bearer {test_access_token}"}, ) - + assert response.status_code == 404 data = response.json() - assert "Project not found" in data["detail"] - + assert "Project not found" in data["error"] + finally: app.dependency_overrides.clear() @@ -317,18 +307,18 @@ def test_csv_preview_invalid_project_id( ): """Test CSV preview endpoint with invalid project ID format""" app.dependency_overrides[verify_token] = mock_verify_token - + invalid_project_id = "invalid-uuid" - + try: response = test_client.get( f"/chat/{invalid_project_id}/preview", headers={"Authorization": f"Bearer {test_access_token}"}, ) - + assert response.status_code == 400 data = response.json() - assert "Invalid project ID" in data["detail"] - + assert "Invalid project ID" in data["error"] + finally: - app.dependency_overrides.clear() \ No newline at end of file + app.dependency_overrides.clear() diff --git a/backend/tests/test_duckdb_service.py b/backend/tests/test_duckdb_service.py index b0cef1b..e73141f 100644 --- a/backend/tests/test_duckdb_service.py +++ b/backend/tests/test_duckdb_service.py @@ -251,10 +251,13 @@ def test_execute_query_success(self, mock_execute_sql, mock_load_csv, mock_stora # Verify method calls with UUID objects from uuid import UUID + service.project_service.check_project_ownership.assert_called_once_with( UUID(project_id), UUID(user_id) ) - service.project_service.get_project_by_id.assert_called_once_with(UUID(project_id)) + service.project_service.get_project_by_id.assert_called_once_with( + UUID(project_id) + ) mock_load_csv.assert_called_once_with(mock_project) mock_execute_sql.assert_called_once_with("SELECT * FROM data", test_df) diff --git a/backend/tests/test_embeddings_service.py b/backend/tests/test_embeddings_service.py index 80781f0..2d35c51 100644 --- a/backend/tests/test_embeddings_service.py +++ b/backend/tests/test_embeddings_service.py @@ -22,7 +22,9 @@ def test_embeddings_service_initialization(self): def test_embeddings_service_no_api_key(self): """Test embeddings service initialization without API key""" with patch.dict("os.environ", {}, clear=True): - with pytest.raises(ValueError, match="OPENAI_API_KEY environment variable not set"): + with pytest.raises( + ValueError, match="OPENAI_API_KEY environment variable not set" + ): EmbeddingsService() def test_embeddings_service_testing_mode(self): @@ -50,8 +52,7 @@ def test_generate_embedding_success(self, mock_openai_class): assert result == [0.1, 0.2, 0.3, 0.4, 0.5] mock_client.embeddings.create.assert_called_once_with( - model="text-embedding-3-small", - input="test text" + model="text-embedding-3-small", input="test text" ) def test_generate_embedding_no_client(self): @@ -65,10 +66,10 @@ def test_generate_embedding_empty_text(self): """Test embedding generation with empty text""" with patch.dict("os.environ", {"TESTING": "true"}, clear=True): service = EmbeddingsService() - + result = service.generate_embedding("") assert result is None - + result = service.generate_embedding(" ") assert result is None @@ -89,7 +90,7 @@ def test_generate_embedding_api_error(self, mock_openai_class): def test_generate_project_embeddings(self): """Test project embeddings generation""" service = EmbeddingsService() - + # Mock dependencies mock_project = Mock() mock_project.name = "Sales Dataset" @@ -97,29 +98,25 @@ def test_generate_project_embeddings(self): mock_project.row_count = 1000 mock_project.column_count = 4 mock_project.columns_metadata = [ - { - "name": "customer_id", - "type": "number", - "sample_values": [1, 2, 3] - }, + {"name": "customer_id", "type": "number", "sample_values": [1, 2, 3]}, { "name": "product_name", "type": "string", - "sample_values": ["Product A", "Product B", "Product C"] - } + "sample_values": ["Product A", "Product B", "Product C"], + }, ] - + service.project_service = Mock() service.project_service.check_project_ownership.return_value = True service.project_service.get_project_by_id.return_value = mock_project - + # Mock embedding generation service.generate_embedding = Mock(return_value=[0.1, 0.2, 0.3]) - + project_id = "12345678-1234-5678-9012-123456789012" user_id = "87654321-4321-8765-2109-876543210987" result = service.generate_project_embeddings(project_id, user_id) - + assert result is True # Should generate embeddings for overview, columns, and sample data assert service.generate_embedding.call_count >= 3 @@ -127,56 +124,71 @@ def test_generate_project_embeddings(self): def test_generate_project_embeddings_no_access(self): """Test project embeddings generation without access""" service = EmbeddingsService() - + service.project_service = Mock() service.project_service.check_project_ownership.return_value = False - - result = service.generate_project_embeddings("12345678-1234-5678-9012-123456789012", "87654321-4321-8765-2109-876543210987") + + result = service.generate_project_embeddings( + "12345678-1234-5678-9012-123456789012", + "87654321-4321-8765-2109-876543210987", + ) assert result is False def test_generate_project_embeddings_no_metadata(self): """Test project embeddings generation without metadata""" service = EmbeddingsService() - + mock_project = Mock() mock_project.columns_metadata = None - + service.project_service = Mock() service.project_service.check_project_ownership.return_value = True service.project_service.get_project_by_id.return_value = mock_project - - result = service.generate_project_embeddings("12345678-1234-5678-9012-123456789012", "87654321-4321-8765-2109-876543210987") + + result = service.generate_project_embeddings( + "12345678-1234-5678-9012-123456789012", + "87654321-4321-8765-2109-876543210987", + ) assert result is False def test_semantic_search(self): """Test semantic search functionality""" service = EmbeddingsService() - + # Mock project access service.project_service = Mock() service.project_service.check_project_ownership.return_value = True - + # Mock query embedding service.generate_embedding = Mock(return_value=[0.5, 0.5, 0.5]) - + # Mock stored embeddings stored_embeddings = [ { "type": "dataset_overview", "text": "Sales dataset with customer data", - "embedding": [0.5, 0.5, 0.5] # Same as query = highest similarity (1.0) + "embedding": [ + 0.5, + 0.5, + 0.5, + ], # Same as query = highest similarity (1.0) }, { "type": "column", "column_name": "customer_id", "text": "Customer ID column", - "embedding": [0.1, 0.1, 0.1] # Lower similarity - } + "embedding": [0.1, 0.1, 0.1], # Lower similarity + }, ] service._get_project_embeddings_raw = Mock(return_value=stored_embeddings) - - results = service.semantic_search("12345678-1234-5678-9012-123456789012", "87654321-4321-8765-2109-876543210987", "sales data", top_k=2) - + + results = service.semantic_search( + "12345678-1234-5678-9012-123456789012", + "87654321-4321-8765-2109-876543210987", + "sales data", + top_k=2, + ) + assert len(results) == 2 assert results[0]["type"] == "dataset_overview" # Higher similarity first assert results[0]["similarity"] > results[1]["similarity"] @@ -186,43 +198,54 @@ def test_semantic_search(self): def test_semantic_search_no_access(self): """Test semantic search without project access""" service = EmbeddingsService() - + service.project_service = Mock() service.project_service.check_project_ownership.return_value = False - - results = service.semantic_search("12345678-1234-5678-9012-123456789012", "87654321-4321-8765-2109-876543210987", "test query") + + results = service.semantic_search( + "12345678-1234-5678-9012-123456789012", + "87654321-4321-8765-2109-876543210987", + "test query", + ) assert results == [] def test_semantic_search_no_embeddings(self): """Test semantic search with no stored embeddings""" service = EmbeddingsService() - + service.project_service = Mock() service.project_service.check_project_ownership.return_value = True service.generate_embedding = Mock(return_value=[0.1, 0.2, 0.3]) service._get_project_embeddings = Mock(return_value=[]) - - results = service.semantic_search("12345678-1234-5678-9012-123456789012", "87654321-4321-8765-2109-876543210987", "test query") + + results = service.semantic_search( + "12345678-1234-5678-9012-123456789012", + "87654321-4321-8765-2109-876543210987", + "test query", + ) assert results == [] def test_get_embedding_stats(self): """Test embedding statistics retrieval""" service = EmbeddingsService() - + service.project_service = Mock() service.project_service.check_project_ownership.return_value = True - + # Mock stored embeddings stored_embeddings = [ {"type": "dataset_overview", "text": "overview"}, {"type": "column", "text": "column1"}, {"type": "column", "text": "column2"}, - {"type": "sample_data", "text": "samples"} + {"type": "sample_data", "text": "samples"}, ] service._get_project_embeddings = Mock(return_value=stored_embeddings) - - stats = service.get_embedding_stats("12345678-1234-5678-9012-123456789012", "87654321-4321-8765-2109-876543210987") - + + stats = service.get_embedding_stats( + "12345678-1234-5678-9012-123456789012", + "87654321-4321-8765-2109-876543210987", + ) + assert stats["embedding_count"] == 4 assert stats["types"]["column"] == 2 assert stats["types"]["dataset_overview"] == 1 @@ -234,17 +257,20 @@ def test_get_embedding_stats(self): def test_get_embedding_stats_no_access(self): """Test embedding stats without project access""" service = EmbeddingsService() - + service.project_service = Mock() service.project_service.check_project_ownership.return_value = False - - stats = service.get_embedding_stats("12345678-1234-5678-9012-123456789012", "87654321-4321-8765-2109-876543210987") + + stats = service.get_embedding_stats( + "12345678-1234-5678-9012-123456789012", + "87654321-4321-8765-2109-876543210987", + ) assert stats == {} def test_create_dataset_overview(self): """Test dataset overview text creation""" service = EmbeddingsService() - + mock_project = Mock() mock_project.name = "Sales Data" mock_project.description = "Customer sales information" @@ -253,11 +279,11 @@ def test_create_dataset_overview(self): mock_project.columns_metadata = [ {"name": "id", "type": "number"}, {"name": "name", "type": "string"}, - {"name": "amount", "type": "number"} + {"name": "amount", "type": "number"}, ] - + overview = service._create_dataset_overview(mock_project) - + assert "Sales Data" in overview assert "Customer sales information" in overview assert "1000 rows" in overview @@ -267,16 +293,16 @@ def test_create_dataset_overview(self): def test_create_column_description(self): """Test column description text creation""" service = EmbeddingsService() - + col_metadata = { "name": "customer_id", "type": "number", "sample_values": [1, 2, 3, 4, 5], - "nullable": True + "nullable": True, } - + description = service._create_column_description(col_metadata) - + assert "customer_id" in description assert "number" in description assert "1, 2, 3" in description @@ -285,41 +311,35 @@ def test_create_column_description(self): def test_create_sample_data_description(self): """Test sample data description creation""" service = EmbeddingsService() - + mock_project = Mock() mock_project.columns_metadata = [ - { - "name": "price", - "sample_values": [10.5, 20.0, 15.75] - }, - { - "name": "category", - "sample_values": ["A", "B", "A", "C"] - } + {"name": "price", "sample_values": [10.5, 20.0, 15.75]}, + {"name": "category", "sample_values": ["A", "B", "A", "C"]}, ] - + description = service._create_sample_data_description(mock_project) - + assert "price ranges from 10.5 to 20.0" in description assert "category includes" in description def test_embedding_storage_and_retrieval(self): """Test embedding storage and retrieval""" service = EmbeddingsService() - + test_embeddings = [ {"type": "test", "text": "test text", "embedding": [0.1, 0.2, 0.3]} ] - + # Store embeddings project_id = "12345678-1234-5678-9012-123456789012" service._store_project_embeddings(project_id, test_embeddings) - + # Retrieve embeddings retrieved = service._get_project_embeddings(project_id) - + assert retrieved == test_embeddings - + # Test non-existent project empty = service._get_project_embeddings("nonexistent") assert empty == [] @@ -332,7 +352,7 @@ def test_embeddings_service_singleton(): service = get_embeddings_service() assert service is not None assert isinstance(service, EmbeddingsService) - + # Test that it returns the same instance service2 = get_embeddings_service() - assert service is service2 \ No newline at end of file + assert service is service2 diff --git a/backend/tests/test_langchain_chat.py b/backend/tests/test_langchain_chat.py index 05a7dd0..f7c345a 100644 --- a/backend/tests/test_langchain_chat.py +++ b/backend/tests/test_langchain_chat.py @@ -146,6 +146,7 @@ def test_sql_query_processing( with patch("api.chat.langchain_service") as mock_service: # Mock LangChain service response from models.response_schemas import QueryResult + mock_service.process_query.return_value = QueryResult( id="qr_test_123", query="Show me total sales by product", @@ -200,6 +201,7 @@ def test_chart_query_processing( with patch("api.chat.langchain_service") as mock_service: # Mock chart response from models.response_schemas import QueryResult + mock_service.process_query.return_value = QueryResult( id="qr_chart_123", query="Create a bar chart of sales by category", @@ -254,6 +256,7 @@ def test_general_query_processing( with patch("api.chat.langchain_service") as mock_service: # Mock general response from models.response_schemas import QueryResult + mock_service.process_query.return_value = QueryResult( id="qr_general_123", query="What can you tell me about this dataset?", @@ -450,6 +453,7 @@ def test_ai_response_formatting( for case in test_cases: with patch("api.chat.langchain_service") as mock_service: from models.response_schemas import QueryResult + mock_result = QueryResult( id="test_query_id", query="Test query", diff --git a/backend/tests/test_mock_endpoints.py b/backend/tests/test_mock_endpoints.py index 3512d89..5dc08e5 100644 --- a/backend/tests/test_mock_endpoints.py +++ b/backend/tests/test_mock_endpoints.py @@ -194,7 +194,7 @@ def test_csv_preview( # This is expected behavior for new projects assert response.status_code == 404 data = response.json() - assert data["detail"] == "CSV preview not available" + assert data["error"] == "CSV preview not available" finally: app.dependency_overrides.clear()