Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .devcontainer/devcontainer.json
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@
"python.analysis.autoImportCompletions": true,
"ruff.enable": true,
"ruff.organizeImports": true,
"ruff.fixAll": true
"ruff.fixAll": true,
"ruff.lineLength": 100
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions .devcontainer/post_create.sh
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#!/bin/bash

cd ../frontend
npm ci --verbose
npm i --verbose

cd ../backend
python -m script.create_db
python -m script.reset_dev
python -m script.reset_dev
2 changes: 1 addition & 1 deletion backend/pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ python_classes = Test*
python_functions = test_*
asyncio_mode = auto
asyncio_default_fixture_loop_scope = function
addopts = --color=yes
addopts = --color=yes
1 change: 0 additions & 1 deletion backend/script/reset_dev.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ async def reset_dev():
data = json.load(f)

police = PoliceEntity(
id=1,
email=data["police"]["email"],
hashed_password=data["police"]["hashed_password"],
)
Expand Down
24 changes: 12 additions & 12 deletions backend/src/core/authentication.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,20 @@
from typing import Literal

from fastapi import Depends, HTTPException, Request, status
from fastapi import Depends, Request
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from src.core.exceptions import CredentialsException, ForbiddenException
from src.modules.account.account_model import Account, AccountRole
from src.modules.police.police_model import PoliceAccount

StringRole = Literal["student", "admin", "staff", "police"]


class HTTPBearer401(HTTPBearer):
async def __call__(self, request: Request):
try:
return await super().__call__(request)
except Exception:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Not authenticated",
headers={"WWW-Authenticate": "Bearer"},
)
raise CredentialsException()


