Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion examples/get_tenant_token.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import sys
from getpass import getpass

from onekey_client import Client

API_URL = "https://app.eu.onekey.com/api"
Expand All @@ -15,7 +16,7 @@

print("Tenants:", ", ".join([tenant.name for tenant in tenants]))

if len(sys.argv) > 2:
if len(sys.argv) > 2: # noqa: PLR2004 (magic constant)
# Filter tenants that matches the provided pattern
tenants = filter(lambda tenant: sys.argv[2] in tenant.name, tenants)

Expand Down
1 change: 0 additions & 1 deletion examples/upload_firmware.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import sys

from getpass import getpass
from pathlib import Path

Expand Down
3 changes: 2 additions & 1 deletion onekey_client/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .client import Client as Client
from .models import Tenant as Tenant, FirmwareMetadata as FirmwareMetadata
from .models import FirmwareMetadata as FirmwareMetadata
from .models import Tenant as Tenant
19 changes: 7 additions & 12 deletions onekey_client/cli/ci.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import sys
import time
from pathlib import Path
from typing import Optional
from uuid import UUID

import click
import httpx

from junit_xml import TestSuite, TestCase
from junit_xml import TestCase, TestSuite

from onekey_client import Client
from onekey_client.queries import load_query
Expand Down Expand Up @@ -42,9 +40,7 @@ def get_result(self):
except httpx.HTTPError as e:
if error_count <= self.retry_count:
click.echo(
"Error communicating with ONEKEY platform, retrying; error='{}'".format(
str(e)
)
f"Error communicating with ONEKEY platform, retrying; error='{e!s}'"
)
time.sleep(self.retry_wait * error_count)
error_count += 1
Expand Down Expand Up @@ -138,7 +134,7 @@ def wait_for_analysis_finish(self):
)
break
except Exception as e:
click.echo(f"Error fetching results {str(e)}")
click.echo(f"Error fetching results {e!s}")
sys.exit(10)

def get_recent_firmware_id(self):
Expand All @@ -156,11 +152,11 @@ def get_recent_firmware_id(self):
click.echo(
f"Latest firmware upload is not the current firmware, skipping comparison with previous, latest={latest_id}"
)
return
return None

if not firmware_ids:
click.echo("No previous firmware")
return
return None

return firmware_ids[0]

Expand Down Expand Up @@ -303,10 +299,9 @@ def ci_result(
retry_count: int,
retry_wait: int,
check_interval: int,
junit_path: Optional[Path],
junit_path: Path | None,
):
"""Fetch analysis results for CI"""

