Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/docs/reference/api/python/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ finally:
!!! info "NOTE:"
1. The `configuration` argument in the `apply_configuration` method can be either `dstack.api.Task`, `dstack.api.Service`, or `dstack.api.DevEnvironment`.
2. When you create `dstack.api.Task`, `dstack.api.Service`, or `dstack.api.DevEnvironment`, you can specify the `image` argument. If `image` isn't specified, the default image will be used. For a private Docker registry, ensure you also pass the `registry_auth` argument.
3. The `repo` argument in the `apply_configuration` method allows the mounting of a local folder, a remote repo, or a
3. The `repo` argument in the `apply_configuration` method allows the mounting of a remote repo or a
programmatically created repo. In this case, the `commands` argument can refer to the files within this repo.
4. The `attach` method waits for the run to start and, for `dstack.api.Task` sets up an SSH tunnel and forwards
configured `ports` to `localhost`.
Expand Down
19 changes: 8 additions & 11 deletions src/dstack/_internal/core/services/ssh/key_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
from dataclasses import dataclass
from datetime import datetime, timedelta
from pathlib import Path
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING

from dstack._internal.core.models.users import UserWithCreds
from dstack._internal.core.errors import ClientError

if TYPE_CHECKING:
from dstack.api.server import APIClient
Expand All @@ -24,26 +24,25 @@ def __init__(self, api_client: "APIClient", ssh_keys_dir: Path) -> None:
self._key_path = ssh_keys_dir / api_client.get_token_hash()
self._pub_key_path = self._key_path.with_suffix(".pub")

def get_user_key(self) -> Optional[UserSSHKey]:
def get_user_key(self) -> UserSSHKey:
"""
Return the up-to-date user key, or None if the user has no key (if created before 0.19.33)
Return the up-to-date user key
"""
if (
not self._key_path.exists()
or not self._pub_key_path.exists()
or datetime.now() - datetime.fromtimestamp(self._key_path.stat().st_mtime)
> KEY_REFRESH_RATE
):
if not self._download_user_key():
return None
self._download_user_key()
return UserSSHKey(
public_key=self._pub_key_path.read_text(), private_key_path=self._key_path
)

def _download_user_key(self) -> bool:
def _download_user_key(self) -> None:
user = self._api_client.users.get_my_user()
if not (isinstance(user, UserWithCreds) and user.ssh_public_key and user.ssh_private_key):
return False
if user.ssh_private_key is None or user.ssh_public_key is None:
raise ClientError("Server response does not contain user SSH key")

def key_opener(path, flags):
return os.open(path, flags, 0o600)
Expand All @@ -52,5 +51,3 @@ def key_opener(path, flags):
f.write(user.ssh_private_key)
with open(self._pub_key_path, "w") as f:
f.write(user.ssh_public_key)

