Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions lib/db-helpers.test.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { describe, it, expect, vi, beforeEach } from "vitest";
import { db } from "@/lib/database";
import { success } from "@/lib/server-action-result";

// Mock the database module
vi.mock("@/lib/database", () => ({
Expand Down Expand Up @@ -47,3 +48,39 @@ describe("withRLS", () => {
);
});
});

describe("withRLSAction", () => {
const mockExecute = vi.fn();

beforeEach(() => {
vi.clearAllMocks();

vi.mocked(db.transaction).mockReturnValue({
execute: mockExecute,
} as any);
});

it("should start a transaction", async () => {
const { withRLSAction } = await import("./db-helpers");

mockExecute.mockRejectedValue(new Error("Expected test error"));

try {
await withRLSAction(123, async () => success("result"));
} catch {
// Expected - sql template compilation fails in test environment
}

expect(db.transaction).toHaveBeenCalled();
});

it("should propagate transaction-level errors", async () => {
const { withRLSAction } = await import("./db-helpers");

mockExecute.mockRejectedValue(new Error("Connection refused"));

await expect(
withRLSAction(123, async () => success("result")),
).rejects.toThrow("Connection refused");
});
});
59 changes: 59 additions & 0 deletions lib/db-helpers.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { Database } from "@/types/db_types";
import { db } from "@/lib/database";
import { Transaction, sql } from "kysely";
import type { ServerActionResult } from "@/lib/server-action-result";

/**
* Executes a database operation within a transaction with Row Level Security (RLS) context.
Expand Down Expand Up @@ -30,3 +31,61 @@ export async function withRLS<T>(
return fn(trx);
});
}

/**
* Sentinel error used to trigger transaction rollback when the callback
* returns a ServerActionResult with success: false.
*/
class RollbackWithResult<T> extends Error {
constructor(public result: ServerActionResult<T>) {
super("rollback");
}
}

/**
* Like `withRLS`, but the callback returns a `ServerActionResult<T>` directly.
*
* This allows clean early `return error(...)` inside the transaction while
* preserving RLS context and transactional atomicity. If the callback returns
* an error result, the transaction is automatically rolled back.
*
* @example
* const result = await withRLSAction(currentUser.id, async (trx) => {
* const membership = await trx.selectFrom("competition_members")
* .select("role")
* .where("competition_id", "=", competitionId)
* .where("user_id", "=", currentUser.id)
* .executeTakeFirst();
*
* if (membership?.role !== "admin") {
* return error("Only admins can do this", ERROR_CODES.UNAUTHORIZED);
* }
*
* const inserted = await trx.insertInto("competition_members").values(...).returningAll().executeTakeFirstOrThrow();
* return success(inserted);
* });
*/
export async function withRLSAction<T>(
userId: number | undefined,
fn: (trx: Transaction<Database>) => Promise<ServerActionResult<T>>,
): Promise<ServerActionResult<T>> {
try {
return await db.transaction().execute(async (trx) => {
await trx.executeQuery(
sql`SELECT set_config('app.current_user_id', ${userId}, true);`.compile(
db,
),
);
const result = await fn(trx);
if (!result.success) {
throw new RollbackWithResult(result);
}
return result;
});
} catch (err) {
if (err instanceof RollbackWithResult) {
return err.result as ServerActionResult<T>;
}
throw err;
}
}
Loading
Loading