Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion app/core/auth/endpoints_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,12 @@ async def login_for_access_token(
)
# We put the user id in the subject field of the token.
# The subject `sub` is a JWT registered claim name, see https://datatracker.ietf.org/doc/html/rfc7519#section-4.1
data = schemas_auth.TokenData(sub=user.id, scopes=ScopeType.auth)
data = schemas_auth.TokenData(
sub=user.id,
scopes=ScopeType.auth,
account_type=user.account_type,
group_ids=[group.id for group in user.groups],
)
access_token = create_access_token(settings=settings, data=data)
return {"access_token": access_token, "token_type": "bearer"}

Expand Down Expand Up @@ -908,6 +913,10 @@ async def create_response_body(
status_code=500,
detail="Could not find user when trying the get userinfo but it should exist",
)
id_token_data.account_type = user.account_type
access_token_data.account_type = user.account_type
id_token_data.group_ids = [group.id for group in user.groups]
access_token_data.group_ids = [group.id for group in user.groups]
additional_data = auth_client.get_userinfo(user=user)

id_token = create_access_token_RS256(
Expand Down
3 changes: 3 additions & 0 deletions app/core/auth/schemas_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from fastapi import Form
from pydantic import BaseModel, field_validator

from app.core.groups.groups_type import AccountType
from app.utils import validators
from app.utils.examples import examples_auth

Expand Down Expand Up @@ -93,6 +94,8 @@ class AccessToken(BaseModel):

class TokenData(BaseModel):
sub: str # Subject: the user id
account_type: AccountType | None = None
group_ids: list[str] | None = None
iss: str | None = None
aud: str | None = None
cid: str | None = None # The client_id of the service which receives the token
Expand Down
2 changes: 1 addition & 1 deletion app/core/utils/security.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def create_access_token(
if expires_delta is None:
# We use the default value
expires_delta = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
to_encode = data.model_dump(exclude_none=True)
to_encode = data.model_dump(exclude_none=True, mode="json")
iat = datetime.now(UTC)
expire_on = datetime.now(UTC) + expires_delta
to_encode.update({"exp": expire_on, "iat": iat})
Expand Down
7 changes: 6 additions & 1 deletion tests/commons.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,12 @@ def create_api_access_token(
Create a JWT access token for the `user` with the scope `API`
"""

access_token_data = schemas_auth.TokenData(sub=user.id, scopes="API")
access_token_data = schemas_auth.TokenData(
sub=user.id,
scopes="API",
account_type=user.account_type,
group_ids=[group.id for group in user.groups],
)
return security.create_access_token(
data=access_token_data,
settings=override_get_settings(),
Expand Down