"""Fetch analysis results for CI."""
handler = ResultHandler(
client,
firmware_id,
Expand Down
5 changes: 3 additions & 2 deletions onekey_client/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
import httpx

from onekey_client import Client
from .firmware_upload import upload_firmware
from .misc import list_tenants, get_tenant_token

from .ci import ci_result
from .firmware_upload import upload_firmware
from .misc import get_tenant_token, list_tenants


@click.group()
Expand Down
16 changes: 7 additions & 9 deletions onekey_client/cli/firmware_upload.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import sys
from pathlib import Path
from typing import Optional

import click

from onekey_client import FirmwareMetadata, Client
from onekey_client import Client, FirmwareMetadata
from onekey_client.errors import QueryError


Expand Down Expand Up @@ -41,12 +40,11 @@ def upload_firmware(
vendor_name: str,
product_group_name: str,
analysis_configuration_name: str,
version: Optional[str],
name: Optional[str],
version: str | None,
name: str | None,
filename: Path,
):
"""Uploads a firmware to the ONEKEY platform"""

"""Upload a firmware to the ONEKEY platform."""
product_group_id = _get_product_group_id_by_name(client, product_group_name)
analysis_configuration_id = _get_analysis_configuration_id_by_name(
client, analysis_configuration_name
Expand All @@ -73,7 +71,7 @@ def upload_firmware(
click.echo(res["id"])
except QueryError as e:
click.echo("Error during firmware upload:")
for error in e._errors:
for error in e.errors:
click.echo(f"- {error['message']}")
sys.exit(11)

Expand All @@ -86,7 +84,7 @@ def _get_product_group_id_by_name(client: Client, product_group_name: str):
except KeyError:
click.echo(f"Missing product group: {product_group_name}")
click.echo("Available product groups:")
for pg in product_groups.keys():
for pg in product_groups:
click.echo(f"- {pg}")
sys.exit(10)

Expand All @@ -101,6 +99,6 @@ def _get_analysis_configuration_id_by_name(
except KeyError:
click.echo(f"Missing analysis configuration {analysis_configuration_name}")
click.echo("Available analysis configurations:")
for config in analysis_configurations.keys():
for config in analysis_configurations:
click.echo(f"- {config}")
sys.exit(12)
6 changes: 2 additions & 4 deletions onekey_client/cli/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
@click.command()
@click.pass_obj
def list_tenants(client: Client):
"""List available tenants"""

"""List available tenants."""
tenants = client.get_all_tenants()
for tenant in tenants:
click.echo(f"{tenant.name} ({tenant.id}")
Expand All @@ -18,6 +17,5 @@ def list_tenants(client: Client):
@click.command()
@click.pass_obj
def get_tenant_token(client: Client):
"""Get tenant specific Bearer token"""

"""Get tenant specific Bearer token."""
click.echo(json.dumps(client.get_auth_headers()))
57 changes: 26 additions & 31 deletions onekey_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,24 @@
import gc
import secrets
from pathlib import Path
from typing import Optional, List, Dict

from httpx import URL

try:
from importlib import resources
except ImportError:
import importlib_resources as resources

import httpx
from pydantic import parse_obj_as
from authlib.oidc.core import IDToken
from authlib.jose import jwt
from .queries import load_query
from . import errors
from . import models as m
from . import keys
from authlib.oidc.core import IDToken
from httpx import URL
from pydantic import parse_obj_as

from . import errors, keys
from . import models as m
from .queries import load_query

CLIENT_ID = "ONEKEY Python SDK"
TOKEN_NAMESPACE = "https://www.onekey.com/"
TOKEN_NAMESPACE = "https://www.onekey.com/" # noqa: S105 (hardcoded credential)


def _login_required(func):
Expand Down Expand Up @@ -51,8 +48,8 @@ class Client:
def __init__(
self,
api_url: str,
ca_bundle: Optional[Path] = None,
disable_tls_verify: Optional[bool] = False,
ca_bundle: Path | None = None,
disable_tls_verify: bool | None = False,
):
self._api_url = URL(api_url)
self._client = self._setup_httpx_client(api_url, ca_bundle, disable_tls_verify)
Expand All @@ -66,29 +63,27 @@ def __init__(
def _setup_httpx_client(
self,
api_url: str,
ca_bundle: Optional[Path] = None,
disable_tls_verify: Optional[bool] = False,
ca_bundle: Path | None = None,
disable_tls_verify: bool | None = False,
):
if disable_tls_verify:
return httpx.Client(base_url=api_url, verify=False)
return httpx.Client(base_url=api_url, verify=False) # noqa: S501 (TLS certificate validation disabled)

if ca_bundle is not None:
ca = ca_bundle.expanduser()
if not ca.exists():
raise errors.InvalidCABundle

return httpx.Client(base_url=api_url, verify=str(ca))
else:
with resources.path(keys, "ca.pem") as ca:
return httpx.Client(base_url=api_url, verify=str(ca))
with resources.path(keys, "ca.pem") as ca:
return httpx.Client(base_url=api_url, verify=str(ca))

def _load_key(self, key_name: str, path: Optional[Path] = None):
def _load_key(self, key_name: str, path: Path | None = None):
if path is not None:
return path.read_bytes()
else:
response = self._client.get(f"/{key_name}.pem")
response.raise_for_status()
return response.read()
response = self._client.get(f"/{key_name}.pem")
response.raise_for_status()
return response.read()

@property
def api_url(self) -> URL:
Expand All @@ -111,7 +106,7 @@ def login(self, email: str, password: str):
claims_cls=IDToken,
)
tenants = id_token[TOKEN_NAMESPACE + "tenants"]
tenants = parse_obj_as(List[m.Tenant], tenants)
tenants = parse_obj_as(list[m.Tenant], tenants)
self._state.tenants = {e.name: e for e in tenants}
self._state.email = email
self._state.raw_id_token = json_res["id_token"]
Expand All @@ -120,7 +115,7 @@ def use_token(self, token: str):
try:
tenant_id, _ = token.split("/", 1)
except ValueError:
raise errors.InvalidAPIToken()
raise errors.InvalidAPIToken from None

self._state.raw_tenant_token = token

Expand All @@ -131,7 +126,7 @@ def use_token(self, token: str):
self._state.tenants = {tenant.name: tenant}
self._state.tenant = tenant

def _post(self, path: str, headers: Optional[Dict] = None, **kwargs):
def _post(self, path: str, headers: dict | None = None, **kwargs):
response = self._client.post(path, headers=headers, **kwargs)
response.raise_for_status()
return response.json()
Expand All @@ -152,7 +147,7 @@ def get_tenant(self, name: str):
return self._state.tenants[name]

@_login_required
def get_all_tenants(self) -> List[m.Tenant]:
def get_all_tenants(self) -> list[m.Tenant]:
"""Get the list of Tenants you have access to."""
return list(self._state.tenants.values())

Expand Down Expand Up @@ -182,8 +177,8 @@ def refresh_tenant_token(self):
self.use_tenant(self._state.tenant)

@_tenant_required
def query(self, query: str, variables: Optional[Dict] = None, timeout=60):
"""Issues a GraphQL query and returns the results"""
def query(self, query: str, variables: dict | None = None, timeout=60):
"""Issues a GraphQL query and returns the results."""
res = self._post_with_token(
"/graphql", json={"query": query, "variables": variables}, timeout=timeout
)
Expand Down Expand Up @@ -224,10 +219,9 @@ def upload_firmware(
raise errors.QueryError(res["createFirmwareUpload"]["errors"])

upload_url = res["createFirmwareUpload"]["uploadUrl"]
res = self._post_with_token(
return self._post_with_token(
upload_url, files={"firmware": path.open("rb")}, timeout=timeout
)
return res

@_tenant_required
def get_product_groups(self):
Expand Down Expand Up @@ -271,6 +265,7 @@ def _verify_token(

class _LoginState:
"""Keeps state after login.

Client.logout() will simply delete the instance from memory.
"""

Expand Down
7 changes: 3 additions & 4 deletions onekey_client/errors.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import json
from typing import Optional


class ClientError(Exception):
"""Base class for all Client errors."""

def __init__(self, message: Optional[str] = None):
def __init__(self, message: str | None = None):
super().__init__(message or self.MESSAGE)


Expand Down Expand Up @@ -35,7 +34,7 @@ class QueryError(ClientError):
"""raised when a GraphQL query returns errors."""

def __init__(self, errors_json: dict):
self._errors = errors_json
self.errors = errors_json

def __str__(self):
return json.dumps(self._errors, indent=4)
return json.dumps(self.errors, indent=4)
10 changes: 5 additions & 5 deletions onekey_client/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import datetime as dt
from typing import Optional
from uuid import UUID

from pydantic import BaseModel


Expand All @@ -11,11 +11,11 @@ class Tenant(BaseModel):

class FirmwareMetadata(BaseModel):
name: str
version: Optional[str] = None
release_date: Optional[dt.datetime] = None
notes: Optional[str] = None
version: str | None = None
release_date: dt.datetime | None = None
notes: str | None = None
vendor_name: str
product_name: str
product_category: Optional[str] = None
product_category: str | None = None
product_group_id: UUID
analysis_configuration_id: UUID
3 changes: 2 additions & 1 deletion onekey_client/queries/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import functools

from .. import queries

try:
Expand All @@ -7,7 +8,7 @@
import importlib_resources as resources


@functools.lru_cache()
@functools.lru_cache
def load_query(query_name) -> str:
"""Load a predefined GraphQL query and cache it."""
assert query_name.endswith(".graphql")
Expand Down
Loading