Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
297 changes: 203 additions & 94 deletions backend/api/auth.py
Original file line number Diff line number Diff line change
@@ -1,120 +1,229 @@
import os
import logging
import uuid
from datetime import datetime, timedelta
from typing import Any, Dict
from typing import Optional

import jwt
from fastapi import APIRouter, Depends, HTTPException
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from fastapi.security import HTTPBearer

from models.response_schemas import (
ApiResponse,
AuthResponse,
LoginRequest,
RefreshTokenRequest,
User,
)
from models.response_schemas import ApiResponse, AuthResponse, LoginRequest, User
from models.user import UserInDB
from services.auth_service import AuthService

router = APIRouter(prefix="/auth", tags=["authentication"])
# Configure logging
logger = logging.getLogger(__name__)

router = APIRouter(prefix="/auth", tags=["Authentication"])
auth_service = AuthService()
security = HTTPBearer()

# Mock user database
MOCK_USERS = {
"google_user_123": {
"id": "user_001",
"email": "john.doe@example.com",
"name": "John Doe",
"avatar_url": "https://lh3.googleusercontent.com/a/default-user",
"created_at": "2025-01-01T00:00:00Z",
"last_sign_in_at": "2025-01-01T12:00:00Z",
}
}

# Mock JWT settings
JWT_SECRET = os.getenv("JWT_SECRET", "mock_secret_key_for_development")
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 60


def create_access_token(data: Dict[str, Any]) -> str:
"""Create JWT access token"""
to_encode = data.copy()
expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
to_encode.update({"exp": expire})
return jwt.encode(to_encode, JWT_SECRET, algorithm=ALGORITHM)


def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)) -> str:
"""Verify JWT token and return user_id"""
try:
payload = jwt.decode(
credentials.credentials, JWT_SECRET, algorithms=[ALGORITHM]
)
user_id: str = payload.get("sub")
if user_id is None:
raise HTTPException(status_code=401, detail="Invalid token")
return user_id
except jwt.PyJWTError:
raise HTTPException(status_code=401, detail="Invalid token")

def get_current_user_token(token: str = Depends(security)) -> str:
"""Extract token from Authorization header"""
return token.credentials


@router.post("/google")
async def login_with_google(request: LoginRequest) -> ApiResponse[AuthResponse]:
"""Mock Google OAuth login"""
# Mock Google token validation
if not request.google_token.startswith("mock_google_token"):
raise HTTPException(status_code=401, detail="Invalid Google token")
"""Google OAuth login with enhanced error handling"""
try:
logger.info("Received Google OAuth login request")

# Mock user from Google token
user_data = MOCK_USERS["google_user_123"]
user = User(**user_data)
# Validate request
if not request.google_token or not request.google_token.strip():
logger.warning("Empty Google token received")
raise HTTPException(status_code=400, detail="Google token is required")

# Create JWT tokens
access_token = create_access_token(data={"sub": user.id})
refresh_token = str(uuid.uuid4())
user, access_token, refresh_token, is_new_user = auth_service.login_with_google(
request.google_token.strip()
)

auth_response = AuthResponse(
user=user,
access_token=access_token,
refresh_token=refresh_token,
expires_in=ACCESS_TOKEN_EXPIRE_MINUTES * 60,
)
# Convert UserInDB to the response model directly
user_response = User(
id=str(user.id),
email=user.email,
name=user.name,
avatar_url=user.avatar_url,
created_at=user.created_at.isoformat(),
last_sign_in_at=user.updated_at.isoformat(), # Using updated_at for last sign-in
)

return ApiResponse(success=True, data=auth_response)
auth_response = AuthResponse(
user=user_response,
access_token=access_token,
refresh_token=refresh_token,
expires_in=auth_service.access_token_expire_minutes * 60,
)

logger.info(
f"Google OAuth login successful for user: {user.email}, is_new_user: {is_new_user}"
)
return ApiResponse(
success=True,
data=auth_response,
message=(
"Login successful"
if not is_new_user
else "Account created and login successful"
),
)

except HTTPException:
# Re-raise HTTPException without modification
raise
except ValueError as e:
logger.error(f"Google OAuth validation error: {str(e)}")
raise HTTPException(status_code=401, detail=f"Invalid Google token: {str(e)}")
except Exception as e:
logger.error(f"Google OAuth login failed: {str(e)}")
raise HTTPException(status_code=500, detail=f"Authentication failed: {str(e)}")


@router.get("/me")
async def get_current_user(user_id: str = Depends(verify_token)) -> ApiResponse[User]:
"""Get current user information"""
# Mock user lookup
for mock_user in MOCK_USERS.values():
if mock_user["id"] == user_id:
user = User(**mock_user)
return ApiResponse(success=True, data=user)
async def get_current_user(
token: str = Depends(get_current_user_token),
) -> ApiResponse[User]:
"""Get current user information with enhanced error handling"""
try:
logger.info("Received current user request")

user = auth_service.get_current_user(token)

# Convert UserInDB to the response model directly
user_response = User(
id=str(user.id),
email=user.email,
name=user.name,
avatar_url=user.avatar_url,
created_at=user.created_at.isoformat(),
last_sign_in_at=user.updated_at.isoformat(), # Using updated_at for last sign-in
)

logger.info(f"Current user request successful for: {user.email}")
return ApiResponse(success=True, data=user_response)

