From 17c4534f6d8aacfdfddd3497eac19a5c5aa1e017 Mon Sep 17 00:00:00 2001 From: Jason Moggridge Date: Sun, 25 Jan 2026 17:49:10 -0500 Subject: [PATCH] FIX: Add postgres storage adapter for samod --- Cargo.lock | 1 + packages/backend/Cargo.toml | 1 + packages/backend/src/app.rs | 13 +- packages/backend/src/lib.rs | 1 + packages/backend/src/main.rs | 80 ++------ packages/backend/src/storage/mod.rs | 4 + packages/backend/src/storage/postgres.rs | 110 ++++++++++ packages/backend/src/storage/testing.rs | 190 ++++++++++++++++++ .../backend/tests/postgres_storage_tests.rs | 50 +++++ 9 files changed, 375 insertions(+), 75 deletions(-) create mode 100644 packages/backend/src/lib.rs create mode 100644 packages/backend/src/storage/mod.rs create mode 100644 packages/backend/src/storage/postgres.rs create mode 100644 packages/backend/src/storage/testing.rs create mode 100644 packages/backend/tests/postgres_storage_tests.rs diff --git a/Cargo.lock b/Cargo.lock index 182939e74..0936f191d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -277,6 +277,7 @@ dependencies = [ "migrator", "notebook-types", "qubit", + "rand 0.8.5", "regex", "samod", "sd-notify", diff --git a/packages/backend/Cargo.toml b/packages/backend/Cargo.toml index 8cbcd282f..09b694c88 100644 --- a/packages/backend/Cargo.toml +++ b/packages/backend/Cargo.toml @@ -22,6 +22,7 @@ jsonrpsee = "0.24.6" jsonrpsee-server = "0.24.6" notebook-types = { version = "0.1.0", path = "../notebook-types" } qubit = { version = "1.0.0-beta.0", features = ["ts-serde-json", "ts-uuid"] } +rand = "0.8" regex = "1.11.1" samod = { git = "https://github.com/alexjg/samod", features = [ "tokio", diff --git a/packages/backend/src/app.rs b/packages/backend/src/app.rs index 7f737d02d..1eba187ef 100644 --- a/packages/backend/src/app.rs +++ b/packages/backend/src/app.rs @@ -4,7 +4,7 @@ use sqlx::PgPool; use std::collections::HashSet; use std::sync::Arc; use thiserror::Error; -use tokio::sync::{RwLock, watch}; +use tokio::sync::RwLock; use ts_rs::TS; use uuid::Uuid; @@ -19,21 +19,10 @@ pub struct AppState { /// Automerge-repo provider pub repo: samod::Repo, - pub app_status: watch::Receiver, - /// Tracks which ref_ids have active autosave listeners to prevent duplicates pub active_listeners: Arc>>, } -#[derive(Clone, Debug, PartialEq, Eq)] -pub enum AppStatus { - Starting, - Migrating, - Running, - #[allow(dead_code)] - Failed(String), -} - /// Context available to RPC procedures. #[derive(Clone)] pub struct AppCtx { diff --git a/packages/backend/src/lib.rs b/packages/backend/src/lib.rs new file mode 100644 index 000000000..30f61eb69 --- /dev/null +++ b/packages/backend/src/lib.rs @@ -0,0 +1 @@ +pub mod storage; diff --git a/packages/backend/src/main.rs b/packages/backend/src/main.rs index 0e8dae887..e13bdfe4e 100644 --- a/packages/backend/src/main.rs +++ b/packages/backend/src/main.rs @@ -1,21 +1,19 @@ use axum::extract::Request; use axum::extract::ws::WebSocketUpgrade; -use axum::middleware::{Next, from_fn_with_state}; +use axum::middleware::from_fn_with_state; use axum::{Router, routing::get}; use axum::{extract::State, response::IntoResponse}; use clap::{Parser, Subcommand}; use firebase_auth::FirebaseAuth; -use http::StatusCode; use sqlx::postgres::PgPoolOptions; -use sqlx::{PgPool, Postgres}; use sqlx_migrator::cli::MigrationCommand; use sqlx_migrator::migrator::{Migrate, Migrator}; use sqlx_migrator::{Info, Plan}; use std::collections::HashSet; use std::path::Path; use std::sync::Arc; -use tokio::sync::{RwLock, watch}; +use tokio::sync::RwLock; use tower::ServiceBuilder; use tower_http::cors::CorsLayer; use tracing::{error, info}; @@ -26,10 +24,9 @@ mod auth; mod automerge_json; mod document; mod rpc; +mod storage; mod user; -use app::AppStatus; - /// Port for the web server providing the RPC API. fn web_port() -> String { dotenvy::var("PORT").unwrap_or("8000".to_string()) @@ -100,17 +97,21 @@ async fn main() { } Command::Serve => { - let (status_tx, status_rx) = watch::channel(AppStatus::Starting); + info!("Applying database migrations..."); + let mut conn = db.acquire().await.expect("Failed to acquire DB connection"); + migrator + .run(&mut conn, &Plan::apply_all()) + .await + .expect("Failed to run migrations"); + info!("Migrations complete"); - // Create samod repo let repo = samod::Repo::builder(tokio::runtime::Handle::current()) - .with_storage(samod::storage::InMemoryStorage::new()) + .with_storage(storage::PostgresStorage::new(db.clone())) .load() .await; let state = app::AppState { db: db.clone(), - app_status: status_rx.clone(), repo, active_listeners: Arc::new(RwLock::new(HashSet::new())), }; @@ -127,48 +128,10 @@ async fn main() { .await, ); - tokio::try_join!( - run_migrator_apply(db.clone(), migrator, status_tx.clone()), - run_web_server(state.clone(), firebase_auth.clone()), - ) - .unwrap(); - } - } -} - -async fn run_migrator_apply( - db: PgPool, - migrator: Migrator, - status_tx: watch::Sender, -) -> Result<(), Box> { - status_tx.send(AppStatus::Migrating)?; - info!("Applying database migrations..."); - - let mut conn = db.acquire().await?; - migrator.run(&mut conn, &Plan::apply_all()).await.unwrap(); + // Notify systemd we're ready + sd_notify::notify(false, &[sd_notify::NotifyState::Ready]).ok(); - status_tx.send(AppStatus::Running)?; - sd_notify::notify(false, &[sd_notify::NotifyState::Ready])?; - info!("Migrations complete"); - - Ok(()) -} - -async fn app_status_gate( - State(status_rx): State>, - req: Request, - next: Next, -) -> impl IntoResponse { - // Combining the following 2 lines will anger the rust gods - let status = status_rx.borrow().clone(); - match status { - AppStatus::Running => next.run(req).await, - AppStatus::Failed(reason) => { - (StatusCode::INTERNAL_SERVER_ERROR, format!("App failed to start: {reason}")) - .into_response() - } - AppStatus::Starting | AppStatus::Migrating => { - (StatusCode::SERVICE_UNAVAILABLE, "Server not ready yet").into_response() + run_web_server(state.clone(), firebase_auth.clone()).await.unwrap(); } } } @@ -191,13 +154,8 @@ async fn auth_middleware( next.run(req).await } -async fn status_handler(State(status_rx): State>) -> String { - match status_rx.borrow().clone() { - AppStatus::Starting => "Starting".into(), - AppStatus::Migrating => "Migrating".into(), - AppStatus::Running => "Running".into(), - AppStatus::Failed(reason) => format!("Failed: {reason}"), - } +async fn status_handler() -> &'static str { + "Running" } async fn websocket_handler( @@ -223,20 +181,16 @@ async fn run_web_server( let (qubit_service, qubit_handle) = rpc_router.as_rpc(state.clone()).into_service(); let rpc_with_mw = ServiceBuilder::new() - .layer(from_fn_with_state(state.app_status.clone(), app_status_gate)) .layer(from_fn_with_state(firebase_auth.clone(), auth_middleware)) .service(qubit_service); let samod_router = Router::new() .layer(from_fn_with_state(firebase_auth, auth_middleware)) - .layer(from_fn_with_state(state.app_status.clone(), app_status_gate)) .route("/repo-ws", get(websocket_handler)) .with_state(state.repo.clone()); // used by tests to tell when the backend is ready - let status_router = Router::new() - .route("/status", get(status_handler)) - .with_state(state.app_status.clone()); + let status_router = Router::new().route("/status", get(status_handler)); let mut app = Router::new() .merge(status_router) diff --git a/packages/backend/src/storage/mod.rs b/packages/backend/src/storage/mod.rs new file mode 100644 index 000000000..e3d5c4549 --- /dev/null +++ b/packages/backend/src/storage/mod.rs @@ -0,0 +1,4 @@ +mod postgres; +pub mod testing; + +pub use postgres::PostgresStorage; diff --git a/packages/backend/src/storage/postgres.rs b/packages/backend/src/storage/postgres.rs new file mode 100644 index 000000000..3b80dff91 --- /dev/null +++ b/packages/backend/src/storage/postgres.rs @@ -0,0 +1,110 @@ +use samod::storage::{Storage, StorageKey}; +use sqlx::PgPool; +use std::collections::HashMap; + +/// A PostgreSQL-backed storage adapter for samod +/// +/// ## Database Schema +/// +/// The adapter requires a table with the following structure: +/// ```sql +/// CREATE TABLE storage ( +/// key text[] PRIMARY KEY, +/// data bytea NOT NULL +/// ); +/// ``` +#[derive(Clone)] +pub struct PostgresStorage { + pool: PgPool, +} + +impl PostgresStorage { + pub fn new(pool: PgPool) -> Self { + Self { pool } + } +} + +impl Storage for PostgresStorage { + async fn load(&self, key: StorageKey) -> Option> { + let key_parts: Vec = key.into_iter().collect(); + + let result = sqlx::query_scalar::<_, Vec>("SELECT data FROM storage WHERE key = $1") + .bind(&key_parts) + .fetch_optional(&self.pool) + .await; + + match result { + Ok(data) => data, + Err(e) => { + tracing::error!("Failed to load from storage: {}", e); + None + } + } + } + + async fn load_range(&self, prefix: StorageKey) -> HashMap> { + let prefix_parts: Vec = prefix.into_iter().collect(); + + let result = if prefix_parts.is_empty() { + sqlx::query_as::<_, (Vec, Vec)>("SELECT key, data FROM storage") + .fetch_all(&self.pool) + .await + } else { + sqlx::query_as::<_, (Vec, Vec)>( + "SELECT key, data FROM storage WHERE key[1:cardinality($1::text[])] = $1::text[]", + ) + .bind(&prefix_parts) + .fetch_all(&self.pool) + .await + }; + + match result { + Ok(rows) => { + let mut map = HashMap::new(); + for (key_parts, data) in rows { + if let Ok(storage_key) = StorageKey::from_parts(key_parts) { + map.insert(storage_key, data); + } + } + map + } + Err(e) => { + tracing::error!("Failed to load range from storage: {}", e); + HashMap::new() + } + } + } + + async fn put(&self, key: StorageKey, data: Vec) { + let key_parts: Vec = key.into_iter().collect(); + + let result = sqlx::query( + " + INSERT INTO storage (key, data) + VALUES ($1, $2) + ON CONFLICT (key) DO UPDATE SET data = $2 + ", + ) + .bind(&key_parts) + .bind(&data) + .execute(&self.pool) + .await; + + if let Err(e) = result { + tracing::error!("Failed to put to storage: {}", e); + } + } + + async fn delete(&self, key: StorageKey) { + let key_parts: Vec = key.into_iter().collect(); + + let result = sqlx::query("DELETE FROM storage WHERE key = $1") + .bind(&key_parts) + .execute(&self.pool) + .await; + + if let Err(e) = result { + tracing::error!("Failed to delete from storage: {}", e); + } + } +} diff --git a/packages/backend/src/storage/testing.rs b/packages/backend/src/storage/testing.rs new file mode 100644 index 000000000..7aebe5cd1 --- /dev/null +++ b/packages/backend/src/storage/testing.rs @@ -0,0 +1,190 @@ +//! Storage adapter testing utilities +//! +//! rewritten from: +//! automerge-repo/packages/automerge-repo/src/helpers/tests/storage-adapter-tests.ts +//! +//! Provides a test suite for any implementation of the `Storage` trait. +//! Based on the TypeScript `runStorageAdapterTests` from automerge-repo. + +#![allow(dead_code)] + +use rand::Rng; +use samod::storage::{Storage, StorageKey}; +use std::future::Future; +use std::pin::Pin; +use std::sync::LazyLock; + +pub fn payload_a() -> Vec { + vec![0, 1, 127, 99, 154, 235] +} + +pub fn payload_b() -> Vec { + vec![1, 76, 160, 53, 57, 10, 230] +} + +pub fn payload_c() -> Vec { + vec![2, 111, 74, 131, 236, 96, 142, 193] +} + +static LARGE_PAYLOAD: LazyLock> = LazyLock::new(|| { + let mut vec = vec![0u8; 100000]; + rand::thread_rng().fill(&mut vec[..]); + vec +}); + +pub fn large_payload() -> Vec { + LARGE_PAYLOAD.clone() +} + +/// Trait for storage test fixtures +pub trait StorageTestFixture: Sized + Send { + /// The storage type being tested + type Storage: Storage + Send + Sync + 'static; + + /// Setup the test fixture + fn setup() -> impl std::future::Future + Send; + + /// Get reference to the storage adapter + fn storage(&self) -> &Self::Storage; + + /// Optional cleanup + fn teardown(self) -> impl std::future::Future + Send { + async {} + } +} + +/// Helper to run a single test with setup and teardown +async fn run_test(test_fn: TestFn) +where + F: StorageTestFixture, + TestFn: for<'a> FnOnce(&'a F::Storage) -> Pin + Send + 'a>> + Send, +{ + let fixture = F::setup().await; + test_fn(fixture.storage()).await; + fixture.teardown().await; +} + +/// Run all storage adapter acceptance tests +pub async fn run_storage_adapter_tests() { + run_test::(|a| Box::pin(test_load_should_return_none_if_no_data(a))).await; + run_test::(|a| Box::pin(test_save_and_load_should_return_data_that_was_saved(a))).await; + run_test::(|a| Box::pin(test_save_and_load_should_work_with_composite_keys(a))).await; + run_test::(|a| Box::pin(test_save_and_load_should_work_with_large_payload(a))).await; + run_test::(|a| Box::pin(test_load_range_should_return_empty_if_no_data(a))).await; + run_test::(|a| Box::pin(test_save_and_load_range_should_return_all_matching_data(a))) + .await; + run_test::(|a| Box::pin(test_save_and_load_range_should_only_load_matching_values(a))) + .await; + run_test::(|a| Box::pin(test_save_and_remove_should_be_empty_after_removing(a))).await; + run_test::(|a| Box::pin(test_save_and_save_should_overwrite(a))).await; +} + +// describe("load") +pub async fn test_load_should_return_none_if_no_data(adapter: &S) { + let actual = adapter + .load(StorageKey::from_parts(["AAAAA", "sync-state", "xxxxx"]).unwrap()) + .await; + + assert_eq!(actual, None); +} + +// describe("save and load") +pub async fn test_save_and_load_should_return_data_that_was_saved(adapter: &S) { + let key = StorageKey::from_parts(["storage-adapter-id"]).unwrap(); + adapter.put(key.clone(), payload_a()).await; + + let actual = adapter.load(key).await; + + assert_eq!(actual, Some(payload_a())); +} + +pub async fn test_save_and_load_should_work_with_composite_keys(adapter: &S) { + let key = StorageKey::from_parts(["AAAAA", "sync-state", "xxxxx"]).unwrap(); + adapter.put(key.clone(), payload_a()).await; + + let actual = adapter.load(key).await; + + assert_eq!(actual, Some(payload_a())); +} + +pub async fn test_save_and_load_should_work_with_large_payload(adapter: &S) { + let key = StorageKey::from_parts(["AAAAA", "sync-state", "xxxxx"]).unwrap(); + adapter.put(key.clone(), large_payload()).await; + + let actual = adapter.load(key).await; + + assert_eq!(actual, Some(large_payload())); +} + +// describe("loadRange") +pub async fn test_load_range_should_return_empty_if_no_data(adapter: &S) { + let result = adapter.load_range(StorageKey::from_parts(["AAAAA"]).unwrap()).await; + + assert_eq!(result.len(), 0); +} + +// describe("save and loadRange") +pub async fn test_save_and_load_range_should_return_all_matching_data(adapter: &S) { + let key_a = StorageKey::from_parts(["AAAAA", "sync-state", "xxxxx"]).unwrap(); + let key_b = StorageKey::from_parts(["AAAAA", "snapshot", "yyyyy"]).unwrap(); + let key_c = StorageKey::from_parts(["AAAAA", "sync-state", "zzzzz"]).unwrap(); + + adapter.put(key_a.clone(), payload_a()).await; + adapter.put(key_b.clone(), payload_b()).await; + adapter.put(key_c.clone(), payload_c()).await; + + let result = adapter.load_range(StorageKey::from_parts(["AAAAA"]).unwrap()).await; + + assert_eq!(result.len(), 3); + assert_eq!(result.get(&key_a), Some(&payload_a())); + assert_eq!(result.get(&key_b), Some(&payload_b())); + assert_eq!(result.get(&key_c), Some(&payload_c())); + + let sync_result = adapter + .load_range(StorageKey::from_parts(["AAAAA", "sync-state"]).unwrap()) + .await; + + assert_eq!(sync_result.len(), 2); + assert_eq!(sync_result.get(&key_a), Some(&payload_a())); + assert_eq!(sync_result.get(&key_c), Some(&payload_c())); +} + +pub async fn test_save_and_load_range_should_only_load_matching_values(adapter: &S) { + let key_a = StorageKey::from_parts(["AAAAA", "sync-state", "xxxxx"]).unwrap(); + let key_c = StorageKey::from_parts(["BBBBB", "sync-state", "zzzzz"]).unwrap(); + + adapter.put(key_a.clone(), payload_a()).await; + adapter.put(key_c.clone(), payload_c()).await; + + let actual = adapter.load_range(StorageKey::from_parts(["AAAAA"]).unwrap()).await; + + assert_eq!(actual.len(), 1); + assert_eq!(actual.get(&key_a), Some(&payload_a())); +} + +// describe("save and remove") +pub async fn test_save_and_remove_should_be_empty_after_removing(adapter: &S) { + let key = StorageKey::from_parts(["AAAAA", "snapshot", "xxxxx"]).unwrap(); + adapter.put(key.clone(), payload_a()).await; + adapter.delete(key.clone()).await; + + let range_result = adapter.load_range(StorageKey::from_parts(["AAAAA"]).unwrap()).await; + assert_eq!(range_result.len(), 0); + + let load_result = adapter.load(key).await; + assert_eq!(load_result, None); +} + +// describe("save and save") +pub async fn test_save_and_save_should_overwrite(adapter: &S) { + let key = StorageKey::from_parts(["AAAAA", "sync-state", "xxxxx"]).unwrap(); + adapter.put(key.clone(), payload_a()).await; + adapter.put(key.clone(), payload_b()).await; + + let result = adapter + .load_range(StorageKey::from_parts(["AAAAA", "sync-state"]).unwrap()) + .await; + + assert_eq!(result.len(), 1); + assert_eq!(result.get(&key), Some(&payload_b())); +} diff --git a/packages/backend/tests/postgres_storage_tests.rs b/packages/backend/tests/postgres_storage_tests.rs new file mode 100644 index 000000000..50ff95e1f --- /dev/null +++ b/packages/backend/tests/postgres_storage_tests.rs @@ -0,0 +1,50 @@ +use backend::storage::{PostgresStorage, testing}; +use sqlx::PgPool; + +async fn cleanup_test_data(pool: &PgPool) { + let _ = sqlx::query("DELETE FROM storage WHERE key[1] = ANY($1)") + .bind(["AAAAA", "BBBBB", "storage-adapter-id"]) + .execute(pool) + .await; +} + +struct PostgresTestFixture { + storage: PostgresStorage, + pool: PgPool, +} + +impl testing::StorageTestFixture for PostgresTestFixture { + type Storage = PostgresStorage; + + async fn setup() -> Self { + let database_url = + std::env::var("DATABASE_URL").expect("DATABASE_URL must be set for tests"); + + let pool = PgPool::connect(&database_url).await.expect("Failed to connect to database"); + + cleanup_test_data(&pool).await; + + let storage = PostgresStorage::new(pool.clone()); + + Self { storage, pool } + } + + fn storage(&self) -> &PostgresStorage { + &self.storage + } + + async fn teardown(self) { + cleanup_test_data(&self.pool).await; + } +} + +#[tokio::test] +async fn postgres_storage_adapter_tests() { + // Skip test if DATABASE_URL is not set (e.g., in CI without postgres) + if std::env::var("DATABASE_URL").is_err() { + eprintln!("Skipping postgres_storage_adapter_tests: DATABASE_URL not set"); + return; + } + + testing::run_storage_adapter_tests::().await; +}