Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 14 additions & 9 deletions backend/api/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,13 @@
from fastapi import APIRouter, Depends, HTTPException
from fastapi.security import HTTPBearer

from models.response_schemas import ApiResponse, AuthResponse, LoginRequest, User
from models.response_schemas import (
ApiResponse,
AuthResponse,
LoginRequest,
RefreshTokenRequest,
User,
)
from models.user import UserInDB
from services.auth_service import AuthService

Expand Down Expand Up @@ -116,15 +122,15 @@ async def get_current_user(

@router.post("/logout")
async def logout(token: str = Depends(get_current_user_token)) -> ApiResponse[dict]:
"""Logout current user with enhanced logging"""
"""Logout current user with enhanced logging and token blacklisting"""
try:
logger.info("Received logout request")

# Verify token and get user for logging
user = auth_service.get_current_user(token)

# Revoke tokens (placeholder implementation)
success = auth_service.revoke_user_tokens(str(user.id))
# Revoke tokens with proper blacklisting
success = auth_service.revoke_user_tokens(str(user.id), access_token=token)

if success:
logger.info(f"Logout successful for user: {user.email}")
Expand All @@ -148,19 +154,18 @@ async def logout(token: str = Depends(get_current_user_token)) -> ApiResponse[di


@router.post("/refresh")
async def refresh_token(request: dict) -> ApiResponse[AuthResponse]:
async def refresh_token(request: RefreshTokenRequest) -> ApiResponse[AuthResponse]:
"""Refresh access token with enhanced validation"""
try:
logger.info("Received token refresh request")

# Validate request
refresh_token = request.get("refresh_token")
if not refresh_token or not refresh_token.strip():
if not request.refresh_token or not request.refresh_token.strip():
logger.warning("Empty refresh token received")
raise HTTPException(status_code=400, detail="Refresh token is required")

new_access_token, user = auth_service.refresh_access_token(
refresh_token.strip()
request.refresh_token.strip()
)

# Convert to response format
Expand All @@ -176,7 +181,7 @@ async def refresh_token(request: dict) -> ApiResponse[AuthResponse]:
auth_response = AuthResponse(
user=user_response,
access_token=new_access_token,
refresh_token=refresh_token, # Keep the same refresh token
refresh_token=request.refresh_token, # Keep the same refresh token
expires_in=auth_service.access_token_expire_minutes * 60,
)

Expand Down
82 changes: 72 additions & 10 deletions backend/services/auth_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
import uuid
from datetime import datetime, timedelta
from typing import Dict, Optional, Tuple
from typing import Dict, Optional, Set, Tuple

import jwt
from google.auth.exceptions import GoogleAuthError
Expand All @@ -24,6 +24,11 @@ class TokenData(BaseModel):
user_id: str
email: str
exp: datetime
jti: Optional[str] = None # JWT ID for token tracking


# In-memory token blacklist (in production, use Redis)
_token_blacklist: Set[str] = set()


class AuthService:
Expand Down Expand Up @@ -52,31 +57,44 @@ def __init__(self):
logger.info(f"Mock auth enabled: {self.enable_mock_auth}")

def create_access_token(self, user_id: str, email: str) -> str:
"""Create JWT access token"""
"""Create JWT access token with unique JWT ID"""
expire = datetime.utcnow() + timedelta(minutes=self.access_token_expire_minutes)
jti = str(uuid.uuid4()) # Unique token identifier
to_encode = {
"sub": user_id,
"email": email,
"exp": expire,
"iat": datetime.utcnow(),
"type": "access",
"jti": jti,
}
return jwt.encode(to_encode, self.jwt_secret, algorithm=self.algorithm)

def create_refresh_token(self, user_id: str, email: str) -> str:
"""Create JWT refresh token"""
"""Create JWT refresh token with unique JWT ID"""
expire = datetime.utcnow() + timedelta(days=self.refresh_token_expire_days)
jti = str(uuid.uuid4()) # Unique token identifier
to_encode = {
"sub": user_id,
"email": email,
"exp": expire,
"iat": datetime.utcnow(),
"type": "refresh",
"jti": jti,
}
return jwt.encode(to_encode, self.jwt_secret, algorithm=self.algorithm)

def _is_token_blacklisted(self, jti: str) -> bool:
"""Check if token is blacklisted"""
return jti in _token_blacklist

def _blacklist_token(self, jti: str) -> None:
"""Add token to blacklist"""
_token_blacklist.add(jti)
logger.info(f"Token blacklisted: {jti}")

def verify_token(self, token: str, token_type: str = "access") -> TokenData:
"""Verify JWT token and return token data"""
"""Verify JWT token and return token data with blacklist check"""
try:
payload = jwt.decode(token, self.jwt_secret, algorithms=[self.algorithm])

Expand All @@ -92,6 +110,11 @@ def verify_token(self, token: str, token_type: str = "access") -> TokenData:
):
raise jwt.InvalidTokenError("Token has expired")

# Check if token is blacklisted
jti = payload.get("jti")
if jti and self._is_token_blacklisted(jti):
raise jwt.InvalidTokenError("Token has been revoked")

return TokenData(
user_id=payload.get("sub"),
email=payload.get("email"),
Expand All @@ -100,6 +123,7 @@ def verify_token(self, token: str, token_type: str = "access") -> TokenData:
if exp_timestamp
else datetime.utcnow()
),
jti=jti,
)
except jwt.ExpiredSignatureError:
raise jwt.InvalidTokenError("Token has expired")
Expand Down Expand Up @@ -316,15 +340,44 @@ def get_current_user(self, access_token: str) -> UserInDB:
logger.error(f"Get current user failed: {str(e)}")
raise