raise HTTPException(status_code=404, detail="User not found")
except jwt.InvalidTokenError as e:
logger.warning(f"Invalid token in current user request: {str(e)}")
raise HTTPException(
status_code=401, detail=f"Invalid or expired token: {str(e)}"
)
except Exception as e:
logger.error(f"Current user request failed: {str(e)}")
raise HTTPException(
status_code=500, detail=f"Failed to get user information: {str(e)}"
)


@router.post("/logout")
async def logout(user_id: str = Depends(verify_token)) -> ApiResponse[Dict[str, str]]:
"""Logout current user"""
return ApiResponse(success=True, data={"message": "Logged out successfully"})
async def logout(token: str = Depends(get_current_user_token)) -> ApiResponse[dict]:
"""Logout current user with enhanced logging"""
try:
logger.info("Received logout request")

# Verify token and get user for logging
user = auth_service.get_current_user(token)

# Revoke tokens (placeholder implementation)
success = auth_service.revoke_user_tokens(str(user.id))

if success:
logger.info(f"Logout successful for user: {user.email}")
return ApiResponse(
success=True,
data={"message": "Logged out successfully"},
message="You have been logged out",
)
else:
logger.error(f"Token revocation failed for user: {user.email}")
raise HTTPException(status_code=500, detail="Logout failed")

except jwt.InvalidTokenError as e:
logger.warning(f"Invalid token in logout request: {str(e)}")
raise HTTPException(
status_code=401, detail=f"Invalid or expired token: {str(e)}"
)
except Exception as e:
logger.error(f"Logout failed: {str(e)}")
raise HTTPException(status_code=500, detail=f"Logout failed: {str(e)}")


@router.post("/refresh")
async def refresh_token(request: RefreshTokenRequest) -> ApiResponse[Dict[str, Any]]:
"""Refresh access token"""
# Mock refresh token validation
if not request.refresh_token:
raise HTTPException(status_code=401, detail="Invalid refresh token")

# Create new access token
new_access_token = create_access_token(data={"sub": "user_001"})

return ApiResponse(
success=True,
data={
"access_token": new_access_token,
"expires_in": ACCESS_TOKEN_EXPIRE_MINUTES * 60,
},
)
async def refresh_token(request: dict) -> ApiResponse[AuthResponse]:
"""Refresh access token with enhanced validation"""
try:
logger.info("Received token refresh request")

# Validate request
refresh_token = request.get("refresh_token")
if not refresh_token or not refresh_token.strip():
logger.warning("Empty refresh token received")
raise HTTPException(status_code=400, detail="Refresh token is required")

new_access_token, user = auth_service.refresh_access_token(
refresh_token.strip()
)

# Convert to response format
user_response = User(
id=str(user.id),
email=user.email,
name=user.name,
avatar_url=user.avatar_url,
created_at=user.created_at.isoformat(),
last_sign_in_at=user.updated_at.isoformat(),
)

auth_response = AuthResponse(
user=user_response,
access_token=new_access_token,
refresh_token=refresh_token, # Keep the same refresh token
expires_in=auth_service.access_token_expire_minutes * 60,
)

logger.info(f"Token refresh successful for user: {user.email}")
return ApiResponse(
success=True, data=auth_response, message="Token refreshed successfully"
)

except HTTPException:
# Re-raise HTTPException without modification
raise
except jwt.InvalidTokenError as e:
logger.warning(f"Invalid refresh token: {str(e)}")
raise HTTPException(
status_code=401, detail=f"Invalid or expired refresh token: {str(e)}"
)
except Exception as e:
logger.error(f"Token refresh failed: {str(e)}")
raise HTTPException(status_code=500, detail=f"Token refresh failed: {str(e)}")


@router.get("/health")
async def auth_health_check() -> ApiResponse[dict]:
"""Enhanced authentication service health check"""
try:
logger.info("Received auth health check request")

health_data = auth_service.health_check()

# Determine HTTP status based on health
if health_data.get("status") == "healthy":
logger.info("Auth health check passed")
return ApiResponse(
success=True,
data=health_data,
message="Authentication service is healthy",
)
else:
logger.warning(f"Auth health check failed: {health_data}")
raise HTTPException(
status_code=503,
detail=f"Authentication service is unhealthy: {health_data.get('error', 'Unknown error')}",
)

except HTTPException:
# Re-raise HTTPException without modification
raise
except Exception as e:
logger.error(f"Auth health check error: {str(e)}")
raise HTTPException(status_code=500, detail=f"Health check failed: {str(e)}")
2 changes: 1 addition & 1 deletion backend/api/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

from fastapi import APIRouter, Depends, HTTPException, Query

from api.auth import verify_token
from api.projects import MOCK_PROJECTS
from middleware.auth_middleware import verify_token
from models.response_schemas import (
ApiResponse,
ChatMessage,
Expand Down
4 changes: 2 additions & 2 deletions backend/api/health.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from fastapi import APIRouter

from services.database_service import db_service
from services.database_service import get_db_service
from services.redis_service import redis_service
from services.storage_service import storage_service

Expand Down Expand Up @@ -44,7 +44,7 @@ async def health_check() -> Dict[str, Any]:
}

# Check all services in production
database_health = db_service.health_check()
database_health = get_db_service().health_check()
redis_health = redis_service.health_check()
storage_health = storage_service.health_check()

Expand Down
2 changes: 1 addition & 1 deletion backend/api/projects.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from fastapi import APIRouter, Depends, HTTPException, Query

from api.auth import verify_token
from middleware.auth_middleware import verify_token
from models.response_schemas import (
ApiResponse,
ColumnMetadata,
Expand Down
Loading
Loading