Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ DB_NAME=proto
DB_USER=dev
DB_PASSWORD=dev
CONNECTIONS_POOL_MIN_SIZE=10
CONNECTIONS_POOL_MAX_OVERFLOW=25
CONNECTIONS_POOL_MAX_OVERFLOW=30
CONNECTIONS_POOL_RECYCLE=3600
CONNECTIONS_POOL_TIMEOUT: 30

# Redis
# ------------------------------------------------------------------------------
Expand Down
4 changes: 3 additions & 1 deletion src/app/config/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ class SettingsBase(PydanticSettings):
DB_URL_SYNC: str = f"postgresql://{DB_USER}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_NAME}"

CONNECTIONS_POOL_MIN_SIZE: int = env.int("CONNECTIONS_POOL_MIN_SIZE", 5)
CONNECTIONS_POOL_MAX_OVERFLOW: int = env.int("CONNECTIONS_POOL_MAX_OVERFLOW", 25)
CONNECTIONS_POOL_MAX_OVERFLOW: int = env.int("CONNECTIONS_POOL_MAX_OVERFLOW", 35)
CONNECTIONS_POOL_RECYCLE: int = env.int("CONNECTIONS_POOL_RECYCLE", 3600) # 1 hour in seconds
CONNECTIONS_POOL_TIMEOUT: int = env.int("CONNECTIONS_POOL_TIMEOUT", 30) # seconds

# Redis Settings
# --------------------------------------------------------------------------
Expand Down
55 changes: 44 additions & 11 deletions src/app/infrastructure/extensions/psql_ext/psql_ext.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,71 @@
from contextlib import asynccontextmanager
from typing import AsyncGenerator

from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
from sqlalchemy import create_engine, text
from sqlalchemy.orm import DeclarativeBase, sessionmaker

from src.app.config.settings import settings

Base = declarative_base()

class Base(DeclarativeBase):
pass


CONNECTIONS_POOL_USE_LIFO: bool = True # LIFO for better connection reuse
DB_JIT_DISABLED: bool = True # Disable JIT
DB_ISOLATION_LEVEL: str = "READ COMMITTED"

# Init connection for own database ...
default_engine = create_async_engine(
settings.DB_URL,
pool_size=settings.CONNECTIONS_POOL_MIN_SIZE,
max_overflow=settings.CONNECTIONS_POOL_MAX_OVERFLOW,
pool_recycle=60 * 60 * 3, # recycle after 3 hours
pool_recycle=settings.CONNECTIONS_POOL_RECYCLE,
pool_timeout=settings.CONNECTIONS_POOL_TIMEOUT,
pool_use_lifo=CONNECTIONS_POOL_USE_LIFO,
pool_pre_ping=True,
future=True,
echo_pool=True,
echo=settings.SHOW_SQL,
connect_args={"server_settings": {"jit": "off"}},
isolation_level=DB_ISOLATION_LEVEL,
connect_args={"server_settings": {"jit": "off" if DB_JIT_DISABLED else "on"}},
)

