Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ client/configs/lab/*
client/venv/*
logs/*

# scratch space
scratch/*

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
2 changes: 2 additions & 0 deletions client/lomas_client/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,7 @@

SNSYNTH_DEFAULT_SAMPLES_NB = 200

OIDC_REQUIRED_SCOPES = "openid profile email offline_access"

# Only for testing
DEFAULT_EPSILON = 1.0
144 changes: 111 additions & 33 deletions client/lomas_client/http_client.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
import json
import os
from time import sleep
import tempfile
import time

import requests
from oauthlib.oauth2 import LegacyApplicationClient, TokenExpiredError
from authlib.integrations.base_client.errors import OAuthError
from authlib.integrations.requests_client import OAuth2Session
from authlib.oauth2.rfc6749.errors import OAuth2Error
from opentelemetry.instrumentation.requests import RequestsInstrumentor
from requests_oauthlib import OAuth2Session

from lomas_client.constants import CONNECT_TIMEOUT, DEFAULT_READ_TIMEOUT
from lomas_client.constants import CONNECT_TIMEOUT, DEFAULT_READ_TIMEOUT, OIDC_REQUIRED_SCOPES
from lomas_client.models.config import ClientConfig
from lomas_core.constants import OIDC_LOMAS_CLIENT__CLIENT_ID
from lomas_core.models.config import OIDCDeviceCodeResponse
from lomas_core.models.constants import init_logging
from lomas_core.models.requests import LomasRequestModel
from lomas_core.models.responses import Job
Expand All @@ -28,29 +32,108 @@ def __init__(self, config: ClientConfig) -> None:
self.config = config

if not self.config.oidc_use_tls or not self.config.lomas_service_use_tls:
logger.warning(
"OIDC IdP or Lomas service configured without TLS -> using oauthlib insecure transport"
)
os.environ["OAUTHLIB_INSECURE_TRANSPORT"] = "1"
else:
# Reset in case it was changed before
os.environ["OAUTHLIB_INSECURE_TRANSPORT"] = "0"
logger.warning("OIDC IdP or Lomas service configured without TLS -> using insecure transport")

self._oauth2_session = OAuth2Session(
client_id="lomas_client",
token_endpoint=self.config.oidc_config.token_endpoint,
scope=OIDC_REQUIRED_SCOPES,
update_token=self._save_token,
token=self._load_token(),
token_endpoint_auth_method="none",
leeway=30, # refresh token 30 seconds before expiry
)

oauth_client = LegacyApplicationClient(OIDC_LOMAS_CLIENT__CLIENT_ID)
self._oauth2_session = OAuth2Session(client=oauth_client)
try:
self._oauth2_session.refresh_token()
except (OAuth2Error, AttributeError, requests.HTTPError):
# Fallback to authorize
# We catch http errors because dex fails when it cannot link a token to existing user.
# We catch attribute error in case the token is none
self._authorize()

def _get_token_file(self) -> str:
"""Returns a temp filename for saving/loading the token."""
return os.path.join(
tempfile.gettempdir(), f"lomas_{self.config.user_name}_{self.config.dataset_name}_token.json"
)
Comment on lines +57 to +59
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no pathlib ? :sadface:


# Fetch first token:
self._fetch_token()
def _save_token(self, token: dict, refresh_token: str | None = None) -> None:
"""Saves the token to disk."""
with open(self._get_token_file(), "w") as f:
json.dump(token, f)

def _load_token(self) -> dict | None:
"""Tries to load the saved token from disk."""
if os.path.exists(self._get_token_file()):
with open(self._get_token_file()) as f:
return json.load(f)
return None

def _authorize(self) -> None:
"""Chooses the right grant and gets access token."""
if self.config.use_password_flow:
self._password_flow()
else:
self._device_flow()

def _fetch_token(self) -> None:
"""Fetches an authorization token and stores it."""
def _password_flow(self) -> None:
"""Performs a legacy password flow to fetch an access token."""
self._oauth2_session.fetch_token(
str(self.config.oidc_config.token_endpoint),
self.config.oidc_config.token_endpoint,
username=self.config.user_name,
password=self.config.user_password,
scope=["openid", "profile", "email"],
grant_type="password",
)

def _device_flow(self) -> None:
"""Fetches an access token using the device auth flow.

Waits until the user has authorized the python client.

Raises:
TimeoutError: In case the user did not authorize the Lomas Python client in time.
"""
print("Authorizing Lomas Python client")

device_data_resp = requests.post(
str(self.config.oidc_config.device_authorization_endpoint),
data={"client_id": OIDC_LOMAS_CLIENT__CLIENT_ID, "scope": OIDC_REQUIRED_SCOPES},
)
device_data_resp.raise_for_status()
device_data = OIDCDeviceCodeResponse.model_validate(device_data_resp.json())

if not device_data.verification_uri_complete:
print(f"Go to: {device_data.verification_uri}")
print(f"Log in and authorize the Lomas Python client with this code {device_data.user_code}")
else:
print(f"Go to: {device_data.verification_uri_complete}")
print("Log in and authorize the Lomas Python client.")

Comment on lines +106 to +112
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rich.pretty for flex ?

print("This will hang until the authorization is complete...")

interval = 5
while True:
try:
self._oauth2_session.fetch_token(
self.config.oidc_config.token_endpoint,
grant_type="urn:ietf:params:oauth:grant-type:device_code",
device_code=device_data.device_code,
)
break
except (OAuth2Error, OAuthError) as e:
if e.error == "authorization_pending":
time.sleep(interval)
elif e.error == "slow_down":
interval += 5
time.sleep(interval)
elif e.error == "expired_token":
raise TimeoutError("Lomas Python client was not authorized soon enough.") from e
else:
raise e

print("Authorization process complete.")

def post(
self,
endpoint: str,
Expand Down Expand Up @@ -88,17 +171,16 @@ def post(
headers=self.headers,
timeout=(CONNECT_TIMEOUT, read_timeout),
)
except TokenExpiredError:
# This also catches if there is no token at first try.
# Retry with new token
self._fetch_token()
except OAuth2Error:
# Handle expired refresh token
self._authorize()

r = self._oauth2_session.post(
f"{self.config.app_url}/{endpoint}",
json=body.model_dump(),
headers=self.headers,
timeout=(CONNECT_TIMEOUT, read_timeout),
)

return r

def wait_for_job(self, job_uid: str, n_retry: int = 1800, sleep_sec: float = 1) -> Job:
Expand All @@ -108,9 +190,10 @@ def wait_for_job(self, job_uid: str, n_retry: int = 1800, sleep_sec: float = 1)
job_query = self._oauth2_session.get(
f"{self.config.app_url}/status/{job_uid}", headers=self.headers, timeout=(CONNECT_TIMEOUT)
).json()
except TokenExpiredError:
# This also catches if there is no token at first try.
self._fetch_token()
except OAuth2Error:
# Handle expired refresh token
self._authorize()

job_query = self._oauth2_session.get(
f"{self.config.app_url}/status/{job_uid}", headers=self.headers, timeout=(CONNECT_TIMEOUT)
).json()
Expand All @@ -119,11 +202,6 @@ def wait_for_job(self, job_uid: str, n_retry: int = 1800, sleep_sec: float = 1)
if "status" in job_query and job_query["status"] in {"complete", "failed"}:
return Job.model_validate(job_query)

if "type" in job_query and job_query["type"] == "UnauthorizedAccessException":
# Handle unauthorized specifically
self._fetch_token() # refresh token
continue # retry the request

sleep(sleep_sec)
time.sleep(sleep_sec)

raise TimeoutError(f"Job {job_uid} didn't complete in time ({sleep_sec * n_retry})")
7 changes: 4 additions & 3 deletions client/lomas_client/models/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,11 @@ class ClientConfig(BaseSettings):
"""The base URL for the API server."""
dataset_name: str
"""The name of the dataset to be accessed or manipulated."""
user_name: str
use_password_flow: bool = False
"""If true, uses the legacy password auth flow."""
user_name: str | None = None
"""User name."""
# TODO add option for devide auth flow.
user_password: str | None
user_password: str | None = None
"""User password."""
oidc_discovery_url: HttpUrl
"""The oidc provier discovery Url."""
Expand Down
102 changes: 100 additions & 2 deletions client/lomas_client/tests/test_integrations.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
import io
import re
import sys
import time
from dataclasses import dataclass
from urllib.parse import urljoin

import numpy as np
import opendp.prelude as dp
import pandas as pd
import polars as pl
import pytest
import requests
from authlib.integrations.base_client.errors import OAuthError
from bs4 import BeautifulSoup
from diffprivlib import models
from oauthlib import oauth2
from sklearn.pipeline import Pipeline

from lomas_client import Client
Expand Down Expand Up @@ -65,7 +72,7 @@ def test_missing_configs() -> None:


def test_oauth2(aria, dex_config) -> None:
with pytest.raises(oauth2.AccessDeniedError, match=r"Invalid username or password"):
with pytest.raises(OAuthError, match=r"Invalid username or password"):
aria.as_client()

# Add a user
Expand All @@ -77,6 +84,97 @@ def test_oauth2(aria, dex_config) -> None:
client.get_dataset_metadata()


class DeviceAuthorizationBot(io.StringIO):
def __init__(self, user_name, user_password, *args, **kwargs):
super().__init__(*args, **kwargs)
self.user_name = user_name
self.user_password = user_password
self.found = False
self.terminal = sys.stdout

def verify_device(self, url):
# url is something like http://localhost:4445/dex/device&user_code=ABDC-EFGH

session = requests.Session()

# 1. Access the verification page
resp = session.get(url)
resp.raise_for_status()

# 2. Find the form and submit the user_code
# User code should already be filled out.
soup = BeautifulSoup(resp.text, "html.parser")
form_action = soup.find("form")["action"]
# Handle relative paths
form_action = urljoin(resp.url, form_action)
user_code = soup.find("input", {"name": "user_code"})["value"]
resp = session.post(form_action, data={"user_code": user_code})
resp.raise_for_status()

# 3. Handle the Login Page (Email/Password)
soup = BeautifulSoup(resp.text, "html.parser")
login_form = soup.find("form")
login_url = urljoin(resp.url, login_form["action"])
login_data = {"login": self.user_name, "password": self.user_password}

resp = session.post(login_url, data=login_data)
resp.raise_for_status()

# 4. Approve
soup = BeautifulSoup(resp.text, "html.parser")
approve_form = soup.find("form")
data = {}
for hidden_input in approve_form.find_all("input", type="hidden"):
data[hidden_input.get("name")] = hidden_input.get("value")

resp = session.post(resp.url, data=data)
resp.raise_for_status()

def write(self, s):
# Still print to console
self.terminal.write(s)
if not self.found:
uri_match = re.search(r"http[s]?:[^\s]+", s)
if uri_match:
# Find verification url and authorize device
self.found = True
uri = uri_match.group(0)
self.verify_device(uri)


@pytest.mark.long
@pytest.mark.timeout(15)
def test_device_flow(demo_setup) -> None:
# Setup authorization bot
user_name = "jack"
bot = DeviceAuthorizationBot(
user_name=f"{user_name}@example.com",
user_password=user_name,
)
old_stdout = sys.stdout
sys.stdout = bot

client = Client(dataset_name="TITANIC", use_password_flow=False)

# Reset stdout
sys.stdout = old_stdout

init_budget = client.get_initial_budget()
assert init_budget.initial_delta == 0.2

# Test refresh token works (our dex config sets lifetime of 10sec for access token)
time.sleep(10)

init_budget = client.get_initial_budget()
assert init_budget.initial_delta == 0.2

# Check new client uses saved token (in tempfile)
client = Client(dataset_name="TITANIC", use_password_flow=False)

init_budget = client.get_initial_budget()
assert init_budget.initial_delta == 0.2


def test_oauth2_demo(dex_config, demo_setup) -> None:
user_name = "Jack"
client = Client(
Expand Down
8 changes: 6 additions & 2 deletions client/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,16 @@ classifiers = [
dependencies = [
"lomas-core==0.4.1",
"jupyter>=1.1,<2",
"requests-oauthlib>=2.0",
"oauthlib>=3.2",
"authlib>=1.6.7",
"opentelemetry-instrumentation-logging>=0.50b0",
"opentelemetry-instrumentation-requests>=0.50b0",
"seaborn>=0.13",
]

[project.optional-dependencies]
test = [
"beautifulsoup4>=4.14.2",
]

[project.urls]
Homepage = "https://github.com/dscc-admin/lomas/"
11 changes: 11 additions & 0 deletions core/lomas_core/models/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,14 @@ class OIDCConfig(BaseModel):
userinfo_endpoint: HttpUrl
device_authorization_endpoint: HttpUrl
introspection_endpoint: HttpUrl


class OIDCDeviceCodeResponse(BaseModel):
"""Base model for oidc device code response."""

model_config = ConfigDict(extra="ignore")

user_code: str
device_code: str
verification_uri: str
verification_uri_complete: str | None = None
Loading
Loading