Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions app/extra/scripts/check_ldap_principal.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@
)
from ldap_protocol.utils.queries import get_base_directories

from .task_base import task_metadata


@task_metadata(repeat=0, one_time=True, global_=True)
async def check_ldap_principal(
kadmin: AbstractKadmin,
session: AsyncSession,
Expand Down
3 changes: 3 additions & 0 deletions app/extra/scripts/krb_pass_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,13 @@
from models import User
from security import get_password_hash

from .task_base import task_metadata

_LOCK_FILE = ".lock"
_PATH = "/var/spool/krb5-sync"


@task_metadata(repeat=1.5)
async def read_and_save_krb_pwds(session: AsyncSession) -> None:
"""Process file queue with lock.

Expand Down
3 changes: 3 additions & 0 deletions app/extra/scripts/principal_block_user_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@
)
from models import Attribute, Directory, User

from .task_base import task_metadata


@task_metadata(repeat=60.0, global_=True)
async def principal_block_sync(
session: AsyncSession, settings: Settings,
) -> None:
Expand Down
58 changes: 58 additions & 0 deletions app/extra/scripts/task_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
"""Task utils.

Copyright (c) 2024 MultiFactor
License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE
"""
from __future__ import annotations

from functools import wraps
from typing import Any, Awaitable, Callable, Coroutine, Generic, TypeVar

from redis.asyncio import Redis

T = TypeVar("T", bound=Callable[..., Awaitable | Coroutine])


class Task(Generic[T]):
"""Task."""

def __init__(
self,
f: T,
repeat: float,
one_time: bool,
global_: bool,
) -> None:
"""Init.

:param T f: function
:param int repeat: repeat time
:param bool one_time: flag to run task only once
:param bool global_: flag to run task in lock mode
"""
self.f = f
self.repeat = repeat
self.one_time = one_time
self.global_ = global_

async def __call__(self, storage: Redis) -> None:
"""Call."""
if self.global_:
async with storage.lock(self.f.__name__):
await self.f()
else:
await self.f()


def task_metadata(
repeat: float,
one_time: bool = False,
global_: bool = False,
) -> Callable[[T], Task[T]]:
"""Decorate a Task."""
def decorator(f: T) -> Task[T]:
@wraps(f)
def wrapper(*args: Any, **kwargs: Any) -> Task[T]:
return Task(f, repeat, one_time, global_)
return wrapper
return decorator
3 changes: 3 additions & 0 deletions app/extra/scripts/uac_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@
from ldap_protocol.utils.queries import add_lock_and_expire_attributes
from models import Attribute, User

from .task_base import task_metadata


@task_metadata(repeat=600.0, global_=True)
async def disable_accounts(
session: AsyncSession, kadmin: AbstractKadmin, settings: Settings,
) -> None:
Expand Down
3 changes: 3 additions & 0 deletions app/extra/scripts/update_krb5_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
from ldap_protocol.kerberos import AbstractKadmin
from ldap_protocol.utils.queries import get_base_directories

from .task_base import task_metadata


@task_metadata(repeat=0.0, one_time=True, global_=True)
async def update_krb5_config(
kadmin: AbstractKadmin,
session: AsyncSession,
Expand Down
27 changes: 13 additions & 14 deletions app/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,23 @@
from extra.scripts.check_ldap_principal import check_ldap_principal
from extra.scripts.krb_pass_sync import read_and_save_krb_pwds
from extra.scripts.principal_block_user_sync import principal_block_sync
from extra.scripts.task_base import Task
from extra.scripts.uac_sync import disable_accounts
from extra.scripts.update_krb5_config import update_krb5_config
from ioc import MainProvider
from ldap_protocol.dependency import resolve_deps

task_type: TypeAlias = Callable[..., Coroutine]

_TASKS: set[tuple[task_type, float]] = {
(read_and_save_krb_pwds, 1.5),
(disable_accounts, 600.0),
(principal_block_sync, 60.0),
(check_ldap_principal, -1.0),
(update_krb5_config, -1.0),
_TASKS: set[Task] = {
read_and_save_krb_pwds,
disable_accounts,
principal_block_sync,
check_ldap_principal,
update_krb5_config,
}


async def _schedule(
task: task_type,
task: Task,
wait: float,
container: AsyncContainer,
) -> None:
Expand All @@ -38,14 +37,14 @@ async def _schedule(
:param AsyncContainer container: container
:param float wait: time to wait after execution
"""
logger.info("Registered: {}", task.__name__)
logger.info("Registered: {}", task.f.__name__)
while True:
async with container(scope=Scope.REQUEST) as ctnr:
handler = await resolve_deps(func=task, container=ctnr)
handler = await resolve_deps(func=task.f, container=ctnr)
await handler()

# NOTE: one-time tasks
if wait < 0:
if wait < 0.0:
break

await asyncio.sleep(wait)
Expand All @@ -59,8 +58,8 @@ async def runner(settings: Settings) -> None:
)

async with asyncio.TaskGroup() as tg:
for task, timeout in _TASKS:
tg.create_task(_schedule(task, timeout, container))
for task in _TASKS:
tg.create_task(_schedule(task, task.repeat, container))

def _run() -> None:
uvloop.run(runner(settings))
Expand Down
Loading