Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions docs/Flash_Deploy_Guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,28 @@ graph TB

## CLI Commands Reference

### flash login

Authenticate the CLI and store a RunPod API key locally.

```bash
flash login [--no-open] [--timeout <seconds>]
```

**What it does:**
1. Creates an auth request via GraphQL
2. Opens a browser to approve the request
3. Polls for approval and receives an API key
4. Stores the key at `~/.config/runpod/credentials.toml` (or `RUNPOD_CREDENTIALS_FILE`)

**Notes:**
- `RUNPOD_API_KEY` still takes precedence if set
- Use `--no-open` to print the URL only

**Implementation:** `src/tetra_rp/cli/commands/login.py`

---

### flash deploy new

Create a new deployment environment (mothership).
Expand Down
8 changes: 6 additions & 2 deletions src/runpod_flash/cli/commands/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def init_command(
title = "Project Initialized" if is_current_dir else "Project Created"
console.print(Panel(panel_content, title=title, expand=False))

# Next steps
# next steps
console.print("\n[bold]Next steps:[/bold]")
steps_table = Table(show_header=False, box=None, padding=(0, 1))
steps_table.add_column("Step", style="bold cyan")
Expand All @@ -111,13 +111,17 @@ def init_command(
step_num += 1
steps_table.add_row(f"{step_num}.", "cp .env.example .env")
step_num += 1
steps_table.add_row(f"{step_num}.", "Add your RUNPOD_API_KEY to .env")
steps_table.add_row(
f"{step_num}.", "Add your RUNPOD_API_KEY to .env (or run flash login)"
)
step_num += 1
steps_table.add_row(f"{step_num}.", "flash run")

console.print(steps_table)

console.print("\n[bold]Get your API key:[/bold]")
console.print(" https://docs.runpod.io/get-started/api-keys")
console.print("\n[bold]Or authenticate with flash:[/bold]")
console.print(" flash login")
console.print("\nVisit http://localhost:8888/docs after running")
console.print("\nCheck out the README.md for more")
90 changes: 90 additions & 0 deletions src/runpod_flash/cli/commands/login.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import asyncio
import datetime as dt
from typing import Optional

import typer
from rich.console import Console
from rich.panel import Panel

from runpod_flash.core.api.runpod import RunpodGraphQLClient
from runpod_flash.core.credentials import save_api_key
from runpod_flash.core.resources.constants import CONSOLE_BASE_URL

console = Console()

POLL_INTERVAL_SECONDS = 2.0
DEFAULT_TIMEOUT_SECONDS = 600.0


def _parse_expires_at(value: Optional[str]) -> Optional[dt.datetime]:
if not value:
return None
try:
return dt.datetime.fromisoformat(value.replace("Z", "+00:00"))
except ValueError:
return None


async def _login_async(open_browser: bool, timeout_seconds: float) -> None:
async with RunpodGraphQLClient(require_api_key=False) as client:
request = await client.create_flash_auth_request()
request_id = request.get("id")
if not request_id:
raise RuntimeError("auth request failed to initialize")

auth_url = f"{CONSOLE_BASE_URL}/flash/login?request={request_id}"
console.print(
Panel(
f"[bold]open this url to authorize flash:[/bold]\n{auth_url}",
title="flash login",
expand=False,
)
)
if open_browser:
typer.launch(auth_url)

expires_at = _parse_expires_at(request.get("expiresAt"))
deadline = dt.datetime.now(dt.timezone.utc) + dt.timedelta(
seconds=timeout_seconds
)
if expires_at and expires_at < deadline:
deadline = expires_at

with console.status("[cyan]waiting for authorization...[/cyan]"):
while True:
status_payload = await client.get_flash_auth_request_status(request_id)
status = status_payload.get("status")
api_key = status_payload.get("apiKey")

if status == "APPROVED" and api_key:
path = save_api_key(api_key)
console.print(
Panel(
f"[green]logged in![/green]\ncredentials saved to {path}",
title="flash login",
expand=False,
)
)
return

if status in {"DENIED", "EXPIRED", "CONSUMED"}:
raise RuntimeError(f"login failed: {status.lower()}")

if dt.datetime.now(dt.timezone.utc) >= deadline:
raise RuntimeError("login timed out")

await asyncio.sleep(POLL_INTERVAL_SECONDS)


def login_command(
no_open: bool = typer.Option(False, "--no-open", help="do not open the browser"),
timeout: float = typer.Option(
DEFAULT_TIMEOUT_SECONDS, "--timeout", help="max wait time in seconds"
),
):
"""Authenticate and save a Runpod API key for flash."""
try:
asyncio.run(_login_async(open_browser=not no_open, timeout_seconds=timeout))
except RuntimeError as exc:
console.print(f"[red]error:[/red] {exc}")
raise typer.Exit(code=1)
2 changes: 2 additions & 0 deletions src/runpod_flash/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
deploy,
apps,
undeploy,
login,
)


Expand All @@ -38,6 +39,7 @@ def get_version() -> str:
app.command("init")(init.init_command)
app.command("run")(run.run_command)
app.command("build")(build.build_command)
app.command("login")(login.login_command)
# app.command("report")(resource.report_command)


Expand Down
55 changes: 40 additions & 15 deletions src/runpod_flash/core/api/runpod.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import aiohttp
from aiohttp.resolver import ThreadedResolver

from runpod_flash.core.credentials import get_api_key
from runpod_flash.core.exceptions import RunpodAPIKeyError
from runpod_flash.runtime.exceptions import GraphQLMutationError, GraphQLQueryError

Expand Down Expand Up @@ -59,9 +60,9 @@ class RunpodGraphQLClient:

GRAPHQL_URL = f"{RUNPOD_API_BASE_URL}/graphql"

def __init__(self, api_key: Optional[str] = None):
self.api_key = api_key or os.getenv("RUNPOD_API_KEY")
if not self.api_key:
def __init__(self, api_key: Optional[str] = None, require_api_key: bool = True):
self.api_key = api_key or get_api_key()
if require_api_key and not self.api_key:
raise RunpodAPIKeyError()

self.session: Optional[aiohttp.ClientSession] = None
Expand All @@ -71,12 +72,12 @@ async def _get_session(self) -> aiohttp.ClientSession:
if self.session is None or self.session.closed:
timeout = aiohttp.ClientTimeout(total=300) # 5 minute timeout
connector = aiohttp.TCPConnector(resolver=ThreadedResolver())
headers = {"Content-Type": "application/json"}
if self.api_key:
headers["Authorization"] = f"Bearer {self.api_key}"
self.session = aiohttp.ClientSession(
timeout=timeout,
headers={
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
},
headers=headers,
connector=connector,
)
return self.session
Expand Down Expand Up @@ -748,6 +749,33 @@ async def endpoint_exists(self, endpoint_id: str) -> bool:
log.error(f"Error checking endpoint existence: {e}")
return False

async def create_flash_auth_request(self) -> Dict[str, Any]:
mutation = """
mutation createFlashAuthRequest {
createFlashAuthRequest {
id
status
expiresAt
}
}
"""
result = await self._execute_graphql(mutation)
return result.get("createFlashAuthRequest", {})

async def get_flash_auth_request_status(self, request_id: str) -> Dict[str, Any]:
query = """
query flashAuthRequestStatus($flashAuthRequestId: String!) {
flashAuthRequestStatus(flashAuthRequestId: $flashAuthRequestId) {
id
status
expiresAt
apiKey
}
}
"""
result = await self._execute_graphql(query, {"flashAuthRequestId": request_id})
return result.get("flashAuthRequestStatus", {})

async def close(self):
"""Close the HTTP session."""
if self.session and not self.session.closed:
Expand All @@ -767,7 +795,7 @@ class RunpodRestClient:
"""

def __init__(self, api_key: Optional[str] = None):
self.api_key = api_key or os.getenv("RUNPOD_API_KEY")
self.api_key = api_key or get_api_key()
if not self.api_key:
raise RunpodAPIKeyError()

Expand All @@ -777,13 +805,10 @@ async def _get_session(self) -> aiohttp.ClientSession:
"""Get or create an aiohttp session."""
if self.session is None or self.session.closed:
timeout = aiohttp.ClientTimeout(total=300) # 5 minute timeout
self.session = aiohttp.ClientSession(
timeout=timeout,
headers={
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
},
)
headers = {"Content-Type": "application/json"}
if self.api_key:
headers["Authorization"] = f"Bearer {self.api_key}"
self.session = aiohttp.ClientSession(timeout=timeout, headers=headers)
return self.session

async def _execute_rest(
Expand Down
57 changes: 57 additions & 0 deletions src/runpod_flash/core/credentials.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from __future__ import annotations

import os
from pathlib import Path
from typing import Optional

try:
import tomllib
except ImportError: # python < 3.11
import tomli as tomllib


def get_credentials_path() -> Path:
credentials_file = os.getenv("RUNPOD_CREDENTIALS_FILE")
if credentials_file:
return Path(credentials_file).expanduser()

config_home = os.getenv("XDG_CONFIG_HOME")
base_dir = (
Path(config_home).expanduser() if config_home else Path.home() / ".config"
)
return base_dir / "runpod" / "credentials.toml"


def _read_credentials() -> dict:
path = get_credentials_path()
if not path.exists():
return {}

try:
with path.open("rb") as handle:
return tomllib.load(handle)
except (OSError, ValueError):
return {}


def get_api_key() -> Optional[str]:
api_key = os.getenv("RUNPOD_API_KEY")
if api_key and api_key.strip():
return api_key

stored = _read_credentials().get("api_key")
if isinstance(stored, str) and stored.strip():
return stored

return None


def save_api_key(api_key: str) -> Path:
path = get_credentials_path()
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(f'api_key = "{api_key}"\n', encoding="utf-8")
try:
os.chmod(path, 0o600)
except OSError:
pass
return path
11 changes: 10 additions & 1 deletion src/runpod_flash/core/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
Provides clear, actionable error messages for common failure scenarios.
"""

from runpod_flash.core.credentials import get_credentials_path


class RunpodAPIKeyError(Exception):
"""Raised when RUNPOD_API_KEY environment variable is missing or invalid.
Expand All @@ -28,7 +30,8 @@ def _default_message() -> str:
Returns:
Formatted error message with actionable steps.
"""
return """RUNPOD_API_KEY environment variable is required but not set.
credentials_path = get_credentials_path()
return f"""RUNPOD_API_KEY environment variable is required but not set.

To use Flash remote execution features, you need a Runpod API key.

Expand All @@ -46,5 +49,11 @@ def _default_message() -> str:
3. In your shell profile (~/.bashrc, ~/.zshrc):
echo 'export RUNPOD_API_KEY=your_api_key_here' >> ~/.bashrc

4. Use the flash login flow:
flash login

5. Save directly to credentials file:
{credentials_path}

Note: If you created a .env file, make sure it's in your current directory
or project root where Flash can find it."""
5 changes: 3 additions & 2 deletions src/runpod_flash/core/resources/cloud.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import runpod

runpod.api_key = os.getenv("RUNPOD_API_KEY")
from runpod_flash.core.credentials import get_api_key

runpod.api_key = get_api_key()
Loading
Loading