Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
cc3d4ba
refactor: change async functions to synchronous in MainProvider
Naksen Dec 9, 2025
d39cec5
refactor: optimize query loading for entity type in SearchRequest
Naksen Dec 9, 2025
a8144ba
refactor: change loading strategy for Directory group in SearchRequest
Naksen Dec 9, 2025
51837a1
refactor: change loading strategy from selectinload to joinedload for…
Naksen Dec 9, 2025
349b760
refactor: change async functions to synchronous in DHCP manager and A…
Naksen Dec 9, 2025
7a3780b
refactor: change async iteration to synchronous for directory results…
Naksen Dec 9, 2025
4bba1ea
refactor: change loading strategy to use contains_eager for Directory…
Naksen Dec 9, 2025
0b8009a
refactor: reorganize import statements in search.py for clarity
Naksen Dec 9, 2025
e469e82
refactor: add CONTEXT_TYPE class variable to LDAP request classes
Naksen Dec 9, 2025
4e64bfc
refactor: LDAP context provider to use AsyncSessionSearchRequest and …
Naksen Dec 16, 2025
2428d93
test: enhance LDAP search request context provision with a dedicated …
Naksen Dec 16, 2025
00ae84f
refactor: optimize event processing logic in BaseRequest class
Naksen Dec 16, 2025
cf19cc9
refactor: add async method to provide LDAP search request context
Naksen Dec 16, 2025
8d78cff
refactor: add blank line for improved readability in LDAP context pro…
Naksen Dec 16, 2025
b1002c0
refactor: specify type for ctx parameter in handle method of AbandonR…
Naksen Dec 16, 2025
f8876c3
refactor: remove debug logging for group membership in SearchRequest …
Naksen Dec 16, 2025
95de34a
refactor: update context handling in BaseRequest class for improved p…
Naksen Dec 16, 2025
c07821f
refactor: change get_search_request_context method to synchronous
Naksen Dec 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 68 additions & 28 deletions app/ioc.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
from ldap_protocol.kerberos.service import KerberosService
from ldap_protocol.kerberos.template_render import KRBTemplateRenderer
from ldap_protocol.ldap_requests.contexts import (
AsyncSessionSearchRequest,
LDAPAddRequestContext,
LDAPBindRequestContext,
LDAPDeleteRequestContext,
Expand Down Expand Up @@ -205,7 +206,7 @@ async def get_kadmin_http(
yield KadminHTTPClient(client)

@provide(scope=Scope.REQUEST)
async def get_kadmin(
def get_kadmin(
self,
client: KadminHTTPClient,
kadmin_class: type[AbstractKadmin],
Expand Down Expand Up @@ -260,14 +261,14 @@ async def get_dns_http_client(
yield DNSManagerHTTPClient(client)

@provide(scope=Scope.REQUEST)
async def get_dns_mngr(
def get_dns_mngr(
self,
settings: DNSManagerSettings,
dns_manager_class: type[AbstractDNSManager],
http_client: DNSManagerHTTPClient,
) -> AsyncIterator[AbstractDNSManager]:
) -> AbstractDNSManager:
"""Get DNSManager class."""
yield dns_manager_class(settings=settings, http_client=http_client)
return dns_manager_class(settings=settings, http_client=http_client)

@provide(scope=Scope.APP)
async def get_redis_for_sessions(
Expand All @@ -284,7 +285,7 @@ async def get_redis_for_sessions(
await client.aclose()

@provide(scope=Scope.APP)
async def get_session_storage(
def get_session_storage(
self,
client: SessionStorageClient,
settings: Settings,
Expand All @@ -297,7 +298,7 @@ async def get_session_storage(
)

@provide()
async def get_normalized_audit_event(
def get_normalized_audit_event(
self,
) -> type[NormalizedAuditEvent]:
"""Get normalized audit event class."""
Expand All @@ -318,13 +319,13 @@ async def get_audit_redis_client(
await client.aclose()

@provide(scope=Scope.APP)
async def get_raw_audit_manager(
def get_raw_audit_manager(
self,
client: AuditRedisClient,
settings: Settings,
) -> AsyncIterator[RawAuditManager]:
) -> RawAuditManager:
"""Get raw audit manager."""
yield RawAuditManager(
return RawAuditManager(
client,
settings.RAW_EVENT_STREAM_NAME,
settings.EVENT_HANDLER_GROUP,
Expand All @@ -333,13 +334,13 @@ async def get_raw_audit_manager(
)

@provide(scope=Scope.APP)
async def get_normalized_audit_manager(
def get_normalized_audit_manager(
self,
client: AuditRedisClient,
settings: Settings,
) -> AsyncIterator[NormalizedAuditManager]:
) -> NormalizedAuditManager:
"""Get raw audit manager."""
yield NormalizedAuditManager(
return NormalizedAuditManager(
client,
settings.NORMALIZED_EVENT_STREAM_NAME,
settings.EVENT_SENDER_GROUP,
Expand All @@ -352,7 +353,7 @@ async def get_normalized_audit_manager(
audit_destination_dao = provide(AuditDestinationDAO, scope=Scope.REQUEST)

@provide(scope=Scope.REQUEST)
async def get_dhcp_manager_repository(
def get_dhcp_manager_repository(
self,
session: AsyncSession,
) -> DHCPManagerRepository:
Expand All @@ -368,20 +369,20 @@ async def get_dhcp_manager_state(
return await dhcp_manager_repository.ensure_state()

@provide(scope=Scope.REQUEST)
async def get_dhcp_mngr_class(
def get_dhcp_mngr_class(
self,
dhcp_state: DHCPManagerState,
) -> type[AbstractDHCPManager]:
"""Get DHCP manager type."""
return await get_dhcp_manager_class(dhcp_state)
return get_dhcp_manager_class(dhcp_state)

@provide(scope=Scope.REQUEST)
async def get_dhcp_api_repository_class(
def get_dhcp_api_repository_class(
self,
dhcp_state: DHCPManagerState,
) -> type[DHCPAPIRepository]:
"""Get DHCP API repository type."""
return await get_dhcp_api_repository_class(dhcp_state)
return get_dhcp_api_repository_class(dhcp_state)

@provide(scope=Scope.APP)
async def get_dhcp_http_client(
Expand All @@ -395,7 +396,7 @@ async def get_dhcp_http_client(
yield DHCPManagerHTTPClient(http_client)

@provide(scope=Scope.REQUEST)
async def get_dhcp_api_repository(
def get_dhcp_api_repository(
self,
http_client: DHCPManagerHTTPClient,
dhcp_api_repository_class: type[DHCPAPIRepository],
Expand All @@ -404,7 +405,7 @@ async def get_dhcp_api_repository(
return dhcp_api_repository_class(http_client)

@provide(scope=Scope.REQUEST)
async def get_dhcp_mngr(
def get_dhcp_mngr(
self,
dhcp_manager_class: type[AbstractDHCPManager],
dhcp_api_repository: DHCPAPIRepository,
Expand Down Expand Up @@ -445,7 +446,7 @@ async def get_dhcp_mngr(
)
password_utils = provide(PasswordUtils, scope=Scope.RUNTIME)

access_manager = provide(AccessManager, scope=Scope.REQUEST)
access_manager = provide(AccessManager, scope=Scope.RUNTIME)
role_dao = provide(RoleDAO, scope=Scope.REQUEST)
ace_dao = provide(AccessControlEntryDAO, scope=Scope.REQUEST)
role_use_case = provide(RoleUseCase, scope=Scope.REQUEST)
Expand Down Expand Up @@ -490,15 +491,37 @@ class LDAPContextProvider(Provider):
LDAPModifyDNRequestContext,
scope=Scope.REQUEST,
)
search_request_context = provide(
LDAPSearchRequestContext,
scope=Scope.REQUEST,
)
unbind_request_context = provide(
LDAPUnbindRequestContext,
scope=Scope.REQUEST,
)

@provide(scope=Scope.SESSION)
async def create_search_session(
self,
async_session: async_sessionmaker[AsyncSession],
) -> AsyncIterator[AsyncSessionSearchRequest]:
"""Create session for request."""
async with async_session() as session:
yield session # type: ignore

@provide(scope=Scope.SESSION, provides=LDAPSearchRequestContext)
def get_search_request_context(
self,
session: AsyncSessionSearchRequest,
ldap_session: LDAPSession,
settings: Settings,
access_manager: AccessManager,
) -> LDAPSearchRequestContext:
"""Get search request context."""
return LDAPSearchRequestContext(
session=session,
ldap_session=ldap_session,
settings=settings,
access_manager=access_manager,
rootdse_rd=RootDSEReader(settings, SADomainGateway(session)),
)


class HTTPProvider(LDAPContextProvider):
"""HTTP LDAP session."""
Expand All @@ -508,7 +531,7 @@ class HTTPProvider(LDAPContextProvider):
monitor_use_case = provide(AuditMonitorUseCase, scope=Scope.REQUEST)

@provide()
async def get_audit_monitor(
def get_audit_monitor(
self,
session: AsyncSession,
audit_use_case: "AuditUseCase",
Expand Down Expand Up @@ -568,7 +591,7 @@ def get_permissions_provider(
return auth_provider

@provide()
async def get_identity_provider(
def get_identity_provider(
self,
request: Request,
session_storage: SessionStorage,
Expand Down Expand Up @@ -649,6 +672,23 @@ def get_krb_template_render(
)
network_policy_gateway = provide(NetworkPolicyGateway, scope=Scope.REQUEST)

@provide(scope=Scope.REQUEST, provides=LDAPSearchRequestContext)
async def get_search_request_context(
self,
session: AsyncSession,
ldap_session: LDAPSession,
settings: Settings,
access_manager: AccessManager,
) -> LDAPSearchRequestContext:
"""Get search request context."""
return LDAPSearchRequestContext(
session=session, # type: ignore
ldap_session=ldap_session,
settings=settings,
access_manager=access_manager,
rootdse_rd=RootDSEReader(settings, SADomainGateway(session)),
)


class LDAPServerProvider(LDAPContextProvider):
"""Provider with session scope."""
Expand Down Expand Up @@ -739,7 +779,7 @@ async def get_client(
yield MFAHTTPClient(client)

@provide(provides=MultifactorAPI)
async def get_http_mfa(
def get_http_mfa(
self,
credentials: MFA_HTTP_Creds,
client: MFAHTTPClient,
Expand All @@ -761,7 +801,7 @@ async def get_http_mfa(
)

@provide(provides=LDAPMultiFactorAPI)
async def get_ldap_mfa(
def get_ldap_mfa(
self,
credentials: MFA_LDAP_Creds,
client: MFAHTTPClient,
Expand Down
4 changes: 2 additions & 2 deletions app/ldap_protocol/dhcp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from .stub import StubDHCPAPIRepository, StubDHCPManager


async def get_dhcp_manager_class(
def get_dhcp_manager_class(
dhcp_state: DHCPManagerState,
) -> type[AbstractDHCPManager]:
"""Get an instance of the DHCP manager."""
Expand All @@ -35,7 +35,7 @@ async def get_dhcp_manager_class(
return StubDHCPManager


async def get_dhcp_api_repository_class(
def get_dhcp_api_repository_class(
dhcp_state: DHCPManagerState,
) -> type[DHCPAPIRepository]:
"""Get an instance of the DHCP API repository."""
Expand Down
2 changes: 1 addition & 1 deletion app/ldap_protocol/ldap_requests/abandon.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def from_data(
"""Create structure from ASN1Row dataclass list."""
return cls(message_id=1)

async def handle(self) -> AsyncGenerator:
async def handle(self, ctx: None) -> AsyncGenerator: # noqa: ARG002
"""Handle message with current user."""
await asyncio.sleep(0)
return
Expand Down
1 change: 1 addition & 0 deletions app/ldap_protocol/ldap_requests/add.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ class AddRequest(BaseRequest):
"""

PROTOCOL_OP: ClassVar[int] = ProtocolRequests.ADD
CONTEXT_TYPE: ClassVar[type] = LDAPAddRequestContext

entry: str = Field(..., description="Any `DistinguishedName`")
attributes: list[PartialAttribute]
Expand Down
68 changes: 39 additions & 29 deletions app/ldap_protocol/ldap_requests/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from ldap_protocol.dependency import resolve_deps
from ldap_protocol.dialogue import LDAPSession
from ldap_protocol.ldap_responses import BaseResponse, LDAPResult
from ldap_protocol.objects import ProtocolRequests
from ldap_protocol.policies.audit.audit_use_case import AuditUseCase
from ldap_protocol.policies.audit.events.factory import (
RawAuditEventBuilderRedis,
Expand Down Expand Up @@ -62,6 +63,7 @@ class _APIProtocol: ...
class BaseRequest(ABC, _APIProtocol, BaseModel):
"""Base request builder."""

CONTEXT_TYPE: ClassVar[type]
handle: ClassVar[handler]
from_data: ClassVar[serializer]
__event_data: dict = {}
Expand Down Expand Up @@ -113,38 +115,42 @@ async def handle_tcp(
container: AsyncContainer,
) -> AsyncIterator[BaseResponse]:
"""Hanlde response with tcp."""
kwargs = await resolve_deps(func=self.handle, container=container)
responses = []
if self.PROTOCOL_OP != ProtocolRequests.ABANDON:
ctx = await container.get(self.CONTEXT_TYPE) # type: ignore
else:
ctx = None

async for response in self.handle(**kwargs):
responses = []
async for response in self.handle(ctx=ctx):
responses.append(response)
yield response

ldap_session = await container.get(LDAPSession)
settings = await container.get(Settings)
audit_use_case = await container.get(AuditUseCase)

if await audit_use_case.check_event_processing_enabled(
self.PROTOCOL_OP,
):
username = getattr(
ldap_session.user,
"user_principal_name",
"ANONYMOUS",
)
event = RawAuditEventBuilderRedis.from_ldap_request(
self,
responses=responses,
username=username,
ip=ldap_session.ip,
protocol="TCP_LDAP",
settings=settings,
context=self.get_event_data(),
)
if self.PROTOCOL_OP != ProtocolRequests.SEARCH:
ldap_session = await container.get(LDAPSession)
settings = await container.get(Settings)
audit_use_case = await container.get(AuditUseCase)

if await audit_use_case.check_event_processing_enabled(
self.PROTOCOL_OP,
):
username = getattr(
ldap_session.user,
"user_principal_name",
"ANONYMOUS",
)
event = RawAuditEventBuilderRedis.from_ldap_request(
self,
responses=responses,
username=username,
ip=ldap_session.ip,
protocol="TCP_LDAP",
settings=settings,
context=self.get_event_data(),
)

ldap_session.event_task_group.create_task(
audit_use_case.manager.send_event(event),
)
ldap_session.event_task_group.create_task(
audit_use_case.manager.send_event(event),
)

async def _handle_api(
self,
Expand All @@ -156,7 +162,11 @@ async def _handle_api(
:param AsyncSession session: db session
:return list[BaseResponse]: list of handled responses
"""
kwargs = await resolve_deps(func=self.handle, container=container)
if self.PROTOCOL_OP != ProtocolRequests.ABANDON:
ctx = await container.get(self.CONTEXT_TYPE) # type: ignore
else:
ctx = None

ldap_session = await container.get(LDAPSession)
settings = await container.get(Settings)
audit_use_case = await container.get(AuditUseCase)
Expand All @@ -168,7 +178,7 @@ async def _handle_api(
else:
log_api.info(f"{get_class_name(self)}[{un}]")

responses = [response async for response in self.handle(**kwargs)]
responses = [response async for response in self.handle(ctx=ctx)]

if settings.DEBUG:
for response in responses:
Expand Down
Loading