Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 101 additions & 1 deletion backend/add_users_to_db.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import asyncio
import os
from datetime import datetime, timezone
from typing import Union

from redis import asyncio as aioredis
from sqlalchemy import select
from sqlalchemy.orm import Session

from app.config import REDIS_HOST
from app.database import get_session
Expand Down Expand Up @@ -33,7 +36,7 @@ async def async_redis_operations(key: str, value: int | None) -> None:
await redis.aclose()


def run_redis_async_tasks(key: str, value: int | str) -> None:
def run_redis_async_tasks(key: str, value: Union[int, str]) -> None:
"""
Run asynchronous Redis operations to set the remaining API calls for a user.
"""
Expand All @@ -43,6 +46,103 @@ def run_redis_async_tasks(key: str, value: int | str) -> None:
loop.run_until_complete(async_redis_operations(key, value_int))


def ensure_default_workspace(db_session: Session, user_db: UserDB) -> None:
"""
Ensure that a user has a default workspace.

Parameters
----------
db_session
The database session.
user_db
The user DB record.
"""
# Check if user already has a workspace
stmt = select(UserWorkspaceDB).where(UserWorkspaceDB.user_id == user_db.user_id)
result = db_session.execute(stmt)
existing_workspace = result.scalar_one_or_none()

if existing_workspace:
logger.info(
f"User {user_db.username} already has workspace relationship: "
f"{existing_workspace.workspace_id}"
)
# Check if any workspace is set as default
stmt = select(UserWorkspaceDB).where(
UserWorkspaceDB.user_id == user_db.user_id,
UserWorkspaceDB.default_workspace,
)
result = db_session.execute(stmt)
default_workspace = result.scalar_one_or_none()

if default_workspace:
logger.info(
f"User {user_db.username} already has default workspace: "
f"{default_workspace.workspace_id}"
)
return
else:
# Set first workspace as default
existing_workspace.default_workspace = True
db_session.add(existing_workspace)
db_session.commit()
logger.info(
f"Set workspace {existing_workspace.workspace_id} as default for "
f"{user_db.username}"
)
return

# Create a default workspace for the user
workspace_name = f"{user_db.username}'s Workspace"

# Check if workspace with this name already exists
stmt = select(WorkspaceDB).where(WorkspaceDB.workspace_name == workspace_name)
result = db_session.execute(stmt)
existing_workspace_db = result.scalar_one_or_none()

if existing_workspace_db:
workspace_db = existing_workspace_db
logger.info(
f"Workspace '{workspace_name}' already exists with ID "
f"{workspace_db.workspace_id}"
)
else:
# Create new workspace
workspace_db = WorkspaceDB(
workspace_name=workspace_name,
api_daily_quota=100,
content_quota=10,
created_datetime_utc=datetime.now(timezone.utc),
updated_datetime_utc=datetime.now(timezone.utc),
is_default=True,
hashed_api_key=get_key_hash("workspace-api-key-" + workspace_name),
api_key_first_characters="works",
api_key_updated_datetime_utc=datetime.now(timezone.utc),
api_key_rotated_by_user_id=user_db.user_id,
)
db_session.add(workspace_db)
db_session.commit()
logger.info(
f"Created workspace '{workspace_name}' with ID {workspace_db.workspace_id}"
)

# Create user-workspace relationship
user_workspace = UserWorkspaceDB(
user_id=user_db.user_id,
workspace_id=workspace_db.workspace_id,
user_role=UserRoles.ADMIN,
default_workspace=True,
created_datetime_utc=datetime.now(timezone.utc),
updated_datetime_utc=datetime.now(timezone.utc),
)
db_session.add(user_workspace)
db_session.commit()
logger.info(
f"Created workspace relationship for user {user_db.username} with workspace "
f"{workspace_db.workspace_id}"
)


if __name__ == "__main__":
db_session = next(get_session())

Expand Down
7 changes: 3 additions & 4 deletions backend/app/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
from fastapi.middleware.cors import CORSMiddleware
from redis import asyncio as aioredis

from . import auth, bayes_ab, contextual_mab, mab, messages
from . import auth, messages
from .config import BACKEND_ROOT_PATH, DOMAIN, REDIS_HOST
from .experiments.routers import router as experiments_router
from .users.routers import (
router as users_router,
) # to avoid circular imports
Expand Down Expand Up @@ -56,9 +57,7 @@ def create_app() -> FastAPI:
expose_headers=["*"],
)

app.include_router(mab.router)
app.include_router(contextual_mab.router)
app.include_router(bayes_ab.router)
app.include_router(experiments_router)
app.include_router(auth.router)
app.include_router(users_router)
app.include_router(messages.router)
Expand Down
1 change: 0 additions & 1 deletion backend/app/bayes_ab/__init__.py

This file was deleted.

Loading