Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions packages/backend/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ jsonrpsee = "0.24.6"
jsonrpsee-server = "0.24.6"
notebook-types = { version = "0.1.0", path = "../notebook-types" }
qubit = { version = "1.0.0-beta.0", features = ["ts-serde-json", "ts-uuid"] }
rand = "0.8"
regex = "1.11.1"
samod = { git = "https://github.com/alexjg/samod", features = [
"tokio",
Expand Down
13 changes: 1 addition & 12 deletions packages/backend/src/app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use sqlx::PgPool;
use std::collections::HashSet;
use std::sync::Arc;
use thiserror::Error;
use tokio::sync::{RwLock, watch};
use tokio::sync::RwLock;
use ts_rs::TS;
use uuid::Uuid;

Expand All @@ -19,21 +19,10 @@ pub struct AppState {
/// Automerge-repo provider
pub repo: samod::Repo,

pub app_status: watch::Receiver<AppStatus>,

/// Tracks which ref_ids have active autosave listeners to prevent duplicates
pub active_listeners: Arc<RwLock<HashSet<Uuid>>>,
}

#[derive(Clone, Debug, PartialEq, Eq)]
pub enum AppStatus {
Starting,
Migrating,
Running,
#[allow(dead_code)]
Failed(String),
}

/// Context available to RPC procedures.
#[derive(Clone)]
pub struct AppCtx {
Expand Down
1 change: 1 addition & 0 deletions packages/backend/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pub mod storage;
80 changes: 17 additions & 63 deletions packages/backend/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,19 @@
use axum::extract::Request;

use axum::extract::ws::WebSocketUpgrade;
use axum::middleware::{Next, from_fn_with_state};
use axum::middleware::from_fn_with_state;
use axum::{Router, routing::get};
use axum::{extract::State, response::IntoResponse};
use clap::{Parser, Subcommand};
use firebase_auth::FirebaseAuth;
use http::StatusCode;
use sqlx::postgres::PgPoolOptions;
use sqlx::{PgPool, Postgres};
use sqlx_migrator::cli::MigrationCommand;
use sqlx_migrator::migrator::{Migrate, Migrator};
use sqlx_migrator::{Info, Plan};
use std::collections::HashSet;
use std::path::Path;
use std::sync::Arc;
use tokio::sync::{RwLock, watch};
use tokio::sync::RwLock;
use tower::ServiceBuilder;
use tower_http::cors::CorsLayer;
use tracing::{error, info};
Expand All @@ -26,10 +24,9 @@ mod auth;
mod automerge_json;
mod document;
mod rpc;
mod storage;
mod user;

use app::AppStatus;

/// Port for the web server providing the RPC API.
fn web_port() -> String {
dotenvy::var("PORT").unwrap_or("8000".to_string())
Expand Down Expand Up @@ -100,17 +97,21 @@ async fn main() {
}

Command::Serve => {
let (status_tx, status_rx) = watch::channel(AppStatus::Starting);
info!("Applying database migrations...");
let mut conn = db.acquire().await.expect("Failed to acquire DB connection");
migrator
.run(&mut conn, &Plan::apply_all())
.await
.expect("Failed to run migrations");
info!("Migrations complete");

// Create samod repo
let repo = samod::Repo::builder(tokio::runtime::Handle::current())
.with_storage(samod::storage::InMemoryStorage::new())
.with_storage(storage::PostgresStorage::new(db.clone()))
.load()
.await;

let state = app::AppState {
db: db.clone(),
app_status: status_rx.clone(),
repo,
active_listeners: Arc::new(RwLock::new(HashSet::new())),
};
Expand All @@ -127,48 +128,10 @@ async fn main() {
.await,
);

tokio::try_join!(
run_migrator_apply(db.clone(), migrator, status_tx.clone()),
run_web_server(state.clone(), firebase_auth.clone()),
)
.unwrap();
}
}
}

async fn run_migrator_apply(
db: PgPool,
migrator: Migrator<Postgres>,
status_tx: watch::Sender<AppStatus>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
status_tx.send(AppStatus::Migrating)?;
info!("Applying database migrations...");

let mut conn = db.acquire().await?;
migrator.run(&mut conn, &Plan::apply_all()).await.unwrap();
// Notify systemd we're ready
sd_notify::notify(false, &[sd_notify::NotifyState::Ready]).ok();

status_tx.send(AppStatus::Running)?;
sd_notify::notify(false, &[sd_notify::NotifyState::Ready])?;
info!("Migrations complete");

Ok(())
}

async fn app_status_gate(
State(status_rx): State<watch::Receiver<AppStatus>>,
req: Request,
next: Next,
) -> impl IntoResponse {
// Combining the following 2 lines will anger the rust gods
let status = status_rx.borrow().clone();
match status {
AppStatus::Running => next.run(req).await,
AppStatus::Failed(reason) => {
(StatusCode::INTERNAL_SERVER_ERROR, format!("App failed to start: {reason}"))
.into_response()
}
AppStatus::Starting | AppStatus::Migrating => {
(StatusCode::SERVICE_UNAVAILABLE, "Server not ready yet").into_response()
run_web_server(state.clone(), firebase_auth.clone()).await.unwrap();
}
}
}
Expand All @@ -191,13 +154,8 @@ async fn auth_middleware(
next.run(req).await
}

async fn status_handler(State(status_rx): State<watch::Receiver<AppStatus>>) -> String {
match status_rx.borrow().clone() {
AppStatus::Starting => "Starting".into(),
AppStatus::Migrating => "Migrating".into(),
AppStatus::Running => "Running".into(),
AppStatus::Failed(reason) => format!("Failed: {reason}"),
}
async fn status_handler() -> &'static str {
"Running"
}

async fn websocket_handler(
Expand All @@ -223,20 +181,16 @@ async fn run_web_server(
let (qubit_service, qubit_handle) = rpc_router.as_rpc(state.clone()).into_service();

let rpc_with_mw = ServiceBuilder::new()
.layer(from_fn_with_state(state.app_status.clone(), app_status_gate))
.layer(from_fn_with_state(firebase_auth.clone(), auth_middleware))
.service(qubit_service);

let samod_router = Router::new()
.layer(from_fn_with_state(firebase_auth, auth_middleware))
.layer(from_fn_with_state(state.app_status.clone(), app_status_gate))
.route("/repo-ws", get(websocket_handler))
.with_state(state.repo.clone());

// used by tests to tell when the backend is ready
let status_router = Router::new()
.route("/status", get(status_handler))
.with_state(state.app_status.clone());
let status_router = Router::new().route("/status", get(status_handler));

let mut app = Router::new()
.merge(status_router)
Expand Down
4 changes: 4 additions & 0 deletions packages/backend/src/storage/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
mod postgres;
pub mod testing;

pub use postgres::PostgresStorage;
110 changes: 110 additions & 0 deletions packages/backend/src/storage/postgres.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
use samod::storage::{Storage, StorageKey};
use sqlx::PgPool;
use std::collections::HashMap;

/// A PostgreSQL-backed storage adapter for samod
///
/// ## Database Schema
///
/// The adapter requires a table with the following structure:
/// ```sql
/// CREATE TABLE storage (
/// key text[] PRIMARY KEY,
/// data bytea NOT NULL
/// );
/// ```
#[derive(Clone)]
pub struct PostgresStorage {
pool: PgPool,
}

impl PostgresStorage {
pub fn new(pool: PgPool) -> Self {
Self { pool }
}
}

impl Storage for PostgresStorage {
async fn load(&self, key: StorageKey) -> Option<Vec<u8>> {
let key_parts: Vec<String> = key.into_iter().collect();

let result = sqlx::query_scalar::<_, Vec<u8>>("SELECT data FROM storage WHERE key = $1")
.bind(&key_parts)
.fetch_optional(&self.pool)
.await;

match result {
Ok(data) => data,
Err(e) => {
tracing::error!("Failed to load from storage: {}", e);
None
}
}
}

async fn load_range(&self, prefix: StorageKey) -> HashMap<StorageKey, Vec<u8>> {
let prefix_parts: Vec<String> = prefix.into_iter().collect();

let result = if prefix_parts.is_empty() {
sqlx::query_as::<_, (Vec<String>, Vec<u8>)>("SELECT key, data FROM storage")
.fetch_all(&self.pool)
.await
} else {
sqlx::query_as::<_, (Vec<String>, Vec<u8>)>(
"SELECT key, data FROM storage WHERE key[1:cardinality($1::text[])] = $1::text[]",
)
.bind(&prefix_parts)
.fetch_all(&self.pool)
.await
};

match result {
Ok(rows) => {
let mut map = HashMap::new();
for (key_parts, data) in rows {
if let Ok(storage_key) = StorageKey::from_parts(key_parts) {
map.insert(storage_key, data);
}
}
map
}
Err(e) => {
tracing::error!("Failed to load range from storage: {}", e);
HashMap::new()
}
}
}

async fn put(&self, key: StorageKey, data: Vec<u8>) {
let key_parts: Vec<String> = key.into_iter().collect();

let result = sqlx::query(
"
INSERT INTO storage (key, data)
VALUES ($1, $2)
ON CONFLICT (key) DO UPDATE SET data = $2
",
)
.bind(&key_parts)
.bind(&data)
.execute(&self.pool)
.await;

if let Err(e) = result {
tracing::error!("Failed to put to storage: {}", e);
}
}

async fn delete(&self, key: StorageKey) {
let key_parts: Vec<String> = key.into_iter().collect();

let result = sqlx::query("DELETE FROM storage WHERE key = $1")
.bind(&key_parts)
.execute(&self.pool)
.await;

if let Err(e) = result {
tracing::error!("Failed to delete from storage: {}", e);
}
}
}
Loading