default_session = sessionmaker(
default_engine, # type: ignore
class_=AsyncSession, # type: ignore
default_session = async_sessionmaker(
default_engine,
class_=AsyncSession,
expire_on_commit=True,
)


# Allowed isolation levels for validation
ALLOWED_ISOLATION_LEVELS = {
"READ UNCOMMITTED",
"READ COMMITTED",
"REPEATABLE READ",
"SERIALIZABLE",
}


@asynccontextmanager
async def get_session(expire_on_commit: bool = False) -> AsyncGenerator:
async def get_session(
expire_on_commit: bool = False,
isolation_level: str | None = None,
) -> AsyncGenerator:

# Validate isolation level to prevent SQL injection
if isolation_level and isolation_level not in ALLOWED_ISOLATION_LEVELS:
raise ValueError(
f"Invalid isolation level: '{isolation_level}'. "
f"Allowed values: {', '.join(sorted(ALLOWED_ISOLATION_LEVELS))}"
)

try:
async with default_session(expire_on_commit=expire_on_commit) as session:
if isolation_level:
# Safe to use string formatting after validation
await session.execute(text(f"SET TRANSACTION ISOLATION LEVEL {isolation_level}"))
yield session
except Exception as e:
await session.rollback()
Expand Down
62 changes: 19 additions & 43 deletions src/app/infrastructure/repositories/base/base_psql_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from src.app.infrastructure.repositories.base.abstract import (
AbstractBaseRepository,
OuterGenericType,
BaseModel,
RepositoryError,
)
from src.app.infrastructure.utils.common import generate_str
Expand Down Expand Up @@ -624,7 +623,7 @@ def query_builder(cls) -> Type[QueryBuilder]:
return cls._QUERY_BUILDER_CLASS

@classmethod
def model(cls) -> Type[BaseModel]:
def model(cls) -> Type[Base]:
"""Get the SQLAlchemy model class for this repository"""
if not cls.MODEL:
raise AttributeError("Model class not configured")
Expand Down Expand Up @@ -679,7 +678,7 @@ async def count(cls, filter_data: Optional[dict] = None) -> int:
stmt: Select = select(func.count(cls.model().id)) # type: ignore
stmt = cls.query_builder().apply_where(stmt, filter_data=filter_data_, model_class=cls.model())

async with get_session(expire_on_commit=True) as session:
async with get_session(expire_on_commit=False) as session:
result = await session.execute(stmt)
return result.scalars().first()

Expand All @@ -706,7 +705,7 @@ async def get_first(
stmt: Select = select(cls.model())
stmt = cls.query_builder().apply_where(stmt, filter_data=filter_data_, model_class=cls.model())

async with get_session(expire_on_commit=True) as session:
async with get_session(expire_on_commit=False) as session:
result = await session.execute(stmt)

raw = result.scalars().first()
Expand All @@ -733,7 +732,7 @@ async def get_list(
stmt = cls.query_builder().apply_ordering(stmt, order_data=order_data, model_class=cls.model())
stmt = cls.query_builder().apply_pagination(stmt, filter_data=filter_data_)

async with get_session(expire_on_commit=True) as session:
async with get_session(expire_on_commit=False) as session:
result = await session.execute(stmt)

raw_items = result.scalars().all()
Expand All @@ -757,7 +756,7 @@ async def create(

cls._set_timestamps_on_create(items=[data_copy])

async with get_session(expire_on_commit=True) as session:
async with get_session(expire_on_commit=False) as session:
if is_return_require:
# Use RETURNING to get specific columns instead of the whole model
model_class = cls.model() # type: ignore
Expand Down Expand Up @@ -802,7 +801,7 @@ async def update(
stmt = stmt.values(**data_copy)
stmt.execution_options(synchronize_session="fetch")

async with get_session(expire_on_commit=True) as session:
async with get_session(expire_on_commit=False) as session:
await session.execute(stmt)
await session.commit()

Expand Down Expand Up @@ -863,15 +862,17 @@ async def create_bulk(
# Add timestamps to all items
cls._set_timestamps_on_create(items=items_copy)

async with get_session(expire_on_commit=True) as session:
# No need to keep objects attached, we use RETURNING clause
async with get_session(expire_on_commit=False) as session:
model_class = cls.model() # type: ignore
model_table = model_class.__table__ # type: ignore
if is_return_require:
# Use RETURNING to get created records efficiently
# Use RETURNING to get created records efficiently in single query
stmt = insert(model_class).values(items_copy).returning(*model_table.columns.values())
result = await session.execute(stmt)
await session.commit()

# Process results immediately after commit, before session closes
raw_items = result.fetchall()
out_entity_, _ = cls.out_dataclass_with_columns(out_dataclass=out_dataclass)
created_items = []
Expand All @@ -891,12 +892,13 @@ async def create_bulk(
async def update_bulk(
cls, items: List[dict], is_return_require: bool = False, out_dataclass: Optional[OuterGenericType] = None
) -> List[OuterGenericType] | None:
"""Update multiple records in optimized bulk operation
"""
Update multiple records in optimized bulk operation.

Note: Currently uses 2 queries for returning case:
- Option 1: Keep current ORM approach (cleaner, 2 queries for returning)
- Option 2: Go back to raw SQL (1 query, but more complex)
- Option 3: Hybrid approach - use ORM for non-returning, raw SQL for returning
Performance notes:
- Uses expire_on_commit=False to avoid unnecessary object expiration
- When is_return_require=True: 2 queries (bulk update + select)
- When is_return_require=False: 1 query (bulk update only)
"""
if not items:
return None
Expand All @@ -905,7 +907,8 @@ async def update_bulk(

cls._set_timestamps_on_update(items=items_copy)

async with get_session(expire_on_commit=True) as session:
# expire_on_commit=False for better performance, no ORM objects to track
async with get_session(expire_on_commit=False) as session:
if is_return_require:
return await cls._bulk_update_with_returning(session, items_copy, out_dataclass)
else:
Expand Down Expand Up @@ -936,7 +939,7 @@ async def _bulk_update_with_returning(
return []

# Query the updated records
stmt = select(model_class).where(model_class.id.in_(updated_ids))
stmt = select(model_class).where(model_class.id.in_(updated_ids)) # type: ignore[attr-defined]
result = await session.execute(stmt)
updated_records = result.scalars().all()

Expand All @@ -962,33 +965,6 @@ async def _bulk_update_without_returning(cls, session: Any, items: List[dict]) -
await session.execute(update(model_class), items, execution_options={"synchronize_session": False})
await session.commit()

@classmethod
async def _update_single_with_returning(
cls, session: Any, item_data: dict, out_entity_: Callable
) -> OuterGenericType | None:
"""Update a single item and return the updated entity (legacy method)"""
if "id" not in item_data:
return None

model_class = cls.model() # type: ignore
model_table = model_class.__table__ # type: ignore

item_id = item_data.pop("id")
stmt = (
update(model_class)
.where(model_class.id == item_id) # type: ignore
.values(**item_data)
.returning(*model_table.columns.values())
)
result = await session.execute(stmt)
raw = result.fetchone()
if raw:
# Convert Row to dict using column names
column_names = [col.name for col in model_table.columns.values()]
entity_data = dict(zip(column_names, raw))
return out_entity_(**entity_data)
return None

# ==========================================
# UTILITY METHODS
# ==========================================
Expand Down
14 changes: 10 additions & 4 deletions src/app/interfaces/api/v1/endpoints/debug/resources.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from typing import Annotated, Dict
from typing import Annotated

from fastapi import APIRouter, Body, Request
from fastapi.responses import JSONResponse
from starlette.status import HTTP_200_OK, HTTP_400_BAD_REQUEST

from src.app.application.container import container as services_container
from src.app.interfaces.api.v1.endpoints.debug.schemas.req_schemas import MessageReq
from src.app.config.settings import settings
from src.app.infrastructure.messaging.mq_client import mq_client
from src.app.interfaces.api.v1.endpoints.debug.schemas.req_schemas import MessageReq

router = APIRouter(prefix="/debug")

Expand Down Expand Up @@ -41,7 +44,10 @@ async def send_message(
@router.get("/health-check/", status_code=200)
async def health_check(
request: Request,
) -> Dict[str, str]:
) -> JSONResponse:
is_healthy = await services_container.common_service.is_healthy()
status = "OK" if is_healthy else "NOT OK"
return {"status": status}
status_code = HTTP_200_OK if is_healthy else HTTP_400_BAD_REQUEST
resp = JSONResponse(content={"status": status}, status_code=status_code)

return resp
2 changes: 1 addition & 1 deletion src/app/interfaces/grpc/services/debug_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,5 @@ async def SendMessage(self, request, context) -> pb2.MessageResp: # type: ignor

async def HealthCheck(self, request, context) -> pb2.HealthCheckResp: # type: ignore
is_healthy = await services_container.common_service.is_healthy()
status = "OK" if is_healthy else "NOT OK"
status = "SERVING" if is_healthy else "NOT_SERVING"
return pb2.HealthCheckResp(status=status) # type: ignore