Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/dstack/_internal/server/routers/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ async def get_plan(
"""
user, project = user_project
if not user.ssh_public_key and not body.run_spec.ssh_key_pub:
await users.refresh_ssh_key(session=session, user=user)
await users.refresh_ssh_key(session=session, actor=user)
run_plan = await runs.get_plan(
session=session,
project=project,
Expand Down Expand Up @@ -148,7 +148,7 @@ async def apply_plan(
"""
user, project = user_project
if not user.ssh_public_key and not body.plan.run_spec.ssh_key_pub:
await users.refresh_ssh_key(session=session, user=user)
await users.refresh_ssh_key(session=session, actor=user)
return CustomORJSONResponse(
await runs.apply_plan(
session=session,
Expand Down
9 changes: 5 additions & 4 deletions src/dstack/_internal/server/routers/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ async def get_my_user(
):
if user.ssh_private_key is None or user.ssh_public_key is None:
# Generate keys for pre-0.19.33 users
await users.refresh_ssh_key(session=session, user=user)
await users.refresh_ssh_key(session=session, actor=user)
return CustomORJSONResponse(users.user_model_to_user_with_creds(user))


Expand Down Expand Up @@ -86,6 +86,7 @@ async def update_user(
):
res = await users.update_user(
session=session,
actor=user,
username=body.username,
global_role=body.global_role,
email=body.email,
Expand All @@ -102,7 +103,7 @@ async def refresh_ssh_key(
session: AsyncSession = Depends(get_session),
user: UserModel = Depends(Authenticated()),
):
res = await users.refresh_ssh_key(session=session, user=user, username=body.username)
res = await users.refresh_ssh_key(session=session, actor=user, username=body.username)
if res is None:
raise ResourceNotExistsError()
return CustomORJSONResponse(users.user_model_to_user_with_creds(res))
Expand All @@ -114,7 +115,7 @@ async def refresh_token(
session: AsyncSession = Depends(get_session),
user: UserModel = Depends(Authenticated()),
):
res = await users.refresh_user_token(session=session, user=user, username=body.username)
res = await users.refresh_user_token(session=session, actor=user, username=body.username)
if res is None:
raise ResourceNotExistsError()
return CustomORJSONResponse(users.user_model_to_user_with_creds(res))
Expand All @@ -128,6 +129,6 @@ async def delete_users(
):
await users.delete_users(
session=session,
user=user,
actor=user,
usernames=body.users,
)
19 changes: 16 additions & 3 deletions src/dstack/_internal/server/services/projects.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,16 @@ async def update_project(
project: ProjectModel,
is_public: bool,
):
"""Update project visibility (public/private)."""
project.is_public = is_public
updated_fields = []
if is_public != project.is_public:
project.is_public = is_public
updated_fields.append(f"is_public={is_public}")
events.emit(
session,
f"Project updated. Updated fields: {', '.join(updated_fields) or '<none>'}",
actor=events.UserActor.from_user(user),
targets=[events.Target.from_model(project)],
)
await session.commit()


Expand Down Expand Up @@ -222,9 +230,14 @@ async def delete_projects(
"deleted": True,
}
)
events.emit(
session,
"Project deleted",
actor=events.UserActor.from_user(user),
targets=[events.Target.from_model(p)],
)
await session.execute(update(ProjectModel), updates)
await session.commit()
logger.info("Deleted projects %s by user %s", projects_names, user.name)


async def set_project_members(
Expand Down
211 changes: 131 additions & 80 deletions src/dstack/_internal/server/services/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,19 @@
import re
import secrets
import uuid
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from typing import Awaitable, Callable, List, Optional, Tuple

from sqlalchemy import delete, select, update
from sqlalchemy import delete, select
from sqlalchemy import func as safunc
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import load_only

from dstack._internal.core.errors import ResourceExistsError, ServerClientError
from dstack._internal.core.errors import (
ResourceExistsError,
ServerClientError,
)
from dstack._internal.core.models.users import (
GlobalRole,
User,
Expand All @@ -19,8 +24,10 @@
UserTokenCreds,
UserWithCreds,
)
from dstack._internal.server.db import get_db
from dstack._internal.server.models import DecryptedString, MemberModel, UserModel
from dstack._internal.server.services import events
from dstack._internal.server.services.locking import get_locker
from dstack._internal.server.services.permissions import get_default_permissions
from dstack._internal.server.utils.routers import error_forbidden
from dstack._internal.utils import crypto
Expand Down Expand Up @@ -123,114 +130,128 @@ async def create_user(

async def update_user(
session: AsyncSession,
actor: UserModel,
username: str,
global_role: GlobalRole,
email: Optional[str] = None,
active: bool = True,
) -> UserModel:
await session.execute(
update(UserModel)
.where(
UserModel.name == username,
UserModel.deleted == False,
)
.values(
global_role=global_role,
email=email,
active=active,
) -> Optional[UserModel]:
async with get_user_model_by_name_for_update(session, username) as user:
if user is None:
return None
updated_fields = []
if global_role != user.global_role:
user.global_role = global_role
updated_fields.append(f"global_role={global_role}")
if email != user.email:
user.email = email
updated_fields.append("email") # do not include potentially sensitive new value
if active != user.active:
user.active = active
updated_fields.append(f"active={active}")
events.emit(
session,
f"User updated. Updated fields: {', '.join(updated_fields) or '<none>'}",
actor=events.UserActor.from_user(actor),
targets=[events.Target.from_model(user)],
)
)
await session.commit()
return await get_user_model_by_name_or_error(session=session, username=username)
await session.commit()
return user


async def refresh_ssh_key(
session: AsyncSession,
user: UserModel,
actor: UserModel,
username: Optional[str] = None,
) -> Optional[UserModel]:
if username is None:
username = user.name
logger.debug("Refreshing SSH key for user [code]%s[/code]", username)
if user.global_role != GlobalRole.ADMIN and user.name != username:
username = actor.name
if actor.global_role != GlobalRole.ADMIN and actor.name != username:
raise error_forbidden()
private_bytes, public_bytes = await run_async(crypto.generate_rsa_key_pair_bytes, username)
await session.execute(
update(UserModel)
.where(
UserModel.name == username,
UserModel.deleted == False,
)
.values(
ssh_private_key=private_bytes.decode(),
ssh_public_key=public_bytes.decode(),
async with get_user_model_by_name_for_update(session, username) as user:
if user is None:
return None
private_bytes, public_bytes = await run_async(crypto.generate_rsa_key_pair_bytes, username)
user.ssh_private_key = private_bytes.decode()
user.ssh_public_key = public_bytes.decode()
events.emit(
session,
"User SSH key refreshed",
actor=events.UserActor.from_user(actor),
targets=[events.Target.from_model(user)],
)
)
await session.commit()
return await get_user_model_by_name(session=session, username=username)
await session.commit()
return user


async def refresh_user_token(
session: AsyncSession,
user: UserModel,
actor: UserModel,
username: str,
) -> Optional[UserModel]:
if user.global_role != GlobalRole.ADMIN and user.name != username:
if actor.global_role != GlobalRole.ADMIN and actor.name != username:
raise error_forbidden()
new_token = str(uuid.uuid4())
await session.execute(
update(UserModel)
.where(
UserModel.name == username,
UserModel.deleted == False,
)
.values(
token=DecryptedString(plaintext=new_token),
token_hash=get_token_hash(new_token),
async with get_user_model_by_name_for_update(session, username) as user:
if user is None:
return None
new_token = str(uuid.uuid4())
user.token = DecryptedString(plaintext=new_token)
user.token_hash = get_token_hash(new_token)
events.emit(
session,
"User token refreshed",
actor=events.UserActor.from_user(actor),
targets=[events.Target.from_model(user)],
)
)
await session.commit()
return await get_user_model_by_name(session=session, username=username)
await session.commit()
return user


async def delete_users(
session: AsyncSession,
user: UserModel,
actor: UserModel,
usernames: List[str],
):
if _ADMIN_USERNAME in usernames:
raise ServerClientError("User 'admin' cannot be deleted")

res = await session.execute(
select(UserModel)
.where(
UserModel.name.in_(usernames),
UserModel.deleted == False,
)
.options(load_only(UserModel.id, UserModel.name))
)
users = res.scalars().all()
if len(users) != len(usernames):
raise ServerClientError("Failed to delete non-existent users")

user_ids = [u.id for u in users]
timestamp = str(int(get_current_datetime().timestamp()))
updates = []
for u in users:
updates.append(
{
"id": u.id,
"name": f"_deleted_{timestamp}_{secrets.token_hex(8)}",
"original_name": u.name,
"deleted": True,
"active": False,
}
raise ServerClientError(f"User {_ADMIN_USERNAME!r} cannot be deleted")

filters = [
UserModel.name.in_(usernames),
UserModel.deleted == False,
]
res = await session.execute(select(UserModel.id).where(*filters))
user_ids = list(res.scalars().all())
user_ids.sort()

async with get_locker(get_db().dialect_name).lock_ctx(UserModel.__tablename__, user_ids):
# Refetch after lock
res = await session.execute(
select(UserModel)
.where(UserModel.id.in_(user_ids), *filters)
.order_by(UserModel.id) # take locks in order
.options(load_only(UserModel.id, UserModel.name))
.with_for_update(key_share=True)
)
await session.execute(update(UserModel), updates)
await session.execute(delete(MemberModel).where(MemberModel.user_id.in_(user_ids)))
# Projects are not deleted automatically if owners are deleted.
await session.commit()
logger.info("Deleted users %s by user %s", usernames, user.name)
users = list(res.scalars().all())
if len(users) != len(usernames):
raise ServerClientError("Failed to delete non-existent users")
user_ids = [u.id for u in users]
timestamp = str(int(get_current_datetime().timestamp()))
for u in users:
event_target = events.Target.from_model(u) # build target before renaming the user
u.deleted = True
u.active = False
u.original_name = u.name
u.name = f"_deleted_{timestamp}_{secrets.token_hex(8)}"
events.emit(
session,
"User deleted",
actor=events.UserActor.from_user(actor),
targets=[event_target],
)
await session.execute(delete(MemberModel).where(MemberModel.user_id.in_(user_ids)))
# Projects are not deleted automatically if owners are deleted.
await session.commit()


async def get_user_model_by_name(
Expand All @@ -257,6 +278,36 @@ async def get_user_model_by_name_or_error(
)


@asynccontextmanager
async def get_user_model_by_name_for_update(
session: AsyncSession, username: str
) -> AsyncGenerator[Optional[UserModel], None]:
"""
Fetch the user from the database and lock it for update.

**NOTE**: commit changes to the database before exiting from this context manager,
so that in-memory locks are only released after commit.
"""

filters = [
UserModel.name == username,
UserModel.deleted == False,
]
res = await session.execute(select(UserModel.id).where(*filters))
user_id = res.scalar_one_or_none()
if user_id is None:
yield None
else:
async with get_locker(get_db().dialect_name).lock_ctx(UserModel.__tablename__, [user_id]):
# Refetch after lock
res = await session.execute(
select(UserModel)
.where(UserModel.id.in_([user_id]), *filters)
.with_for_update(key_share=True)
)
yield res.scalar_one_or_none()


async def log_in_with_token(session: AsyncSession, token: str) -> Optional[UserModel]:
token_hash = get_token_hash(token)
res = await session.execute(
Expand Down
10 changes: 10 additions & 0 deletions src/tests/_internal/server/routers/test_projects.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,16 @@ async def test_deletes_projects(
await session.refresh(project2)
assert project1.deleted
assert not project2.deleted
# Validate an event is emitted
response = await client.post(
"/api/events/list", headers=get_auth_headers(user.token), json={}
)
assert response.status_code == 200
assert len(response.json()) == 1
assert response.json()[0]["message"] == "Project deleted"
assert len(response.json()[0]["targets"]) == 1
assert response.json()[0]["targets"][0]["id"] == str(project1.id)
assert response.json()[0]["targets"][0]["name"] == project_name

@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
Expand Down
Loading