bearer_scheme = HTTPBearer401()
Expand Down Expand Up @@ -69,7 +67,7 @@ async def authenticate_user(
return user


def authenticate_by_role(*roles: Literal["police", "student", "admin", "staff"]):
def authenticate_by_role(*roles: StringRole):
"""
Middleware factory to ensure the authenticated user has one of the specified roles.
"""
Expand All @@ -79,8 +77,12 @@ async def _authenticate(
) -> Account | PoliceAccount:
token = authorization.credentials.lower()

if "police" in roles and token == "police":
return PoliceAccount(email="police@example.com")
# Check if police token and police is allowed
if token == "police":
if "police" in roles:
return PoliceAccount(email="police@example.com")
else:
raise ForbiddenException(detail="Insufficient privileges")

role_map = {
"student": AccountRole.STUDENT,
Expand Down Expand Up @@ -116,9 +118,7 @@ async def authenticate_staff_or_admin(


async def authenticate_student_or_admin(
account: Account | PoliceAccount = Depends(
authenticate_by_role("student", "admin")
),
account: Account | PoliceAccount = Depends(authenticate_by_role("student", "admin")),
) -> Account:
if not isinstance(account, Account):
raise ForbiddenException(detail="Insufficient privileges")
Expand Down
13 changes: 13 additions & 0 deletions backend/src/core/bcrypt_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import bcrypt


def hash_password(password: str) -> str:
"""Hash a password using bcrypt."""
salt = bcrypt.gensalt()
hashed = bcrypt.hashpw(password.encode("utf-8"), salt)
return hashed.decode("utf-8")


def verify_password(password: str, hashed_password: str) -> bool:
"""Verify a password against its hash."""
return bcrypt.checkpw(password.encode("utf-8"), hashed_password.encode("utf-8"))
1 change: 1 addition & 0 deletions backend/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def handle_http_exception(req: Request, exc: HTTPException):
return JSONResponse(
status_code=exc.status_code,
content={"message": exc.detail},
headers=exc.headers,
)


Expand Down
35 changes: 25 additions & 10 deletions backend/src/modules/account/account_entity.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,15 @@
import enum
from typing import Self

from sqlalchemy import CheckConstraint, Enum, Integer, String
from sqlalchemy.orm import Mapped, mapped_column
from sqlalchemy.orm import Mapped, MappedAsDataclass, mapped_column
from src.core.database import EntityBase
from src.modules.account.account_model import Account, AccountData, AccountRole


class AccountRole(enum.Enum):
STUDENT = "student"
STAFF = "staff"
ADMIN = "admin"


class AccountEntity(EntityBase):
class AccountEntity(MappedAsDataclass, EntityBase):
__tablename__ = "accounts"

id: Mapped[int] = mapped_column(Integer, primary_key=True)
id: Mapped[int] = mapped_column(Integer, primary_key=True, init=False)
email: Mapped[str] = mapped_column(String, unique=True, index=True, nullable=False)
first_name: Mapped[str] = mapped_column(String, nullable=False)
last_name: Mapped[str] = mapped_column(String, nullable=False)
Expand All @@ -27,3 +22,23 @@ class AccountEntity(EntityBase):
nullable=False,
)
role: Mapped[AccountRole] = mapped_column(Enum(AccountRole), nullable=False)

@classmethod
def from_model(cls, data: "AccountData") -> Self:
return cls(
email=data.email,
first_name=data.first_name,
last_name=data.last_name,
pid=data.pid,
role=AccountRole(data.role),
)

def to_model(self) -> "Account":
return Account(
id=self.id,
email=self.email,
first_name=self.first_name,
last_name=self.last_name,
pid=self.pid,
role=self.role,
)
20 changes: 7 additions & 13 deletions backend/src/modules/account/account_model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
from typing import Self
from enum import Enum

from pydantic import BaseModel, EmailStr, Field
from src.modules.account.account_entity import AccountEntity, AccountRole


class AccountRole(Enum):
STUDENT = "student"
STAFF = "staff"
ADMIN = "admin"


class AccountData(BaseModel):
Expand All @@ -23,14 +28,3 @@ class Account(BaseModel):
last_name: str
pid: str
role: AccountRole

@classmethod
def from_entity(cls, account_entity: AccountEntity) -> Self:
return cls(
id=account_entity.id,
email=account_entity.email,
first_name=account_entity.first_name,
last_name=account_entity.last_name,
pid=account_entity.pid,
role=AccountRole(account_entity.role.value),
)
20 changes: 9 additions & 11 deletions backend/src/modules/account/account_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,27 +48,25 @@ async def _get_account_entity_by_email(self, email: str) -> AccountEntity:
async def get_accounts(self) -> list[Account]:
result = await self.session.execute(select(AccountEntity))
accounts = result.scalars().all()
return [Account.from_entity(account) for account in accounts]
return [account.to_model() for account in accounts]

async def get_accounts_by_roles(
self, roles: list[AccountRole] | None = None
) -> list[Account]:
async def get_accounts_by_roles(self, roles: list[AccountRole] | None = None) -> list[Account]:
if not roles:
return await self.get_accounts()

result = await self.session.execute(
select(AccountEntity).where(AccountEntity.role.in_(roles))
)
accounts = result.scalars().all()
return [Account.from_entity(account) for account in accounts]
return [account.to_model() for account in accounts]

async def get_account_by_id(self, account_id: int) -> Account:
account_entity = await self._get_account_entity_by_id(account_id)
return Account.from_entity(account_entity)
return account_entity.to_model()

async def get_account_by_email(self, email: str) -> Account:
account_entity = await self._get_account_entity_by_email(email)
return Account.from_entity(account_entity)
return account_entity.to_model()

async def create_account(self, data: AccountData) -> Account:
try:
Expand All @@ -93,7 +91,7 @@ async def create_account(self, data: AccountData) -> Account:
# handle race condition where another session inserted the same email
raise AccountConflictException(data.email)
await self.session.refresh(new_account)
return Account.from_entity(new_account)
return new_account.to_model()

async def update_account(self, account_id: int, data: AccountData) -> Account:
account_entity = await self._get_account_entity_by_id(account_id)
Expand All @@ -120,11 +118,11 @@ async def update_account(self, account_id: int, data: AccountData) -> Account:
except IntegrityError:
raise AccountConflictException(data.email)
await self.session.refresh(account_entity)
return Account.from_entity(account_entity)
return account_entity.to_model()

async def delete_account(self, account_id: int) -> Account:
account_entity = await self._get_account_entity_by_id(account_id)
account = Account.from_entity(account_entity)
account = account_entity.to_model()
await self.session.delete(account_entity)
await self.session.commit()
return account
return account
15 changes: 7 additions & 8 deletions backend/src/modules/complaint/complaint_entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,18 @@
from typing import TYPE_CHECKING, Self

from sqlalchemy import DateTime, ForeignKey, Integer, String
from sqlalchemy.orm import Mapped, mapped_column, relationship
from sqlalchemy.orm import Mapped, MappedAsDataclass, mapped_column, relationship
from src.core.database import EntityBase

from .complaint_model import Complaint
from src.modules.complaint.complaint_model import Complaint, ComplaintData

if TYPE_CHECKING:
from ..location.location_entity import LocationEntity
from src.modules.location.location_entity import LocationEntity


class ComplaintEntity(EntityBase):
class ComplaintEntity(MappedAsDataclass, EntityBase):
__tablename__ = "complaints"

id: Mapped[int] = mapped_column(Integer, primary_key=True)
id: Mapped[int] = mapped_column(Integer, primary_key=True, init=False)
location_id: Mapped[int] = mapped_column(
Integer, ForeignKey("locations.id", ondelete="CASCADE"), nullable=False
)
Expand All @@ -23,11 +22,11 @@ class ComplaintEntity(EntityBase):

# Relationships
location: Mapped["LocationEntity"] = relationship(
"LocationEntity", passive_deletes=True
"LocationEntity", passive_deletes=True, init=False
)

@classmethod
def from_model(cls, data: Complaint) -> Self:
def from_model(cls, data: ComplaintData) -> Self:
return cls(
location_id=data.location_id,
complaint_datetime=data.complaint_datetime,
Expand Down
43 changes: 22 additions & 21 deletions backend/src/modules/location/location_entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,51 +3,52 @@

from sqlalchemy import DECIMAL, DateTime, Index, Integer, String
from sqlalchemy.inspection import inspect
from sqlalchemy.orm import Mapped, mapped_column, relationship
from sqlalchemy.orm import Mapped, MappedAsDataclass, mapped_column, relationship
from src.core.database import EntityBase
from src.modules.complaint.complaint_entity import ComplaintEntity

from .location_model import Location, LocationData


class LocationEntity(EntityBase):
class LocationEntity(MappedAsDataclass, EntityBase):
__tablename__ = "locations"

id: Mapped[int] = mapped_column(Integer, primary_key=True)
id: Mapped[int] = mapped_column(Integer, primary_key=True, init=False)

# OCSL Data
warning_count: Mapped[int] = mapped_column(Integer, default=0)
citation_count: Mapped[int] = mapped_column(Integer, default=0)
hold_expiration: Mapped[datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True
)

# Google Maps Data
# Google Maps Data (required fields first)
google_place_id: Mapped[str] = mapped_column(
String(255), unique=True, nullable=False, index=True
)
formatted_address: Mapped[str] = mapped_column(String(500), nullable=False)

# Geography
# Geography (required fields)
latitude: Mapped[float] = mapped_column(DECIMAL(10, 8), nullable=False)
longitude: Mapped[float] = mapped_column(DECIMAL(11, 8), nullable=False)

# OCSL Data (fields with defaults)
warning_count: Mapped[int] = mapped_column(Integer, default=0)
citation_count: Mapped[int] = mapped_column(Integer, default=0)
hold_expiration: Mapped[datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True, default=None
)

# Address Components
street_number: Mapped[str | None] = mapped_column(String(20)) # e.g. "123"
street_name: Mapped[str | None] = mapped_column(String(255)) # e.g. "Main St"
unit: Mapped[str | None] = mapped_column(String(50)) # e.g. "Apt 4B"
city: Mapped[str | None] = mapped_column(String(100)) # e.g. "Chapel Hill"
county: Mapped[str | None] = mapped_column(String(100)) # e.g. "Orange County"
state: Mapped[str | None] = mapped_column(String(50)) # e.g. "NC"
country: Mapped[str | None] = mapped_column(String(2)) # e.g. "US"
zip_code: Mapped[str | None] = mapped_column(String(10)) # e.g. "27514"
street_number: Mapped[str | None] = mapped_column(String(20), default=None) # e.g. "123"
street_name: Mapped[str | None] = mapped_column(String(255), default=None) # e.g. "Main St"
unit: Mapped[str | None] = mapped_column(String(50), default=None) # e.g. "Apt 4B"
city: Mapped[str | None] = mapped_column(String(100), default=None) # e.g. "Chapel Hill"
county: Mapped[str | None] = mapped_column(String(100), default=None) # e.g. "Orange County"
state: Mapped[str | None] = mapped_column(String(50), default=None) # e.g. "NC"
country: Mapped[str | None] = mapped_column(String(2), default=None) # e.g. "US"
zip_code: Mapped[str | None] = mapped_column(String(10), default=None) # e.g. "27514"

# Relationships
complaints: Mapped[list["ComplaintEntity"]] = relationship(
"ComplaintEntity",
back_populates="location",
cascade="all, delete-orphan",
lazy="selectin", # Use selectin loading to avoid N+1 queries
init=False,
)

__table_args__ = (Index("idx_lat_lng", "latitude", "longitude"),)
Expand All @@ -57,7 +58,7 @@ def to_model(self) -> Location:
# This prevents issues when LocationEntity is created without loading relationships
insp = inspect(self)
complaints_loaded = "complaints" not in insp.unloaded

hold_exp = self.hold_expiration
if hold_exp is not None and hold_exp.tzinfo is None:
hold_exp = hold_exp.replace(tzinfo=timezone.utc)
Expand Down
Loading