diff --git a/.env.example b/.env.example index 2c8d9a8..5b5e87c 100644 --- a/.env.example +++ b/.env.example @@ -41,7 +41,9 @@ DB_NAME=proto DB_USER=dev DB_PASSWORD=dev CONNECTIONS_POOL_MIN_SIZE=10 -CONNECTIONS_POOL_MAX_OVERFLOW=25 +CONNECTIONS_POOL_MAX_OVERFLOW=30 +CONNECTIONS_POOL_RECYCLE=3600 +CONNECTIONS_POOL_TIMEOUT: 30 # Redis # ------------------------------------------------------------------------------ diff --git a/src/app/config/settings.py b/src/app/config/settings.py index 2168d60..44e64cf 100644 --- a/src/app/config/settings.py +++ b/src/app/config/settings.py @@ -66,7 +66,9 @@ class SettingsBase(PydanticSettings): DB_URL_SYNC: str = f"postgresql://{DB_USER}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_NAME}" CONNECTIONS_POOL_MIN_SIZE: int = env.int("CONNECTIONS_POOL_MIN_SIZE", 5) - CONNECTIONS_POOL_MAX_OVERFLOW: int = env.int("CONNECTIONS_POOL_MAX_OVERFLOW", 25) + CONNECTIONS_POOL_MAX_OVERFLOW: int = env.int("CONNECTIONS_POOL_MAX_OVERFLOW", 35) + CONNECTIONS_POOL_RECYCLE: int = env.int("CONNECTIONS_POOL_RECYCLE", 3600) # 1 hour in seconds + CONNECTIONS_POOL_TIMEOUT: int = env.int("CONNECTIONS_POOL_TIMEOUT", 30) # seconds # Redis Settings # -------------------------------------------------------------------------- diff --git a/src/app/infrastructure/extensions/psql_ext/psql_ext.py b/src/app/infrastructure/extensions/psql_ext/psql_ext.py index 3c2550a..cbc83e7 100644 --- a/src/app/infrastructure/extensions/psql_ext/psql_ext.py +++ b/src/app/infrastructure/extensions/psql_ext/psql_ext.py @@ -1,38 +1,71 @@ from contextlib import asynccontextmanager from typing import AsyncGenerator -from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine -from sqlalchemy import create_engine -from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import sessionmaker +from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker +from sqlalchemy import create_engine, text +from sqlalchemy.orm import DeclarativeBase, sessionmaker from src.app.config.settings import settings -Base = declarative_base() + +class Base(DeclarativeBase): + pass + + +CONNECTIONS_POOL_USE_LIFO: bool = True # LIFO for better connection reuse +DB_JIT_DISABLED: bool = True # Disable JIT +DB_ISOLATION_LEVEL: str = "READ COMMITTED" # Init connection for own database ... default_engine = create_async_engine( settings.DB_URL, pool_size=settings.CONNECTIONS_POOL_MIN_SIZE, max_overflow=settings.CONNECTIONS_POOL_MAX_OVERFLOW, - pool_recycle=60 * 60 * 3, # recycle after 3 hours + pool_recycle=settings.CONNECTIONS_POOL_RECYCLE, + pool_timeout=settings.CONNECTIONS_POOL_TIMEOUT, + pool_use_lifo=CONNECTIONS_POOL_USE_LIFO, pool_pre_ping=True, future=True, echo_pool=True, echo=settings.SHOW_SQL, - connect_args={"server_settings": {"jit": "off"}}, + isolation_level=DB_ISOLATION_LEVEL, + connect_args={"server_settings": {"jit": "off" if DB_JIT_DISABLED else "on"}}, ) -default_session = sessionmaker( - default_engine, # type: ignore - class_=AsyncSession, # type: ignore +default_session = async_sessionmaker( + default_engine, + class_=AsyncSession, + expire_on_commit=True, ) +# Allowed isolation levels for validation +ALLOWED_ISOLATION_LEVELS = { + "READ UNCOMMITTED", + "READ COMMITTED", + "REPEATABLE READ", + "SERIALIZABLE", +} + + @asynccontextmanager -async def get_session(expire_on_commit: bool = False) -> AsyncGenerator: +async def get_session( + expire_on_commit: bool = False, + isolation_level: str | None = None, +) -> AsyncGenerator: + + # Validate isolation level to prevent SQL injection + if isolation_level and isolation_level not in ALLOWED_ISOLATION_LEVELS: + raise ValueError( + f"Invalid isolation level: '{isolation_level}'. " + f"Allowed values: {', '.join(sorted(ALLOWED_ISOLATION_LEVELS))}" + ) + try: async with default_session(expire_on_commit=expire_on_commit) as session: + if isolation_level: + # Safe to use string formatting after validation + await session.execute(text(f"SET TRANSACTION ISOLATION LEVEL {isolation_level}")) yield session except Exception as e: await session.rollback() diff --git a/src/app/infrastructure/repositories/base/base_psql_repository.py b/src/app/infrastructure/repositories/base/base_psql_repository.py index 0c9f88e..b4e0842 100644 --- a/src/app/infrastructure/repositories/base/base_psql_repository.py +++ b/src/app/infrastructure/repositories/base/base_psql_repository.py @@ -28,7 +28,6 @@ from src.app.infrastructure.repositories.base.abstract import ( AbstractBaseRepository, OuterGenericType, - BaseModel, RepositoryError, ) from src.app.infrastructure.utils.common import generate_str @@ -624,7 +623,7 @@ def query_builder(cls) -> Type[QueryBuilder]: return cls._QUERY_BUILDER_CLASS @classmethod - def model(cls) -> Type[BaseModel]: + def model(cls) -> Type[Base]: """Get the SQLAlchemy model class for this repository""" if not cls.MODEL: raise AttributeError("Model class not configured") @@ -679,7 +678,7 @@ async def count(cls, filter_data: Optional[dict] = None) -> int: stmt: Select = select(func.count(cls.model().id)) # type: ignore stmt = cls.query_builder().apply_where(stmt, filter_data=filter_data_, model_class=cls.model()) - async with get_session(expire_on_commit=True) as session: + async with get_session(expire_on_commit=False) as session: result = await session.execute(stmt) return result.scalars().first() @@ -706,7 +705,7 @@ async def get_first( stmt: Select = select(cls.model()) stmt = cls.query_builder().apply_where(stmt, filter_data=filter_data_, model_class=cls.model()) - async with get_session(expire_on_commit=True) as session: + async with get_session(expire_on_commit=False) as session: result = await session.execute(stmt) raw = result.scalars().first() @@ -733,7 +732,7 @@ async def get_list( stmt = cls.query_builder().apply_ordering(stmt, order_data=order_data, model_class=cls.model()) stmt = cls.query_builder().apply_pagination(stmt, filter_data=filter_data_) - async with get_session(expire_on_commit=True) as session: + async with get_session(expire_on_commit=False) as session: result = await session.execute(stmt) raw_items = result.scalars().all() @@ -757,7 +756,7 @@ async def create( cls._set_timestamps_on_create(items=[data_copy]) - async with get_session(expire_on_commit=True) as session: + async with get_session(expire_on_commit=False) as session: if is_return_require: # Use RETURNING to get specific columns instead of the whole model model_class = cls.model() # type: ignore @@ -802,7 +801,7 @@ async def update( stmt = stmt.values(**data_copy) stmt.execution_options(synchronize_session="fetch") - async with get_session(expire_on_commit=True) as session: + async with get_session(expire_on_commit=False) as session: await session.execute(stmt) await session.commit() @@ -863,15 +862,17 @@ async def create_bulk( # Add timestamps to all items cls._set_timestamps_on_create(items=items_copy) - async with get_session(expire_on_commit=True) as session: + # No need to keep objects attached, we use RETURNING clause + async with get_session(expire_on_commit=False) as session: model_class = cls.model() # type: ignore model_table = model_class.__table__ # type: ignore if is_return_require: - # Use RETURNING to get created records efficiently + # Use RETURNING to get created records efficiently in single query stmt = insert(model_class).values(items_copy).returning(*model_table.columns.values()) result = await session.execute(stmt) await session.commit() + # Process results immediately after commit, before session closes raw_items = result.fetchall() out_entity_, _ = cls.out_dataclass_with_columns(out_dataclass=out_dataclass) created_items = [] @@ -891,12 +892,13 @@ async def create_bulk( async def update_bulk( cls, items: List[dict], is_return_require: bool = False, out_dataclass: Optional[OuterGenericType] = None ) -> List[OuterGenericType] | None: - """Update multiple records in optimized bulk operation + """ + Update multiple records in optimized bulk operation. - Note: Currently uses 2 queries for returning case: - - Option 1: Keep current ORM approach (cleaner, 2 queries for returning) - - Option 2: Go back to raw SQL (1 query, but more complex) - - Option 3: Hybrid approach - use ORM for non-returning, raw SQL for returning + Performance notes: + - Uses expire_on_commit=False to avoid unnecessary object expiration + - When is_return_require=True: 2 queries (bulk update + select) + - When is_return_require=False: 1 query (bulk update only) """ if not items: return None @@ -905,7 +907,8 @@ async def update_bulk( cls._set_timestamps_on_update(items=items_copy) - async with get_session(expire_on_commit=True) as session: + # expire_on_commit=False for better performance, no ORM objects to track + async with get_session(expire_on_commit=False) as session: if is_return_require: return await cls._bulk_update_with_returning(session, items_copy, out_dataclass) else: @@ -936,7 +939,7 @@ async def _bulk_update_with_returning( return [] # Query the updated records - stmt = select(model_class).where(model_class.id.in_(updated_ids)) + stmt = select(model_class).where(model_class.id.in_(updated_ids)) # type: ignore[attr-defined] result = await session.execute(stmt) updated_records = result.scalars().all() @@ -962,33 +965,6 @@ async def _bulk_update_without_returning(cls, session: Any, items: List[dict]) - await session.execute(update(model_class), items, execution_options={"synchronize_session": False}) await session.commit() - @classmethod - async def _update_single_with_returning( - cls, session: Any, item_data: dict, out_entity_: Callable - ) -> OuterGenericType | None: - """Update a single item and return the updated entity (legacy method)""" - if "id" not in item_data: - return None - - model_class = cls.model() # type: ignore - model_table = model_class.__table__ # type: ignore - - item_id = item_data.pop("id") - stmt = ( - update(model_class) - .where(model_class.id == item_id) # type: ignore - .values(**item_data) - .returning(*model_table.columns.values()) - ) - result = await session.execute(stmt) - raw = result.fetchone() - if raw: - # Convert Row to dict using column names - column_names = [col.name for col in model_table.columns.values()] - entity_data = dict(zip(column_names, raw)) - return out_entity_(**entity_data) - return None - # ========================================== # UTILITY METHODS # ========================================== diff --git a/src/app/interfaces/api/v1/endpoints/debug/resources.py b/src/app/interfaces/api/v1/endpoints/debug/resources.py index b981c05..ded5bde 100644 --- a/src/app/interfaces/api/v1/endpoints/debug/resources.py +++ b/src/app/interfaces/api/v1/endpoints/debug/resources.py @@ -1,10 +1,13 @@ -from typing import Annotated, Dict +from typing import Annotated from fastapi import APIRouter, Body, Request +from fastapi.responses import JSONResponse +from starlette.status import HTTP_200_OK, HTTP_400_BAD_REQUEST + from src.app.application.container import container as services_container -from src.app.interfaces.api.v1.endpoints.debug.schemas.req_schemas import MessageReq from src.app.config.settings import settings from src.app.infrastructure.messaging.mq_client import mq_client +from src.app.interfaces.api.v1.endpoints.debug.schemas.req_schemas import MessageReq router = APIRouter(prefix="/debug") @@ -41,7 +44,10 @@ async def send_message( @router.get("/health-check/", status_code=200) async def health_check( request: Request, -) -> Dict[str, str]: +) -> JSONResponse: is_healthy = await services_container.common_service.is_healthy() status = "OK" if is_healthy else "NOT OK" - return {"status": status} + status_code = HTTP_200_OK if is_healthy else HTTP_400_BAD_REQUEST + resp = JSONResponse(content={"status": status}, status_code=status_code) + + return resp diff --git a/src/app/interfaces/grpc/services/debug_service.py b/src/app/interfaces/grpc/services/debug_service.py index fee8e1f..0330b87 100644 --- a/src/app/interfaces/grpc/services/debug_service.py +++ b/src/app/interfaces/grpc/services/debug_service.py @@ -27,5 +27,5 @@ async def SendMessage(self, request, context) -> pb2.MessageResp: # type: ignor async def HealthCheck(self, request, context) -> pb2.HealthCheckResp: # type: ignore is_healthy = await services_container.common_service.is_healthy() - status = "OK" if is_healthy else "NOT OK" + status = "SERVING" if is_healthy else "NOT_SERVING" return pb2.HealthCheckResp(status=status) # type: ignore