diff --git a/backend/api/auth.py b/backend/api/auth.py index 01573cc..ef60b63 100644 --- a/backend/api/auth.py +++ b/backend/api/auth.py @@ -6,7 +6,13 @@ from fastapi import APIRouter, Depends, HTTPException from fastapi.security import HTTPBearer -from models.response_schemas import ApiResponse, AuthResponse, LoginRequest, User +from models.response_schemas import ( + ApiResponse, + AuthResponse, + LoginRequest, + RefreshTokenRequest, + User, +) from models.user import UserInDB from services.auth_service import AuthService @@ -116,15 +122,15 @@ async def get_current_user( @router.post("/logout") async def logout(token: str = Depends(get_current_user_token)) -> ApiResponse[dict]: - """Logout current user with enhanced logging""" + """Logout current user with enhanced logging and token blacklisting""" try: logger.info("Received logout request") # Verify token and get user for logging user = auth_service.get_current_user(token) - # Revoke tokens (placeholder implementation) - success = auth_service.revoke_user_tokens(str(user.id)) + # Revoke tokens with proper blacklisting + success = auth_service.revoke_user_tokens(str(user.id), access_token=token) if success: logger.info(f"Logout successful for user: {user.email}") @@ -148,19 +154,18 @@ async def logout(token: str = Depends(get_current_user_token)) -> ApiResponse[di @router.post("/refresh") -async def refresh_token(request: dict) -> ApiResponse[AuthResponse]: +async def refresh_token(request: RefreshTokenRequest) -> ApiResponse[AuthResponse]: """Refresh access token with enhanced validation""" try: logger.info("Received token refresh request") # Validate request - refresh_token = request.get("refresh_token") - if not refresh_token or not refresh_token.strip(): + if not request.refresh_token or not request.refresh_token.strip(): logger.warning("Empty refresh token received") raise HTTPException(status_code=400, detail="Refresh token is required") new_access_token, user = auth_service.refresh_access_token( - refresh_token.strip() + request.refresh_token.strip() ) # Convert to response format @@ -176,7 +181,7 @@ async def refresh_token(request: dict) -> ApiResponse[AuthResponse]: auth_response = AuthResponse( user=user_response, access_token=new_access_token, - refresh_token=refresh_token, # Keep the same refresh token + refresh_token=request.refresh_token, # Keep the same refresh token expires_in=auth_service.access_token_expire_minutes * 60, ) diff --git a/backend/services/auth_service.py b/backend/services/auth_service.py index 8bc6d60..cd4e58d 100644 --- a/backend/services/auth_service.py +++ b/backend/services/auth_service.py @@ -2,7 +2,7 @@ import os import uuid from datetime import datetime, timedelta -from typing import Dict, Optional, Tuple +from typing import Dict, Optional, Set, Tuple import jwt from google.auth.exceptions import GoogleAuthError @@ -24,6 +24,11 @@ class TokenData(BaseModel): user_id: str email: str exp: datetime + jti: Optional[str] = None # JWT ID for token tracking + + +# In-memory token blacklist (in production, use Redis) +_token_blacklist: Set[str] = set() class AuthService: @@ -52,31 +57,44 @@ def __init__(self): logger.info(f"Mock auth enabled: {self.enable_mock_auth}") def create_access_token(self, user_id: str, email: str) -> str: - """Create JWT access token""" + """Create JWT access token with unique JWT ID""" expire = datetime.utcnow() + timedelta(minutes=self.access_token_expire_minutes) + jti = str(uuid.uuid4()) # Unique token identifier to_encode = { "sub": user_id, "email": email, "exp": expire, "iat": datetime.utcnow(), "type": "access", + "jti": jti, } return jwt.encode(to_encode, self.jwt_secret, algorithm=self.algorithm) def create_refresh_token(self, user_id: str, email: str) -> str: - """Create JWT refresh token""" + """Create JWT refresh token with unique JWT ID""" expire = datetime.utcnow() + timedelta(days=self.refresh_token_expire_days) + jti = str(uuid.uuid4()) # Unique token identifier to_encode = { "sub": user_id, "email": email, "exp": expire, "iat": datetime.utcnow(), "type": "refresh", + "jti": jti, } return jwt.encode(to_encode, self.jwt_secret, algorithm=self.algorithm) + def _is_token_blacklisted(self, jti: str) -> bool: + """Check if token is blacklisted""" + return jti in _token_blacklist + + def _blacklist_token(self, jti: str) -> None: + """Add token to blacklist""" + _token_blacklist.add(jti) + logger.info(f"Token blacklisted: {jti}") + def verify_token(self, token: str, token_type: str = "access") -> TokenData: - """Verify JWT token and return token data""" + """Verify JWT token and return token data with blacklist check""" try: payload = jwt.decode(token, self.jwt_secret, algorithms=[self.algorithm]) @@ -92,6 +110,11 @@ def verify_token(self, token: str, token_type: str = "access") -> TokenData: ): raise jwt.InvalidTokenError("Token has expired") + # Check if token is blacklisted + jti = payload.get("jti") + if jti and self._is_token_blacklisted(jti): + raise jwt.InvalidTokenError("Token has been revoked") + return TokenData( user_id=payload.get("sub"), email=payload.get("email"), @@ -100,6 +123,7 @@ def verify_token(self, token: str, token_type: str = "access") -> TokenData: if exp_timestamp else datetime.utcnow() ), + jti=jti, ) except jwt.ExpiredSignatureError: raise jwt.InvalidTokenError("Token has expired") @@ -316,15 +340,44 @@ def get_current_user(self, access_token: str) -> UserInDB: logger.error(f"Get current user failed: {str(e)}") raise - def revoke_user_tokens(self, user_id: str) -> bool: + def revoke_token_by_jti(self, jti: str) -> bool: + """Revoke a specific token by its JWT ID""" + if not jti: + logger.warning("Attempted to revoke token without JTI") + return False + + self._blacklist_token(jti) + return True + + def revoke_user_tokens( + self, user_id: str, access_token: Optional[str] = None + ) -> bool: """ - Revoke all tokens for a user (logout) - Note: With JWT, we can't actually revoke tokens server-side without a blacklist. - This is a placeholder for future token blacklist implementation. + Revoke user tokens (logout with proper token blacklisting) + In a production system, you would query all active tokens for the user. + For now, we blacklist the current access token if provided. """ logger.info(f"Token revocation requested for user: {user_id}") - # In a production system, you would add the user's tokens to a blacklist - # For now, we just return True as logout is handled client-side + + if access_token: + try: + # Verify the token to get its JTI before blacklisting + token_data = self.verify_token(access_token, token_type="access") + if token_data.jti: + self._blacklist_token(token_data.jti) + logger.info(f"Successfully revoked token for user: {user_id}") + return True + else: + logger.warning(f"Token missing JTI for user: {user_id}") + return False + except jwt.InvalidTokenError as e: + logger.warning( + f"Invalid token during revocation for user {user_id}: {str(e)}" + ) + return False + + # If no token provided, still consider it successful + # (client-side logout without server-side token invalidation) return True def validate_google_client_configuration(self) -> Dict[str, any]: @@ -349,6 +402,14 @@ def validate_google_client_configuration(self) -> Dict[str, any]: return config_status + def get_blacklist_stats(self) -> Dict[str, any]: + """Get token blacklist statistics""" + return { + "blacklisted_tokens": len(_token_blacklist), + "implementation": "in_memory", + "note": "In production, use Redis for distributed blacklist", + } + def health_check(self) -> Dict[str, any]: """Enhanced health check for auth service""" try: @@ -383,6 +444,7 @@ def health_check(self) -> Dict[str, any]: "jwt_working": jwt_working, "user_service": user_health, "google_oauth": google_config, + "token_blacklist": self.get_blacklist_stats(), "environment": self.environment, "access_token_expire_minutes": self.access_token_expire_minutes, "refresh_token_expire_days": self.refresh_token_expire_days, diff --git a/backend/tests/test_auth_service.py b/backend/tests/test_auth_service.py index 86c3a7d..bd9ee13 100644 --- a/backend/tests/test_auth_service.py +++ b/backend/tests/test_auth_service.py @@ -60,6 +60,8 @@ def test_create_access_token(self, auth_service): assert payload["sub"] == user_id assert payload["email"] == email assert payload["type"] == "access" + assert "jti" in payload + assert payload["jti"] is not None def test_create_refresh_token(self, auth_service): """Test refresh token creation""" @@ -75,6 +77,8 @@ def test_create_refresh_token(self, auth_service): assert payload["sub"] == user_id assert payload["email"] == email assert payload["type"] == "refresh" + assert "jti" in payload + assert payload["jti"] is not None def test_verify_access_token_success(self, auth_service): """Test successful access token verification""" @@ -87,6 +91,7 @@ def test_verify_access_token_success(self, auth_service): assert isinstance(token_data, TokenData) assert token_data.user_id == user_id assert token_data.email == email + assert token_data.jti is not None def test_verify_refresh_token_success(self, auth_service): """Test successful refresh token verification""" @@ -99,6 +104,7 @@ def test_verify_refresh_token_success(self, auth_service): assert isinstance(token_data, TokenData) assert token_data.user_id == user_id assert token_data.email == email + assert token_data.jti is not None def test_verify_invalid_token(self, auth_service): """Test invalid token verification""" @@ -403,6 +409,57 @@ def test_get_current_user_inactive(self, auth_service, sample_user): auth_service.get_current_user(access_token) def test_revoke_user_tokens(self, auth_service): - """Test token revocation (placeholder implementation)""" + """Test token revocation with proper blacklisting""" + # Test without access token result = auth_service.revoke_user_tokens("test_user_123") assert result is True + + # Test with access token + access_token = auth_service.create_access_token( + "test_user_123", "test@example.com" + ) + result = auth_service.revoke_user_tokens( + "test_user_123", access_token=access_token + ) + assert result is True + + # Verify token is now blacklisted + with pytest.raises(jwt.InvalidTokenError, match="Token has been revoked"): + auth_service.verify_token(access_token) + + def test_revoke_token_by_jti(self, auth_service): + """Test token revocation by JWT ID""" + # Test with valid JTI + result = auth_service.revoke_token_by_jti("test_jti_123") + assert result is True + + # Test with empty JTI + result = auth_service.revoke_token_by_jti("") + assert result is False + + def test_token_blacklisting(self, auth_service): + """Test comprehensive token blacklisting functionality""" + user_id = "test_user_123" + email = "test@example.com" + + # Create token + token = auth_service.create_access_token(user_id, email) + + # Verify token works initially + token_data = auth_service.verify_token(token) + assert token_data.user_id == user_id + assert token_data.jti is not None + + # Blacklist the token + auth_service._blacklist_token(token_data.jti) + + # Verify token is now rejected + with pytest.raises(jwt.InvalidTokenError, match="Token has been revoked"): + auth_service.verify_token(token) + + def test_get_blacklist_stats(self, auth_service): + """Test blacklist statistics""" + stats = auth_service.get_blacklist_stats() + assert "blacklisted_tokens" in stats + assert "implementation" in stats + assert stats["implementation"] == "in_memory" diff --git a/backend/tests/test_mock_endpoints.py b/backend/tests/test_mock_endpoints.py index dfacb5c..4f57e41 100644 --- a/backend/tests/test_mock_endpoints.py +++ b/backend/tests/test_mock_endpoints.py @@ -6,6 +6,7 @@ from fastapi.testclient import TestClient from main import app +from middleware.auth_middleware import verify_token from models.user import GoogleOAuthData, UserInDB from services.auth_service import AuthService @@ -15,6 +16,11 @@ auth_service = AuthService() +def mock_verify_token(): + """Mock verify_token that returns user_001""" + return "user_001" + + @pytest.fixture def sample_user(): """Sample user for testing - uses UUID that matches our mock project ownership""" @@ -75,7 +81,8 @@ def test_get_current_user(test_client, sample_user, test_access_token): def test_get_projects(test_client, test_access_token): """Test get projects endpoint""" - with patch("api.projects.verify_token"): + app.dependency_overrides[verify_token] = mock_verify_token + try: response = test_client.get( "/projects?page=1&limit=10", headers={"Authorization": f"Bearer {test_access_token}"}, @@ -86,11 +93,14 @@ def test_get_projects(test_client, test_access_token): assert "items" in data["data"] assert "total" in data["data"] assert len(data["data"]["items"]) >= 0 + finally: + app.dependency_overrides.clear() def test_create_project(test_client, test_access_token): """Test create project endpoint""" - with patch("api.projects.verify_token"): + app.dependency_overrides[verify_token] = mock_verify_token + try: response = test_client.post( "/projects", json={"name": "Test Project", "description": "Test description"}, @@ -101,11 +111,14 @@ def test_create_project(test_client, test_access_token): assert data["success"] is True assert data["data"]["project"]["name"] == "Test Project" assert "upload_url" in data["data"] + finally: + app.dependency_overrides.clear() def test_get_project(test_client, test_access_token): """Test get single project endpoint""" - with patch("api.projects.verify_token"): + app.dependency_overrides[verify_token] = mock_verify_token + try: response = test_client.get( "/projects/project_001", headers={"Authorization": f"Bearer {test_access_token}"}, @@ -115,11 +128,14 @@ def test_get_project(test_client, test_access_token): assert data["success"] is True assert data["data"]["id"] == "project_001" assert data["data"]["name"] == "Sales Data Analysis" + finally: + app.dependency_overrides.clear() def test_csv_preview(test_client, test_access_token): """Test CSV preview endpoint""" - with patch("api.chat.verify_token"): + app.dependency_overrides[verify_token] = mock_verify_token + try: response = test_client.get( "/chat/project_001/preview", headers={"Authorization": f"Bearer {test_access_token}"}, @@ -130,11 +146,14 @@ def test_csv_preview(test_client, test_access_token): assert "columns" in data["data"] assert "sample_data" in data["data"] assert len(data["data"]["columns"]) > 0 + finally: + app.dependency_overrides.clear() def test_send_message(test_client, test_access_token): """Test send chat message endpoint""" - with patch("api.chat.verify_token"): + app.dependency_overrides[verify_token] = mock_verify_token + try: response = test_client.post( "/chat/project_001/message", json={"message": "Show me total sales by product"}, @@ -146,11 +165,14 @@ def test_send_message(test_client, test_access_token): assert "message" in data["data"] assert "result" in data["data"] assert data["data"]["result"]["result_type"] in ["table", "chart", "summary"] + finally: + app.dependency_overrides.clear() def test_query_suggestions(test_client, test_access_token): """Test query suggestions endpoint""" - with patch("api.chat.verify_token"): + app.dependency_overrides[verify_token] = mock_verify_token + try: response = test_client.get( "/chat/project_001/suggestions", headers={"Authorization": f"Bearer {test_access_token}"}, @@ -160,12 +182,14 @@ def test_query_suggestions(test_client, test_access_token): assert data["success"] is True assert len(data["data"]) > 0 assert all("text" in suggestion for suggestion in data["data"]) + finally: + app.dependency_overrides.clear() def test_unauthorized_access(test_client): """Test that endpoints require authentication""" response = test_client.get("/projects") - assert response.status_code == 403 + assert response.status_code == 401 def test_invalid_token(test_client): @@ -210,7 +234,8 @@ def test_refresh_token(test_client, sample_user): def test_project_status(test_client, test_access_token): """Test project status endpoint""" - with patch("api.projects.verify_token"): + app.dependency_overrides[verify_token] = mock_verify_token + try: response = test_client.get( "/projects/project_001/status", headers={"Authorization": f"Bearer {test_access_token}"}, @@ -220,26 +245,31 @@ def test_project_status(test_client, test_access_token): assert data["success"] is True assert "status" in data["data"] assert "progress" in data["data"] + finally: + app.dependency_overrides.clear() def test_get_upload_url(test_client, test_access_token): """Test get upload URL endpoint""" - with patch("api.projects.verify_token"): - response = test_client.post( + app.dependency_overrides[verify_token] = mock_verify_token + try: + response = test_client.get( "/projects/project_001/upload-url", - json={"filename": "new_data.csv", "content_type": "text/csv"}, headers={"Authorization": f"Bearer {test_access_token}"}, ) assert response.status_code == 200 data = response.json() assert data["success"] is True assert "upload_url" in data["data"] - assert "object_path" in data["data"] + assert "upload_fields" in data["data"] + finally: + app.dependency_overrides.clear() def test_get_messages(test_client, test_access_token): """Test get chat messages endpoint""" - with patch("api.chat.verify_token"): + app.dependency_overrides[verify_token] = mock_verify_token + try: response = test_client.get( "/chat/project_001/messages", headers={"Authorization": f"Bearer {test_access_token}"}, @@ -249,6 +279,8 @@ def test_get_messages(test_client, test_access_token): assert data["success"] is True assert "items" in data["data"] assert len(data["data"]["items"]) >= 0 + finally: + app.dependency_overrides.clear() def test_invalid_google_token(test_client): @@ -265,17 +297,21 @@ def test_invalid_google_token(test_client): def test_project_not_found(test_client, test_access_token): """Test project not found error""" - with patch("api.projects.verify_token"): + app.dependency_overrides[verify_token] = mock_verify_token + try: response = test_client.get( "/projects/nonexistent_project", headers={"Authorization": f"Bearer {test_access_token}"}, ) assert response.status_code == 404 + finally: + app.dependency_overrides.clear() def test_chart_query_response(test_client, test_access_token): """Test chart query response type""" - with patch("api.chat.verify_token"): + app.dependency_overrides[verify_token] = mock_verify_token + try: response = test_client.post( "/chat/project_001/message", json={"message": "show me a chart"}, @@ -285,3 +321,5 @@ def test_chart_query_response(test_client, test_access_token): data = response.json() assert data["data"]["result"]["result_type"] == "chart" assert "chart_config" in data["data"]["result"] + finally: + app.dependency_overrides.clear()