From 562fb529934af35b40d9e162cdf0c167c4ad2a0a Mon Sep 17 00:00:00 2001 From: Jvst Me Date: Tue, 16 Dec 2025 13:28:42 +0100 Subject: [PATCH] Add more events about users and projects - User updated - User token refreshed - User SSH key refreshed - User deleted - Project updated - Project deleted Also refactor the implementation of the relevant operations on users to enable more detailed event messages and to avoid race conditions and longer write transactions. --- src/dstack/_internal/server/routers/runs.py | 4 +- src/dstack/_internal/server/routers/users.py | 9 +- .../_internal/server/services/projects.py | 19 +- src/dstack/_internal/server/services/users.py | 211 +++++++++++------- .../_internal/server/routers/test_projects.py | 10 + .../_internal/server/routers/test_users.py | 13 ++ 6 files changed, 177 insertions(+), 89 deletions(-) diff --git a/src/dstack/_internal/server/routers/runs.py b/src/dstack/_internal/server/routers/runs.py index 24baee9179..a4a09b3fb8 100644 --- a/src/dstack/_internal/server/routers/runs.py +++ b/src/dstack/_internal/server/routers/runs.py @@ -118,7 +118,7 @@ async def get_plan( """ user, project = user_project if not user.ssh_public_key and not body.run_spec.ssh_key_pub: - await users.refresh_ssh_key(session=session, user=user) + await users.refresh_ssh_key(session=session, actor=user) run_plan = await runs.get_plan( session=session, project=project, @@ -148,7 +148,7 @@ async def apply_plan( """ user, project = user_project if not user.ssh_public_key and not body.plan.run_spec.ssh_key_pub: - await users.refresh_ssh_key(session=session, user=user) + await users.refresh_ssh_key(session=session, actor=user) return CustomORJSONResponse( await runs.apply_plan( session=session, diff --git a/src/dstack/_internal/server/routers/users.py b/src/dstack/_internal/server/routers/users.py index 2568c6ac29..1feac5da36 100644 --- a/src/dstack/_internal/server/routers/users.py +++ b/src/dstack/_internal/server/routers/users.py @@ -43,7 +43,7 @@ async def get_my_user( ): if user.ssh_private_key is None or user.ssh_public_key is None: # Generate keys for pre-0.19.33 users - await users.refresh_ssh_key(session=session, user=user) + await users.refresh_ssh_key(session=session, actor=user) return CustomORJSONResponse(users.user_model_to_user_with_creds(user)) @@ -86,6 +86,7 @@ async def update_user( ): res = await users.update_user( session=session, + actor=user, username=body.username, global_role=body.global_role, email=body.email, @@ -102,7 +103,7 @@ async def refresh_ssh_key( session: AsyncSession = Depends(get_session), user: UserModel = Depends(Authenticated()), ): - res = await users.refresh_ssh_key(session=session, user=user, username=body.username) + res = await users.refresh_ssh_key(session=session, actor=user, username=body.username) if res is None: raise ResourceNotExistsError() return CustomORJSONResponse(users.user_model_to_user_with_creds(res)) @@ -114,7 +115,7 @@ async def refresh_token( session: AsyncSession = Depends(get_session), user: UserModel = Depends(Authenticated()), ): - res = await users.refresh_user_token(session=session, user=user, username=body.username) + res = await users.refresh_user_token(session=session, actor=user, username=body.username) if res is None: raise ResourceNotExistsError() return CustomORJSONResponse(users.user_model_to_user_with_creds(res)) @@ -128,6 +129,6 @@ async def delete_users( ): await users.delete_users( session=session, - user=user, + actor=user, usernames=body.users, ) diff --git a/src/dstack/_internal/server/services/projects.py b/src/dstack/_internal/server/services/projects.py index 2004b5cccd..5e4842df56 100644 --- a/src/dstack/_internal/server/services/projects.py +++ b/src/dstack/_internal/server/services/projects.py @@ -169,8 +169,16 @@ async def update_project( project: ProjectModel, is_public: bool, ): - """Update project visibility (public/private).""" - project.is_public = is_public + updated_fields = [] + if is_public != project.is_public: + project.is_public = is_public + updated_fields.append(f"is_public={is_public}") + events.emit( + session, + f"Project updated. Updated fields: {', '.join(updated_fields) or ''}", + actor=events.UserActor.from_user(user), + targets=[events.Target.from_model(project)], + ) await session.commit() @@ -222,9 +230,14 @@ async def delete_projects( "deleted": True, } ) + events.emit( + session, + "Project deleted", + actor=events.UserActor.from_user(user), + targets=[events.Target.from_model(p)], + ) await session.execute(update(ProjectModel), updates) await session.commit() - logger.info("Deleted projects %s by user %s", projects_names, user.name) async def set_project_members( diff --git a/src/dstack/_internal/server/services/users.py b/src/dstack/_internal/server/services/users.py index 62fcc848ea..e8fbcde782 100644 --- a/src/dstack/_internal/server/services/users.py +++ b/src/dstack/_internal/server/services/users.py @@ -3,14 +3,19 @@ import re import secrets import uuid +from collections.abc import AsyncGenerator +from contextlib import asynccontextmanager from typing import Awaitable, Callable, List, Optional, Tuple -from sqlalchemy import delete, select, update +from sqlalchemy import delete, select from sqlalchemy import func as safunc from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import load_only -from dstack._internal.core.errors import ResourceExistsError, ServerClientError +from dstack._internal.core.errors import ( + ResourceExistsError, + ServerClientError, +) from dstack._internal.core.models.users import ( GlobalRole, User, @@ -19,8 +24,10 @@ UserTokenCreds, UserWithCreds, ) +from dstack._internal.server.db import get_db from dstack._internal.server.models import DecryptedString, MemberModel, UserModel from dstack._internal.server.services import events +from dstack._internal.server.services.locking import get_locker from dstack._internal.server.services.permissions import get_default_permissions from dstack._internal.server.utils.routers import error_forbidden from dstack._internal.utils import crypto @@ -123,114 +130,128 @@ async def create_user( async def update_user( session: AsyncSession, + actor: UserModel, username: str, global_role: GlobalRole, email: Optional[str] = None, active: bool = True, -) -> UserModel: - await session.execute( - update(UserModel) - .where( - UserModel.name == username, - UserModel.deleted == False, - ) - .values( - global_role=global_role, - email=email, - active=active, +) -> Optional[UserModel]: + async with get_user_model_by_name_for_update(session, username) as user: + if user is None: + return None + updated_fields = [] + if global_role != user.global_role: + user.global_role = global_role + updated_fields.append(f"global_role={global_role}") + if email != user.email: + user.email = email + updated_fields.append("email") # do not include potentially sensitive new value + if active != user.active: + user.active = active + updated_fields.append(f"active={active}") + events.emit( + session, + f"User updated. Updated fields: {', '.join(updated_fields) or ''}", + actor=events.UserActor.from_user(actor), + targets=[events.Target.from_model(user)], ) - ) - await session.commit() - return await get_user_model_by_name_or_error(session=session, username=username) + await session.commit() + return user async def refresh_ssh_key( session: AsyncSession, - user: UserModel, + actor: UserModel, username: Optional[str] = None, ) -> Optional[UserModel]: if username is None: - username = user.name - logger.debug("Refreshing SSH key for user [code]%s[/code]", username) - if user.global_role != GlobalRole.ADMIN and user.name != username: + username = actor.name + if actor.global_role != GlobalRole.ADMIN and actor.name != username: raise error_forbidden() - private_bytes, public_bytes = await run_async(crypto.generate_rsa_key_pair_bytes, username) - await session.execute( - update(UserModel) - .where( - UserModel.name == username, - UserModel.deleted == False, - ) - .values( - ssh_private_key=private_bytes.decode(), - ssh_public_key=public_bytes.decode(), + async with get_user_model_by_name_for_update(session, username) as user: + if user is None: + return None + private_bytes, public_bytes = await run_async(crypto.generate_rsa_key_pair_bytes, username) + user.ssh_private_key = private_bytes.decode() + user.ssh_public_key = public_bytes.decode() + events.emit( + session, + "User SSH key refreshed", + actor=events.UserActor.from_user(actor), + targets=[events.Target.from_model(user)], ) - ) - await session.commit() - return await get_user_model_by_name(session=session, username=username) + await session.commit() + return user async def refresh_user_token( session: AsyncSession, - user: UserModel, + actor: UserModel, username: str, ) -> Optional[UserModel]: - if user.global_role != GlobalRole.ADMIN and user.name != username: + if actor.global_role != GlobalRole.ADMIN and actor.name != username: raise error_forbidden() - new_token = str(uuid.uuid4()) - await session.execute( - update(UserModel) - .where( - UserModel.name == username, - UserModel.deleted == False, - ) - .values( - token=DecryptedString(plaintext=new_token), - token_hash=get_token_hash(new_token), + async with get_user_model_by_name_for_update(session, username) as user: + if user is None: + return None + new_token = str(uuid.uuid4()) + user.token = DecryptedString(plaintext=new_token) + user.token_hash = get_token_hash(new_token) + events.emit( + session, + "User token refreshed", + actor=events.UserActor.from_user(actor), + targets=[events.Target.from_model(user)], ) - ) - await session.commit() - return await get_user_model_by_name(session=session, username=username) + await session.commit() + return user async def delete_users( session: AsyncSession, - user: UserModel, + actor: UserModel, usernames: List[str], ): if _ADMIN_USERNAME in usernames: - raise ServerClientError("User 'admin' cannot be deleted") - - res = await session.execute( - select(UserModel) - .where( - UserModel.name.in_(usernames), - UserModel.deleted == False, - ) - .options(load_only(UserModel.id, UserModel.name)) - ) - users = res.scalars().all() - if len(users) != len(usernames): - raise ServerClientError("Failed to delete non-existent users") - - user_ids = [u.id for u in users] - timestamp = str(int(get_current_datetime().timestamp())) - updates = [] - for u in users: - updates.append( - { - "id": u.id, - "name": f"_deleted_{timestamp}_{secrets.token_hex(8)}", - "original_name": u.name, - "deleted": True, - "active": False, - } + raise ServerClientError(f"User {_ADMIN_USERNAME!r} cannot be deleted") + + filters = [ + UserModel.name.in_(usernames), + UserModel.deleted == False, + ] + res = await session.execute(select(UserModel.id).where(*filters)) + user_ids = list(res.scalars().all()) + user_ids.sort() + + async with get_locker(get_db().dialect_name).lock_ctx(UserModel.__tablename__, user_ids): + # Refetch after lock + res = await session.execute( + select(UserModel) + .where(UserModel.id.in_(user_ids), *filters) + .order_by(UserModel.id) # take locks in order + .options(load_only(UserModel.id, UserModel.name)) + .with_for_update(key_share=True) ) - await session.execute(update(UserModel), updates) - await session.execute(delete(MemberModel).where(MemberModel.user_id.in_(user_ids))) - # Projects are not deleted automatically if owners are deleted. - await session.commit() - logger.info("Deleted users %s by user %s", usernames, user.name) + users = list(res.scalars().all()) + if len(users) != len(usernames): + raise ServerClientError("Failed to delete non-existent users") + user_ids = [u.id for u in users] + timestamp = str(int(get_current_datetime().timestamp())) + for u in users: + event_target = events.Target.from_model(u) # build target before renaming the user + u.deleted = True + u.active = False + u.original_name = u.name + u.name = f"_deleted_{timestamp}_{secrets.token_hex(8)}" + events.emit( + session, + "User deleted", + actor=events.UserActor.from_user(actor), + targets=[event_target], + ) + await session.execute(delete(MemberModel).where(MemberModel.user_id.in_(user_ids))) + # Projects are not deleted automatically if owners are deleted. + await session.commit() async def get_user_model_by_name( @@ -257,6 +278,36 @@ async def get_user_model_by_name_or_error( ) +@asynccontextmanager +async def get_user_model_by_name_for_update( + session: AsyncSession, username: str +) -> AsyncGenerator[Optional[UserModel], None]: + """ + Fetch the user from the database and lock it for update. + + **NOTE**: commit changes to the database before exiting from this context manager, + so that in-memory locks are only released after commit. + """ + + filters = [ + UserModel.name == username, + UserModel.deleted == False, + ] + res = await session.execute(select(UserModel.id).where(*filters)) + user_id = res.scalar_one_or_none() + if user_id is None: + yield None + else: + async with get_locker(get_db().dialect_name).lock_ctx(UserModel.__tablename__, [user_id]): + # Refetch after lock + res = await session.execute( + select(UserModel) + .where(UserModel.id.in_([user_id]), *filters) + .with_for_update(key_share=True) + ) + yield res.scalar_one_or_none() + + async def log_in_with_token(session: AsyncSession, token: str) -> Optional[UserModel]: token_hash = get_token_hash(token) res = await session.execute( diff --git a/src/tests/_internal/server/routers/test_projects.py b/src/tests/_internal/server/routers/test_projects.py index 8e21957f5e..826ecbc096 100644 --- a/src/tests/_internal/server/routers/test_projects.py +++ b/src/tests/_internal/server/routers/test_projects.py @@ -495,6 +495,16 @@ async def test_deletes_projects( await session.refresh(project2) assert project1.deleted assert not project2.deleted + # Validate an event is emitted + response = await client.post( + "/api/events/list", headers=get_auth_headers(user.token), json={} + ) + assert response.status_code == 200 + assert len(response.json()) == 1 + assert response.json()[0]["message"] == "Project deleted" + assert len(response.json()[0]["targets"]) == 1 + assert response.json()[0]["targets"][0]["id"] == str(project1.id) + assert response.json()[0]["targets"][0]["name"] == project_name @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) diff --git a/src/tests/_internal/server/routers/test_users.py b/src/tests/_internal/server/routers/test_users.py index 8b8c7ca2a6..6c5b373a63 100644 --- a/src/tests/_internal/server/routers/test_users.py +++ b/src/tests/_internal/server/routers/test_users.py @@ -392,9 +392,22 @@ async def test_deletes_users( json={"users": [user.name]}, ) assert response.status_code == 200 + + # Validate the user is deleted res = await session.execute(select(UserModel).where(UserModel.name == user.name)) assert len(res.scalars().all()) == 0 + # Validate an event is emitted + response = await client.post( + "/api/events/list", headers=get_auth_headers(admin.token), json={} + ) + assert response.status_code == 200 + assert len(response.json()) == 1 + assert response.json()[0]["message"] == "User deleted" + assert len(response.json()[0]["targets"]) == 1 + assert response.json()[0]["targets"][0]["id"] == str(user.id) + assert response.json()[0]["targets"][0]["name"] == user.name + @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) async def test_returns_400_if_users_not_exist(