Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions src/rbac/core/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@
import json


def _utcnow() -> datetime:
"""Timezone-aware UTC now — used as a default_factory in dataclasses."""
return datetime.now(timezone.utc)


class EntityStatus(Enum):
"""Status of an entity in the system."""
ACTIVE = "active"
Expand Down Expand Up @@ -55,8 +60,8 @@ class User:
attributes: Dict[str, Any] = field(default_factory=dict)
status: EntityStatus = EntityStatus.ACTIVE
domain: Optional[str] = None
created_at: datetime = field(default_factory=datetime.utcnow)
updated_at: datetime = field(default_factory=datetime.utcnow)
created_at: datetime = field(default_factory=_utcnow)
updated_at: datetime = field(default_factory=_utcnow)

def __post_init__(self):
"""Validate user data after initialization."""
Expand Down Expand Up @@ -167,7 +172,7 @@ class Permission:
action: str
description: Optional[str] = None
conditions: Dict[str, Any] = field(default_factory=dict)
created_at: datetime = field(default_factory=datetime.utcnow)
created_at: datetime = field(default_factory=_utcnow)

def __post_init__(self):
"""Validate permission data."""
Expand Down Expand Up @@ -286,8 +291,8 @@ class Resource:
parent_id: Optional[str] = None
status: EntityStatus = EntityStatus.ACTIVE
domain: Optional[str] = None
created_at: datetime = field(default_factory=datetime.utcnow)
updated_at: datetime = field(default_factory=datetime.utcnow)
created_at: datetime = field(default_factory=_utcnow)
updated_at: datetime = field(default_factory=_utcnow)

def __post_init__(self):
"""Validate resource data."""
Expand Down
14 changes: 9 additions & 5 deletions src/rbac/storage/sqlalchemy_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@
# ORM declarations
# ---------------------------------------------------------------------------

# FK target constant – avoids repeating the table.column string literal
_FK_ROLES_ID = "rbac_roles.id"


class _Base(DeclarativeBase):
"""Shared declarative base for all RBAC ORM models."""

Expand All @@ -83,7 +87,7 @@ class _Base(DeclarativeBase):
_role_permissions = Table(
"rbac_role_permissions",
_Base.metadata,
Column("role_id", String(255), ForeignKey("rbac_roles.id", ondelete="CASCADE"), nullable=False),
Column("role_id", String(255), ForeignKey(_FK_ROLES_ID, ondelete="CASCADE"), nullable=False),
Column("permission_id", String(255), ForeignKey("rbac_permissions.id", ondelete="CASCADE"), nullable=False),
UniqueConstraint("role_id", "permission_id", name="uq_role_permission"),
)
Expand Down Expand Up @@ -119,7 +123,7 @@ class _RoleRow(_Base):
id = Column(String(255), primary_key=True)
name = Column(String(255), nullable=False)
description = Column(Text, nullable=True)
parent_id = Column(String(255), ForeignKey("rbac_roles.id", ondelete="SET NULL"), nullable=True)
parent_id = Column(String(255), ForeignKey(_FK_ROLES_ID, ondelete="SET NULL"), nullable=True)
domain = Column(String(255), nullable=True)
status = Column(String(50), nullable=False, default="active")
metadata_json = Column(Text, nullable=False, default="{}")
Expand Down Expand Up @@ -175,7 +179,7 @@ class _RoleAssignmentRow(_Base):

id = Column(String(255), primary_key=True) # composite surrogate
user_id = Column(String(255), ForeignKey("rbac_users.id", ondelete="CASCADE"), nullable=False)
role_id = Column(String(255), ForeignKey("rbac_roles.id", ondelete="CASCADE"), nullable=False)
role_id = Column(String(255), ForeignKey(_FK_ROLES_ID, ondelete="CASCADE"), nullable=False)
domain = Column(String(255), nullable=True)
granted_by = Column(String(255), nullable=True)
granted_at = Column(DateTime(timezone=True), nullable=False)
Expand Down Expand Up @@ -398,7 +402,7 @@ def __init__(
kwargs["connect_args"] = {"check_same_thread": False}

self._engine = create_engine(database_url, **kwargs)
self._Session = sessionmaker(bind=self._engine, expire_on_commit=False)
self._session_factory = sessionmaker(bind=self._engine, expire_on_commit=False)

# ------------------------------------------------------------------
# Lifecycle
Expand Down Expand Up @@ -433,7 +437,7 @@ def _validate_role_assignment(self, assignment: RoleAssignment) -> None:
@contextmanager
def _session(self) -> Generator[Session, None, None]:
"""Provide a transactional session scope."""
session: Session = self._Session()
session: Session = self._session_factory()
try:
yield session
session.commit()
Expand Down
Loading
Loading