def revoke_user_tokens(self, user_id: str) -> bool:
def revoke_token_by_jti(self, jti: str) -> bool:
"""Revoke a specific token by its JWT ID"""
if not jti:
logger.warning("Attempted to revoke token without JTI")
return False

self._blacklist_token(jti)
return True

def revoke_user_tokens(
self, user_id: str, access_token: Optional[str] = None
) -> bool:
"""
Revoke all tokens for a user (logout)
Note: With JWT, we can't actually revoke tokens server-side without a blacklist.
This is a placeholder for future token blacklist implementation.
Revoke user tokens (logout with proper token blacklisting)
In a production system, you would query all active tokens for the user.
For now, we blacklist the current access token if provided.
"""
logger.info(f"Token revocation requested for user: {user_id}")
# In a production system, you would add the user's tokens to a blacklist
# For now, we just return True as logout is handled client-side

if access_token:
try:
# Verify the token to get its JTI before blacklisting
token_data = self.verify_token(access_token, token_type="access")
if token_data.jti:
self._blacklist_token(token_data.jti)
logger.info(f"Successfully revoked token for user: {user_id}")
return True
else:
logger.warning(f"Token missing JTI for user: {user_id}")
return False
except jwt.InvalidTokenError as e:
logger.warning(
f"Invalid token during revocation for user {user_id}: {str(e)}"
)
return False

# If no token provided, still consider it successful
# (client-side logout without server-side token invalidation)
return True

def validate_google_client_configuration(self) -> Dict[str, any]:
Expand All @@ -349,6 +402,14 @@ def validate_google_client_configuration(self) -> Dict[str, any]:

return config_status

def get_blacklist_stats(self) -> Dict[str, any]:
"""Get token blacklist statistics"""
return {
"blacklisted_tokens": len(_token_blacklist),
"implementation": "in_memory",
"note": "In production, use Redis for distributed blacklist",
}

def health_check(self) -> Dict[str, any]:
"""Enhanced health check for auth service"""
try:
Expand Down Expand Up @@ -383,6 +444,7 @@ def health_check(self) -> Dict[str, any]:
"jwt_working": jwt_working,
"user_service": user_health,
"google_oauth": google_config,
"token_blacklist": self.get_blacklist_stats(),
"environment": self.environment,
"access_token_expire_minutes": self.access_token_expire_minutes,
"refresh_token_expire_days": self.refresh_token_expire_days,
Expand Down
59 changes: 58 additions & 1 deletion backend/tests/test_auth_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ def test_create_access_token(self, auth_service):
assert payload["sub"] == user_id
assert payload["email"] == email
assert payload["type"] == "access"
assert "jti" in payload
assert payload["jti"] is not None

def test_create_refresh_token(self, auth_service):
"""Test refresh token creation"""
Expand All @@ -75,6 +77,8 @@ def test_create_refresh_token(self, auth_service):
assert payload["sub"] == user_id
assert payload["email"] == email
assert payload["type"] == "refresh"
assert "jti" in payload
assert payload["jti"] is not None

def test_verify_access_token_success(self, auth_service):
"""Test successful access token verification"""
Expand All @@ -87,6 +91,7 @@ def test_verify_access_token_success(self, auth_service):
assert isinstance(token_data, TokenData)
assert token_data.user_id == user_id
assert token_data.email == email
assert token_data.jti is not None

def test_verify_refresh_token_success(self, auth_service):
"""Test successful refresh token verification"""
Expand All @@ -99,6 +104,7 @@ def test_verify_refresh_token_success(self, auth_service):
assert isinstance(token_data, TokenData)
assert token_data.user_id == user_id
assert token_data.email == email
assert token_data.jti is not None

def test_verify_invalid_token(self, auth_service):
"""Test invalid token verification"""
Expand Down Expand Up @@ -403,6 +409,57 @@ def test_get_current_user_inactive(self, auth_service, sample_user):
auth_service.get_current_user(access_token)

def test_revoke_user_tokens(self, auth_service):
"""Test token revocation (placeholder implementation)"""
"""Test token revocation with proper blacklisting"""
# Test without access token
result = auth_service.revoke_user_tokens("test_user_123")
assert result is True

# Test with access token
access_token = auth_service.create_access_token(
"test_user_123", "test@example.com"
)
result = auth_service.revoke_user_tokens(
"test_user_123", access_token=access_token
)
assert result is True

# Verify token is now blacklisted
with pytest.raises(jwt.InvalidTokenError, match="Token has been revoked"):
auth_service.verify_token(access_token)

def test_revoke_token_by_jti(self, auth_service):
"""Test token revocation by JWT ID"""
# Test with valid JTI
result = auth_service.revoke_token_by_jti("test_jti_123")
assert result is True

# Test with empty JTI
result = auth_service.revoke_token_by_jti("")
assert result is False

def test_token_blacklisting(self, auth_service):
"""Test comprehensive token blacklisting functionality"""
user_id = "test_user_123"
email = "test@example.com"

# Create token
token = auth_service.create_access_token(user_id, email)

# Verify token works initially
token_data = auth_service.verify_token(token)
assert token_data.user_id == user_id
assert token_data.jti is not None

# Blacklist the token
auth_service._blacklist_token(token_data.jti)

# Verify token is now rejected
with pytest.raises(jwt.InvalidTokenError, match="Token has been revoked"):
auth_service.verify_token(token)

def test_get_blacklist_stats(self, auth_service):
"""Test blacklist statistics"""
stats = auth_service.get_blacklist_stats()
assert "blacklisted_tokens" in stats
assert "implementation" in stats
assert stats["implementation"] == "in_memory"
Loading
Loading