return True
4 changes: 2 additions & 2 deletions src/dstack/_internal/server/routers/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ async def get_plan(
"""
user, project = user_project
if not user.ssh_public_key and not body.run_spec.ssh_key_pub:
await users.refresh_ssh_key(session=session, user=user, username=user.name)
await users.refresh_ssh_key(session=session, user=user)
run_plan = await runs.get_plan(
session=session,
project=project,
Expand Down Expand Up @@ -148,7 +148,7 @@ async def apply_plan(
"""
user, project = user_project
if not user.ssh_public_key and not body.plan.run_spec.ssh_key_pub:
await users.refresh_ssh_key(session=session, user=user, username=user.name)
await users.refresh_ssh_key(session=session, user=user)
return CustomORJSONResponse(
await runs.apply_plan(
session=session,
Expand Down
5 changes: 1 addition & 4 deletions src/dstack/_internal/server/routers/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,7 @@ async def get_my_user(
):
if user.ssh_private_key is None or user.ssh_public_key is None:
# Generate keys for pre-0.19.33 users
updated_user = await users.refresh_ssh_key(session=session, user=user, username=user.name)
if updated_user is None:
raise ResourceNotExistsError()
user = updated_user
await users.refresh_ssh_key(session=session, user=user)
return CustomORJSONResponse(users.user_model_to_user_with_creds(user))


Expand Down
4 changes: 3 additions & 1 deletion src/dstack/_internal/server/services/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,10 @@ async def update_user(
async def refresh_ssh_key(
session: AsyncSession,
user: UserModel,
username: str,
username: Optional[str] = None,
) -> Optional[UserModel]:
if username is None:
username = user.name
logger.debug("Refreshing SSH key for user [code]%s[/code]", username)
if user.global_role != GlobalRole.ADMIN and user.name != username:
raise error_forbidden()
Expand Down
3 changes: 2 additions & 1 deletion src/dstack/_internal/server/testing/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,7 @@ def get_run_spec(
configuration_path: str = "dstack.yaml",
profile: Union[Profile, Callable[[], Profile], None] = lambda: Profile(name="default"),
configuration: Optional[AnyRunConfiguration] = None,
ssh_key_pub: Optional[str] = "user_ssh_key",
) -> RunSpec:
if callable(profile):
profile = profile()
Expand All @@ -288,7 +289,7 @@ def get_run_spec(
configuration_path=configuration_path,
configuration=configuration or DevEnvironmentConfiguration(ide="vscode"),
profile=profile,
ssh_key_pub="user_ssh_key",
ssh_key_pub=ssh_key_pub,
)


Expand Down
5 changes: 5 additions & 0 deletions src/dstack/api/_public/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import enum


class Deprecated(enum.Enum):
PLACEHOLDER = "DEPRECATED"
47 changes: 22 additions & 25 deletions src/dstack/api/_public/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from copy import copy
from datetime import datetime
from pathlib import Path
from typing import BinaryIO, Dict, Iterable, List, Optional
from typing import BinaryIO, Dict, Iterable, List, Optional, Union
from urllib.parse import urlencode, urlparse

from websocket import WebSocketApp
Expand Down Expand Up @@ -46,10 +46,10 @@
from dstack._internal.core.services.ssh.ports import PortsLock
from dstack._internal.server.schemas.logs import PollLogsRequest
from dstack._internal.utils.common import get_or_error, make_proxy_url
from dstack._internal.utils.crypto import generate_rsa_key_pair
from dstack._internal.utils.files import create_file_archive
from dstack._internal.utils.logging import get_logger
from dstack._internal.utils.path import PathLike
from dstack.api._public.common import Deprecated
from dstack.api.server import APIClient

logger = get_logger(__name__)
Expand Down Expand Up @@ -278,13 +278,11 @@ def attach(
if not ssh_identity_file:
config_manager = ConfigManager()
key_manager = UserSSHKeyManager(self._api_client, config_manager.dstack_ssh_dir)
if (
user_key := key_manager.get_user_key()
) and user_key.public_key == self._run.run_spec.ssh_key_pub:
user_key = key_manager.get_user_key()
if user_key.public_key == self._run.run_spec.ssh_key_pub:
ssh_identity_file = user_key.private_key_path
else:
if config_manager.dstack_key_path.exists():
# TODO: Remove since 0.19.40
logger.debug(f"Using legacy [code]{config_manager.dstack_key_path}[/code].")
ssh_identity_file = config_manager.dstack_key_path
else:
Expand Down Expand Up @@ -451,7 +449,7 @@ def get_run_plan(
repo: Optional[Repo] = None,
profile: Optional[Profile] = None,
configuration_path: Optional[str] = None,
repo_dir: Optional[str] = None,
repo_dir: Union[Deprecated, str, None] = Deprecated.PLACEHOLDER,
ssh_identity_file: Optional[PathLike] = None,
) -> RunPlan:
"""
Expand All @@ -465,9 +463,10 @@ def get_run_plan(
profile: The profile to use for the run.
configuration_path: The path to the configuration file. Omit if the configuration
is not loaded from a file.
repo_dir: The path of the cloned repo inside the run container. If not set,
defaults first to the `repos[0].path` property of the configuration (for remote
repos only).
ssh_identity_file: Path to the private SSH key file. The corresponding public key
(`.pub` file) is read and included in the run plan, allowing SSH access to the instances.
If the `.pub` file does not exist, it is generated automatically.
If ssh_identity_file is not specified, the user key is used.

Returns:
Run plan.
Expand All @@ -479,8 +478,15 @@ def get_run_plan(
with _prepare_code_file(repo) as (_, repo_code_hash):
pass

if repo_dir is None and configuration.repos:
if repo_dir is not Deprecated.PLACEHOLDER:
logger.warning(
"The repo_dir argument is deprecated, ignored, and will be removed soon."
" Remove it and use the repos[].path configuration property instead."
)
if configuration.repos:
repo_dir = configuration.repos[0].path
else:
repo_dir = None

self._validate_configuration_files(configuration, configuration_path)
file_archives: list[FileArchiveMapping] = []
Expand All @@ -497,20 +503,7 @@ def get_run_plan(
if ssh_identity_file:
ssh_key_pub = Path(ssh_identity_file).with_suffix(".pub").read_text()
else:
config_manager = ConfigManager()
key_manager = UserSSHKeyManager(self._api_client, config_manager.dstack_ssh_dir)
if key_manager.get_user_key():
ssh_key_pub = None # using the server-managed user key
else:
if not config_manager.dstack_key_path.exists():
generate_rsa_key_pair(private_key_path=config_manager.dstack_key_path)
logger.warning(
f"Using legacy [code]{config_manager.dstack_key_path.with_suffix('.pub')}[/code]."
" You will only be able to attach to the run from this client."
" Update the [code]dstack[/] server to [code]0.19.34[/]+ to switch to user keys"
" automatically replicated to all clients.",
)
ssh_key_pub = config_manager.dstack_key_path.with_suffix(".pub").read_text()
ssh_key_pub = None # using the server-managed user key
run_spec = RunSpec(
run_name=configuration.name,
repo_id=repo.repo_id,
Expand Down Expand Up @@ -587,6 +580,10 @@ def apply_configuration(
profile: The profile to use for the run.
configuration_path: The path to the configuration file. Omit if the configuration is not loaded from a file.
reserve_ports: Reserve local ports before applying. Use if you'll attach to the run.
ssh_identity_file: Path to the private SSH key file. The corresponding public key
(`.pub` file) is read and included in the run plan, allowing SSH access to the instances.
If the `.pub` file does not exist, it is generated automatically.
If ssh_identity_file is not specified, the user key is used.

Returns:
Submitted run.
Expand Down
21 changes: 3 additions & 18 deletions src/dstack/api/server/_users.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import List

from pydantic import ValidationError, parse_obj_as
from pydantic import parse_obj_as

from dstack._internal.core.models.users import GlobalRole, User, UserWithCreds
from dstack._internal.server.schemas.users import (
Expand All @@ -17,24 +17,9 @@ def list(self) -> List[User]:
resp = self._request("/api/users/list")
return parse_obj_as(List[User.__response__], resp.json())

def get_my_user(self) -> User:
"""
Returns `User` with pre-0.19.33 servers, or `UserWithCreds` with newer servers.
"""

def get_my_user(self) -> UserWithCreds:
resp = self._request("/api/users/get_my_user")
try:
return parse_obj_as(UserWithCreds.__response__, resp.json())
except ValidationError as e:
# Compatibility with pre-0.19.33 server
if (
len(e.errors()) == 1
and e.errors()[0]["loc"] == ("__root__", "creds")
and e.errors()[0]["type"] == "value_error.missing"
):
return parse_obj_as(User.__response__, resp.json())
else:
raise
return parse_obj_as(UserWithCreds.__response__, resp.json())

def get_user(self, username: str) -> User:
body = GetUserRequest(username=username)
Expand Down
11 changes: 0 additions & 11 deletions src/tests/_internal/core/services/ssh/test_key_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,6 @@ def set_mtime(path: Path, ts: float):
os.utime(path, (ts, ts))


def test_get_user_key_returns_none_when_no_user_creds(tmp_path: Path):
api_client = make_api_client(
user=User.__response__.parse_obj(SAMPLE_USER.dict()), token_hash=SAMPLE_USER_TOKEN_HASH
)
manager = UserSSHKeyManager(api_client, tmp_path)

assert manager.get_user_key() is None
assert not (tmp_path / SAMPLE_USER_TOKEN_HASH).exists()
assert not (tmp_path / f"{SAMPLE_USER_TOKEN_HASH}.pub").exists()


def test_get_user_key_downloads_keys(tmp_path: Path):
api_client = make_api_client(user=SAMPLE_USER, token_hash=SAMPLE_USER_TOKEN_HASH)
manager = UserSSHKeyManager(api_client, tmp_path)
Expand Down
58 changes: 58 additions & 0 deletions src/tests/_internal/server/routers/test_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1364,6 +1364,32 @@ async def test_returns_update_or_create_action_on_conf_change(
assert response_json["action"] == action
assert response_json["current_resource"] == json.loads(run.json())

@pytest.mark.asyncio
@pytest.mark.usefixtures("test_db")
async def test_generates_user_ssh_key(self, session: AsyncSession, client: AsyncClient):
user = await create_user(
session=session, global_role=GlobalRole.USER, ssh_public_key=None, ssh_private_key=None
)
project = await create_project(session=session, owner=user)
await add_project_member(
session=session, project=project, user=user, project_role=ProjectRole.USER
)
repo = await create_repo(session=session, project_id=project.id)
run_spec = get_run_spec(run_name="test-run", repo_id=repo.name, ssh_key_pub=None)

response = await client.post(
f"/api/project/{project.name}/runs/get_plan",
headers=get_auth_headers(user.token),
json={"run_spec": run_spec.dict()},
)

assert response.status_code == 200, response.json()
run_spec_ssh_public_key = response.json()["effective_run_spec"]["ssh_key_pub"]
assert run_spec_ssh_public_key is not None
await session.refresh(user)
assert user.ssh_public_key == run_spec_ssh_public_key
assert user.ssh_private_key is not None


class TestApplyPlan:
@pytest.mark.asyncio
Expand Down Expand Up @@ -1517,6 +1543,38 @@ async def test_creates_pending_run_if_run_is_scheduled(
assert run.status == RunStatus.PENDING
assert run.next_triggered_at == datetime(2023, 1, 2, 3, 10, tzinfo=timezone.utc)

@pytest.mark.asyncio
@pytest.mark.usefixtures("test_db")
async def test_generates_user_ssh_key(self, session: AsyncSession, client: AsyncClient):
user = await create_user(
session=session, global_role=GlobalRole.USER, ssh_public_key=None, ssh_private_key=None
)
project = await create_project(session=session, owner=user)
await add_project_member(
session=session, project=project, user=user, project_role=ProjectRole.USER
)
repo = await create_repo(session=session, project_id=project.id)
run_spec = get_run_spec(run_name="test-run", repo_id=repo.name, ssh_key_pub=None)

response = await client.post(
f"/api/project/{project.name}/runs/apply",
headers=get_auth_headers(user.token),
json={
"plan": {
"run_spec": run_spec.dict(),
"current_resource": None,
},
"force": False,
},
)

assert response.status_code == 200, response.json()
run_spec_ssh_public_key = response.json()["run_spec"]["ssh_key_pub"]
assert run_spec_ssh_public_key is not None
await session.refresh(user)
assert user.ssh_public_key == run_spec_ssh_public_key
assert user.ssh_private_key is not None


class TestSubmitRun:
@pytest.mark.asyncio
Expand Down