Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 8 additions & 14 deletions src/dstack/_internal/cli/commands/login.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import Optional

from dstack._internal.cli.commands import BaseCommand
from dstack._internal.cli.utils.common import console
from dstack._internal.cli.utils.common import console, resolve_url
from dstack._internal.core.errors import ClientError, CLIError
from dstack._internal.core.models.users import UserWithCreds
from dstack.api._public.runs import ConfigManager
Expand Down Expand Up @@ -202,19 +202,13 @@ def _create_server(self, handler: type[BaseHTTPRequestHandler]) -> HTTPServer:


def _normalize_url_or_error(url: str) -> str:
if not url.startswith("http://") and not url.startswith("https://"):
url = "http://" + url
parsed = urllib.parse.urlparse(url)
if (
not parsed.scheme
or not parsed.hostname
or parsed.path not in ("", "/")
or parsed.params
or parsed.query
or parsed.fragment
or (parsed.port is not None and not (1 <= parsed.port <= 65535))
):
raise CLIError("Invalid server URL format. Format: --url https://sky.dstack.ai")
try:
# Validate the URL and determine the URL scheme.
# Need to resolve the scheme before making first POST request
# since for some redirect codes (301), clients change POST to GET.
url = resolve_url(url)
except ValueError as e:
raise CLIError(e.args[0])
return url


Expand Down
18 changes: 18 additions & 0 deletions src/dstack/_internal/cli/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from pathlib import Path
from typing import Any, Dict, Optional, Union

import requests
from rich.console import Console
from rich.prompt import Confirm
from rich.table import Table
Expand Down Expand Up @@ -128,3 +129,20 @@ def get_start_time(since: Optional[str]) -> Optional[datetime]:
return parse_since(since)
except ValueError as e:
raise CLIError(e.args[0])


def resolve_url(url: str, timeout: float = 5.0) -> str:
"""
Starts with http:// and follows redirects. Returns the final URL (including scheme).
"""
if not url.startswith("http://") and not url.startswith("https://"):
url = "http://" + url
try:
response = requests.get(
url,
allow_redirects=True,
timeout=timeout,
)
except requests.exceptions.ConnectionError as e:
raise ValueError(f"Failed to resolve url {url}") from e
return response.url
8 changes: 8 additions & 0 deletions src/tests/_internal/cli/commands/test_login.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,12 @@ def test_login_no_projects(self, capsys: CaptureFixture, tmp_path: Path):
patch("dstack._internal.cli.commands.login.webbrowser") as webbrowser_mock,
patch("dstack._internal.cli.commands.login.APIClient") as APIClientMock,
patch("dstack._internal.cli.commands.login._LoginServer") as LoginServerMock,
patch(
"dstack._internal.cli.commands.login._normalize_url_or_error"
) as _normalize_url_or_error_mock,
):
webbrowser_mock.open.return_value = True
_normalize_url_or_error_mock.return_value = "http://127.0.0.1:31313"
APIClientMock.return_value.auth.list_providers.return_value = [
SimpleNamespace(name="github", enabled=True)
]
Expand Down Expand Up @@ -49,7 +53,11 @@ def test_login_configures_projects(self, capsys: CaptureFixture, tmp_path: Path)
patch("dstack._internal.cli.commands.login.APIClient") as APIClientMock,
patch("dstack._internal.cli.commands.login.ConfigManager") as ConfigManagerMock,
patch("dstack._internal.cli.commands.login._LoginServer") as LoginServerMock,
patch(
"dstack._internal.cli.commands.login._normalize_url_or_error"
) as _normalize_url_or_error_mock,
):
_normalize_url_or_error_mock.return_value = "http://127.0.0.1:31313"
webbrowser_mock.open.return_value = True
APIClientMock.return_value.auth.list_providers.return_value = [
SimpleNamespace(name="github", enabled=True)
Expand Down