diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index e80f58d..31bddad 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -41,7 +41,7 @@ jobs:
sudo mariadb -e "FLUSH PRIVILEGES";
- name: Setup Rust
- uses: dtolnay/rust-toolchain@1.89.0
+ uses: dtolnay/rust-toolchain@1.92.0
- name: Configure cache
uses: Swatinem/rust-cache@v2
@@ -50,7 +50,8 @@ jobs:
env:
DATABASE_URL: mysql://api:${{ secrets.MARIADB_PW }}@localhost:3306/master_db
REDIS_URL: redis://localhost:6379
- run: RUST_BACKTRACE=1 cargo test -F mysql -F test
+ GQL_API_CURSOR_SECRET_KEY: ${{ secrets.GQL_API_CURSOR_SECRET_KEY }}
+ run: RUST_BACKTRACE=1 cargo test -F mysql
lint:
name: Lint
@@ -61,13 +62,16 @@ jobs:
uses: actions/checkout@v4
- name: Setup Rust
- uses: dtolnay/rust-toolchain@1.89.0
+ uses: dtolnay/rust-toolchain@1.92.0
- name: Setup Clippy
run: rustup component add clippy
- name: Run clippy
run: |
- cargo clippy --no-default-features -- -D warnings
- cargo clippy -- -D warnings
+ cargo clippy --no-default-features -F mysql -- -D warnings
+ cargo clippy -r --no-default-features -F mysql -- -D warnings
+ cargo clippy -F mysql -- -D warnings
+ cargo clippy -r -F mysql -- -D warnings
cargo clippy --all-features -- -D warnings
+ cargo clippy -r --all-features -- -D warnings
diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml
index 5909c1a..7c02e2c 100644
--- a/.github/workflows/docs.yml
+++ b/.github/workflows/docs.yml
@@ -17,7 +17,7 @@ jobs:
- name: Checkout repository
uses: actions/checkout@v4
- name: Setup Rust
- uses: dtolnay/rust-toolchain@1.89.0
+ uses: dtolnay/rust-toolchain@1.92.0
- name: Configure cache
uses: Swatinem/rust-cache@v2
- name: Setup pages
@@ -26,7 +26,7 @@ jobs:
- name: Clean docs folder
run: cargo clean --doc
- name: Build docs
- run: cargo +nightly doc --all-features --no-deps --workspace --exclude clear_redis_mappacks --exclude admin
+ run: cargo +nightly doc --all-features --no-deps --workspace --exclude admin
- name: Add redirect
run: echo '' > target/doc/index.html
- name: Remove lock file
diff --git a/Cargo.toml b/Cargo.toml
index f9910dd..7382ba0 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -13,40 +13,41 @@ members = [
"crates/graphql-schema-generator",
"crates/player-map-ranking",
"crates/compute-player-map-ranking",
+ "crates/test-env",
]
[workspace.dependencies]
-thiserror = "2.0.12"
-tokio = "1.44.1"
-async-graphql = "7.0.15"
-sqlx = { version = "0.8.3", features = ["chrono", "mysql", "runtime-tokio"] }
-chrono = { version = "0.4.40", features = ["serde"] }
-deadpool = { version = "0.12.2", features = ["managed", "rt_tokio_1"] }
+thiserror = "2.0.17"
+tokio = "1.49.0"
+async-graphql = "7.1.0"
+sqlx = { version = "0.8.6", features = ["chrono", "mysql", "runtime-tokio"] }
+chrono = { version = "0.4.42", features = ["serde"] }
+deadpool = { version = "0.12.3", features = ["managed", "rt_tokio_1"] }
deadpool-redis = { version = "0.20.0", features = ["rt_tokio_1"] }
-serde = "1.0.219"
-tracing = "0.1.41"
-tracing-subscriber = "0.3.19"
-actix-web = "4.10.2"
+serde = "1.0.228"
+tracing = "0.1.44"
+tracing-subscriber = "0.3.22"
+actix-web = "4.12.1"
actix-cors = "0.7.1"
-async-graphql-actix-web = "7.0.15"
-tracing-actix-web = "0.7.16"
-reqwest = { version = "0.12.14", features = ["json"] }
-rand = "0.9.0"
-futures = "0.3.27"
+async-graphql-actix-web = "7.1.0"
+tracing-actix-web = "0.7.20"
+reqwest = { version = "0.12.28", features = ["json"] }
+rand = "0.9.2"
+futures = "0.3.31"
sha256 = "1.6.0"
actix-session = { version = "0.10.1", features = ["cookie-session"] }
-anyhow = "1.0.97"
+anyhow = "1.0.100"
dotenvy = "0.15.7"
itertools = "0.14.0"
-once_cell = "1.21.1"
-csv = "1.3.0"
-mkenv = "0.1.6"
+once_cell = "1.21.3"
+csv = "1.4.0"
+mkenv = "1.0.2"
nom = "8.0.0"
pin-project-lite = "0.2.16"
-sea-orm = { version = "1.1.14", features = [
+sea-orm = { version = "1.1.19", features = [
"runtime-tokio",
"macros",
"with-chrono",
] }
-serde_json = "1.0.143"
+serde_json = "1.0.149"
prettytable = { version = "0.10.0", default-features = false }
diff --git a/crates/admin/Cargo.toml b/crates/admin/Cargo.toml
index a50bc98..fe7716c 100644
--- a/crates/admin/Cargo.toml
+++ b/crates/admin/Cargo.toml
@@ -5,7 +5,7 @@ edition = "2024"
[dependencies]
anyhow = { workspace = true }
-clap = { version = "4.5.16", features = ["derive"] }
+clap = { version = "4.5.54", features = ["derive"] }
csv = { workspace = true }
deadpool-redis = { workspace = true }
dotenvy = { workspace = true }
diff --git a/crates/admin/src/main.rs b/crates/admin/src/main.rs
index 3c05e4c..26a7207 100644
--- a/crates/admin/src/main.rs
+++ b/crates/admin/src/main.rs
@@ -1,5 +1,5 @@
use clap::Parser;
-use mkenv::Env as _;
+use mkenv::prelude::*;
use records_lib::{Database, DbEnv, LibEnv};
use self::{clear::ClearCommand, leaderboard::LbCommand, populate::PopulateCommand};
@@ -24,7 +24,12 @@ enum EventCommand {
Clear(ClearCommand),
}
-mkenv::make_env!(Env includes [DbEnv as db_env, LibEnv as lib_env]:);
+mkenv::make_config! {
+ struct Env {
+ db_env: { DbEnv },
+ lib_env: { LibEnv },
+ }
+}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
@@ -33,11 +38,15 @@ async fn main() -> anyhow::Result<()> {
.compact()
.try_init()
.map_err(|e| anyhow::anyhow!("unable to init tracing_subscriber: {e}"))?;
- let env = Env::try_get()?;
+ let env = Env::define();
+ env.init();
records_lib::init_env(env.lib_env);
- let db =
- Database::from_db_url(env.db_env.db_url.db_url, env.db_env.redis_url.redis_url).await?;
+ let db = Database::from_db_url(
+ env.db_env.db_url.db_url.get(),
+ env.db_env.redis_url.redis_url.get(),
+ )
+ .await?;
let cmd = Command::parse();
diff --git a/crates/compute-player-map-ranking/Cargo.toml b/crates/compute-player-map-ranking/Cargo.toml
index 1d298c6..ebfe68c 100644
--- a/crates/compute-player-map-ranking/Cargo.toml
+++ b/crates/compute-player-map-ranking/Cargo.toml
@@ -5,7 +5,7 @@ edition = "2024"
[dependencies]
anyhow = { workspace = true }
-clap = { version = "4.5.48", features = ["derive"] }
+clap = { version = "4.5.54", features = ["derive"] }
dotenvy = { workspace = true }
mkenv = { workspace = true }
player-map-ranking = { path = "../player-map-ranking" }
diff --git a/crates/compute-player-map-ranking/src/main.rs b/crates/compute-player-map-ranking/src/main.rs
index 67ed3a9..ea9f523 100644
--- a/crates/compute-player-map-ranking/src/main.rs
+++ b/crates/compute-player-map-ranking/src/main.rs
@@ -11,11 +11,15 @@ use std::{
use anyhow::Context as _;
use chrono::{DateTime, Days, Months, Utc};
use clap::Parser as _;
-use mkenv::Env as _;
+use mkenv::prelude::*;
use records_lib::{DbUrlEnv, time::Time};
use sea_orm::Database;
-mkenv::make_env! {AppEnv includes [DbUrlEnv as db_env]:}
+mkenv::make_config! {
+ struct AppEnv {
+ db_env: { DbUrlEnv },
+ }
+}
#[derive(Clone)]
struct SinceDuration {
@@ -96,10 +100,11 @@ async fn main() -> anyhow::Result<()> {
})
.context("couldn't write header to map ranking file")?;
- let db_url = AppEnv::try_get()
- .context("couldn't initialize environment")?
- .db_env
- .db_url;
+ let app_config = AppEnv::define();
+ app_config
+ .try_init()
+ .map_err(|e| anyhow::anyhow!("couldn't initialize environment: {e}"))?;
+ let db_url = app_config.db_env.db_url.get();
let db = Database::connect(db_url)
.await
.context("couldn't connect to database")?;
diff --git a/crates/entity/src/entities/maps.rs b/crates/entity/src/entities/maps.rs
index 30d4aeb..c983b04 100644
--- a/crates/entity/src/entities/maps.rs
+++ b/crates/entity/src/entities/maps.rs
@@ -1,7 +1,7 @@
use sea_orm::entity::prelude::*;
/// A ShootMania Obstacle map in the database.
-#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Eq)]
+#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
#[sea_orm(table_name = "maps")]
pub struct Model {
/// The map ID.
@@ -32,6 +32,8 @@ pub struct Model {
pub gold_time: Option,
/// The author time of the map.
pub author_time: Option,
+ /// The score of the player, calculated periodically.
+ pub score: f64,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
diff --git a/crates/entity/src/entities/players.rs b/crates/entity/src/entities/players.rs
index c026a6f..d79e2fa 100644
--- a/crates/entity/src/entities/players.rs
+++ b/crates/entity/src/entities/players.rs
@@ -1,7 +1,7 @@
use sea_orm::entity::prelude::*;
/// A player in the database.
-#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Eq)]
+#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
#[sea_orm(table_name = "players")]
pub struct Model {
/// The player ID.
@@ -20,6 +20,8 @@ pub struct Model {
pub admins_note: Option,
/// The player role.
pub role: u8,
+ /// The score of the player, calculated periodically.
+ pub score: f64,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
diff --git a/crates/game_api/Cargo.toml b/crates/game_api/Cargo.toml
index 96ee5c7..36305d1 100644
--- a/crates/game_api/Cargo.toml
+++ b/crates/game_api/Cargo.toml
@@ -19,7 +19,7 @@ tracing = { workspace = true }
tracing-subscriber = { workspace = true, features = ["env-filter"] }
actix-web = { workspace = true }
actix-cors = { workspace = true }
-actix-http = "3.11.0"
+actix-http = "3.11.2"
async-graphql-actix-web = { workspace = true }
tracing-actix-web = { workspace = true }
reqwest = { workspace = true, features = ["multipart"] }
@@ -44,12 +44,12 @@ graphql-api = { path = "../graphql-api" }
serde_json = { workspace = true }
[dev-dependencies]
-records-lib = { path = "../records_lib" }
+test-env = { path = "../test-env" }
+records-lib = { path = "../records_lib", features = ["tracing", "mock"] }
[features]
default = ["request_filter"]
request_filter = ["dep:pin-project-lite", "dep:request_filter"]
auth = []
-mysql = ["records-lib/mysql"]
-postgres = ["records-lib/postgres"]
-test = ["records-lib/test"]
+mysql = ["records-lib/mysql", "test-env/mysql"]
+postgres = ["records-lib/postgres", "test-env/postgres"]
diff --git a/crates/game_api/build.rs b/crates/game_api/build.rs
index fb8d5e7..8bd602c 100644
--- a/crates/game_api/build.rs
+++ b/crates/game_api/build.rs
@@ -2,6 +2,4 @@ fn main() {
#[cfg(all(feature = "auth", not(test)))]
println!("cargo:rustc-cfg=auth");
println!("cargo:rustc-check-cfg=cfg(auth)");
-
- println!("cargo:rustc-check-cfg=cfg(test_force_db_deletion)");
}
diff --git a/crates/game_api/src/auth/gen_token.rs b/crates/game_api/src/auth/gen_token.rs
index ba4c140..d98e2ba 100644
--- a/crates/game_api/src/auth/gen_token.rs
+++ b/crates/game_api/src/auth/gen_token.rs
@@ -1,4 +1,5 @@
use deadpool_redis::redis::AsyncCommands as _;
+use mkenv::Layer as _;
use records_lib::{
RedisConnection, gen_random_str,
redis_key::{mp_token_key, web_token_key},
@@ -23,7 +24,7 @@ pub async fn gen_token_for(
let mp_key = mp_token_key(login);
let web_key = web_token_key(login);
- let ex = crate::env().auth_token_ttl as _;
+ let ex = crate::env().auth_token_ttl.get() as _;
let mp_token_hash = digest(&*mp_token);
let web_token_hash = digest(&*web_token);
diff --git a/crates/game_api/src/configure.rs b/crates/game_api/src/configure.rs
index 0c305a0..29b8992 100644
--- a/crates/game_api/src/configure.rs
+++ b/crates/game_api/src/configure.rs
@@ -9,6 +9,7 @@ use actix_web::{
web,
};
use dsc_webhook::{FormattedRequestHead, WebhookBody, WebhookBodyEmbed, WebhookBodyEmbedField};
+use mkenv::prelude::*;
use records_lib::{Database, pool::clone_dbconn};
use tracing_actix_web::{DefaultRootSpanBuilder, RequestId};
@@ -163,7 +164,7 @@ pub(crate) fn send_internal_err_msg_detached(
tokio::task::spawn(async move {
if let Err(e) = client
- .post(&crate::env().wh_report_url)
+ .post(crate::env().wh_report_url.get())
.json(&wh_msg)
.send()
.await
diff --git a/crates/game_api/src/env.rs b/crates/game_api/src/env.rs
index 1c5f571..9e1b987 100644
--- a/crates/game_api/src/env.rs
+++ b/crates/game_api/src/env.rs
@@ -1,163 +1,160 @@
-use mkenv::{Env as _, EnvSplitIncluded as _};
+use mkenv::{make_config, prelude::*};
use once_cell::sync::OnceCell;
use records_lib::{DbEnv, LibEnv};
-#[cfg(feature = "test")]
-const DEFAULT_SESSION_KEY: &str = "";
-
-mkenv::make_env! {pub ApiEnvUsedOnce:
- #[cfg(not(feature = "test"))]
- sess_key: {
- id: SessKey(String),
- kind: file,
- var: "RECORDS_API_SESSION_KEY_FILE",
- desc: "The path to the file containing the session key used by the API",
- },
-
- #[cfg(feature = "test")]
- sess_key: {
- id: SessKey(String),
- kind: normal,
- var: "RECORDS_API_SESSION_KEY",
- desc: "The session key used by the API",
- default: DEFAULT_SESSION_KEY,
- },
-
- wh_invalid_req_url: {
- id: WebhookInvalidReqUrl(String),
- kind: normal,
- var: "WEBHOOK_INVALID_REQ_URL",
- desc: "The URL to the Discord webhook used to flag invalid requests",
- default: DEFAULT_WH_INVALID_REQ_URL,
- },
+#[cfg(not(debug_assertions))]
+mkenv::make_config! {
+ pub struct DynamicApiEnv {
+ pub sess_key: {
+ var_name: "RECORDS_API_SESSION_KEY_FILE",
+ layers: [file_read()],
+ description: "The path to the file containing the session key used by the API",
+ },
+
+ pub mp_client_id: {
+ var_name: "RECORDS_MP_APP_CLIENT_ID_FILE",
+ layers: [file_read()],
+ description: "The path to the file containing the Obstacle ManiaPlanet client ID",
+ },
+
+ pub mp_client_secret: {
+ var_name: "RECORDS_MP_APP_CLIENT_SECRET_FILE",
+ layers: [file_read()],
+ description: "The path to the file containing the Obstacle ManiaPlanet client secret",
+ },
+ }
}
-const DEFAULT_PORT: u16 = 3000;
-const DEFAULT_TOKEN_TTL: u32 = 15_552_000;
-#[cfg(feature = "test")]
-const DEFAULT_MP_CLIENT_ID: &str = "";
-#[cfg(feature = "test")]
-const DEFAULT_MP_CLIENT_SECRET: &str = "";
-const DEFAULT_WH_REPORT_URL: &str = "";
-const DEFAULT_WH_AC_URL: &str = "";
-const DEFAULT_WH_INVALID_REQ_URL: &str = "";
-const DEFAULT_WH_RANK_COMPUTE_ERROR: &str = "";
-const DEFAULT_GQL_ENDPOINT: &str = "/graphql";
-
-mkenv::make_env! {pub ApiEnv includes [
- DbEnv as db_env,
- LibEnv as lib_env,
- ApiEnvUsedOnce as used_once
-]:
- port: {
- id: Port(u16),
- kind: parse,
- var: "RECORDS_API_PORT",
- desc: "The port used to expose the API",
- default: DEFAULT_PORT,
- },
-
- #[cfg(not(debug_assertions))]
- host: {
- id: Host(String),
- kind: normal,
- var: "RECORDS_API_HOST",
- desc: "The hostname of the server where the API is running (e.g. https://obstacle.titlepack.io)",
- },
-
- auth_token_ttl: {
- id: AuthTokenTtl(u32),
- kind: parse,
- var: "RECORDS_API_TOKEN_TTL",
- desc: "The TTL (time-to-live) of an authentication token or anything related to it (in seconds)",
- default: DEFAULT_TOKEN_TTL,
- },
-
- #[cfg(not(feature = "test"))]
- mp_client_id: {
- id: MpClientId(String),
- kind: file,
- var: "RECORDS_MP_APP_CLIENT_ID_FILE",
- desc: "The path to the file containing the Obstacle ManiaPlanet client ID",
- },
- #[cfg(feature = "test")]
- mp_client_id: {
- id: MpClientId(String),
- kind: normal,
- var: "RECORDS_MP_APP_CLIENT_ID",
- desc: "The Obstacle ManiaPlanet client ID",
- default: DEFAULT_MP_CLIENT_ID,
- },
-
- #[cfg(not(feature = "test"))]
- mp_client_secret: {
- id: MpClientSecret(String),
- kind: file,
- var: "RECORDS_MP_APP_CLIENT_SECRET_FILE",
- desc: "The path to the file containing the Obstacle ManiaPlanet client secret",
- },
- #[cfg(feature = "test")]
- mp_client_secret: {
- id: MpClientSecret(String),
- kind: normal,
- var: "RECORDS_MP_APP_CLIENT_SECRET",
- desc: "The Obstacle ManiaPlanet client secret",
- default: DEFAULT_MP_CLIENT_SECRET,
- },
-
- wh_report_url: {
- id: WebhookReportUrl(String),
- kind: normal,
- var: "WEBHOOK_REPORT_URL",
- desc: "The URL to the Discord webhook used to report errors",
- default: DEFAULT_WH_REPORT_URL,
- },
-
- wh_ac_url: {
- id: WebhookAcUrl(String),
- kind: normal,
- var: "WEBHOOK_AC_URL",
- desc: "The URL to the Discord webhook used to share in-game statistics",
- default: DEFAULT_WH_AC_URL,
- },
-
- gql_endpoint: {
- id: GqlEndpoint(String),
- kind: normal,
- var: "GQL_ENDPOINT",
- desc: "The route to the GraphQL endpoint (e.g. /graphql)",
- default: DEFAULT_GQL_ENDPOINT,
- },
-
- wh_rank_compute_err: {
- id: WebhookRankComputeError(String),
- kind: normal,
- var: "WEBHOOK_RANK_COMPUTE_ERROR",
- desc: "The URL to the Discord webhook used to send rank compute errors",
- default: DEFAULT_WH_RANK_COMPUTE_ERROR,
- },
+#[cfg(debug_assertions)]
+mkenv::make_config! {
+ pub struct DynamicApiEnv {
+ pub sess_key: {
+ var_name: "RECORDS_API_SESSION_KEY",
+ layers: [or_default()],
+ description: "The session key used by the API",
+ default_val_fmt: "empty",
+ },
+
+ pub mp_client_id: {
+ var_name: "RECORDS_MP_APP_CLIENT_ID",
+ layers: [or_default()],
+ description: "The Obstacle ManiaPlanet client ID",
+ default_val_fmt: "empty",
+ },
+
+ pub mp_client_secret: {
+ var_name: "RECORDS_MP_APP_CLIENT_SECRET",
+ layers: [or_default()],
+ description: "The Obstacle ManiaPlanet client secret",
+ default_val_fmt: "empty",
+ },
+ }
}
-pub struct InitEnvOut {
- pub db_env: DbEnv,
- pub used_once: ApiEnvUsedOnce,
+#[cfg(debug_assertions)]
+mkenv::make_config! {
+ pub struct Hostname {}
}
-static ENV: OnceCell = OnceCell::new();
+#[cfg(not(debug_assertions))]
+mkenv::make_config! {
+ pub struct Hostname {
+ pub host: {
+ var_name: "RECORDS_API_HOST",
+ description: "The hostname of the server where the API is running (e.g. https://obstacle.titlepack.io)",
+ }
+ }
+}
+
+mkenv::make_config! {
+ pub struct ApiEnv {
+ pub db_env: { DbEnv },
+
+ pub dynamic: { DynamicApiEnv },
+
+ pub wh_invalid_req_url: {
+ var_name: "WEBHOOK_INVALID_REQ_URL",
+ layers: [or_default()],
+ description: "The URL to the Discord webhook used to flag invalid requests",
+ default_val_fmt: "empty",
+ },
+
+ pub port: {
+ var_name: "RECORDS_API_PORT",
+ layers: [
+ parsed_from_str(),
+ or_default(),
+ ],
+ description: "The port used to expose the API",
+ default_val_fmt: "3000",
+ },
+
+ pub host: { Hostname },
+
+ pub auth_token_ttl: {
+ var_name: "RECORDS_API_TOKEN_TTL",
+ layers: [
+ parsed_from_str(),
+ or_default_val(|| 180 * 24 * 3600),
+ ],
+ description: "The TTL (time-to-live) of an authentication token or anything related to it (in seconds)",
+ default_val_fmt: "180 days",
+ },
+
+ pub wh_report_url: {
+ var_name: "WEBHOOK_REPORT_URL",
+ layers: [or_default()],
+ description: "The URL to the Discord webhook used to report errors",
+ default_val_fmt: "empty",
+ },
+
+ pub wh_ac_url: {
+ var_name: "WEBHOOK_AC_URL",
+ layers: [or_default()],
+ description: "The URL to the Discord webhook used to share in-game statistics",
+ default_val_fmt: "empty",
+ },
+
+
+ pub gql_endpoint: {
+ var_name: "GQL_ENDPOINT",
+ layers: [
+ or_default_val(|| "/graphql".to_owned()),
+ ],
+ description: "The route to the GraphQL endpoint (e.g. /graphql)",
+ default_val_fmt: "/graphql",
+ },
+
+ pub wh_rank_compute_err: {
+ var_name: "WEBHOOK_RANK_COMPUTE_ERROR",
+ layers: [or_default()],
+ description: "The URL to the Discord webhook used to send rank compute errors",
+ default_val_fmt: "empty",
+ },
+ }
+}
-pub fn env() -> &'static mkenv::init_env!(ApiEnv) {
- // SAFETY: this function is always called when the `init_env()` is called at the start.
- unsafe { ENV.get_unchecked() }
+make_config! {
+ struct All {
+ api_env: { ApiEnv },
+ lib_env: { LibEnv },
+ gql_env: { graphql_api::config::ApiConfig },
+ }
}
-pub fn init_env() -> anyhow::Result {
- let env = ApiEnv::try_get()?;
- let (included, rest) = env.split();
- records_lib::init_env(included.lib_env);
- let _ = ENV.set(rest);
+static ENV: OnceCell = OnceCell::new();
+
+pub fn env() -> &'static ApiEnv {
+ ENV.get().unwrap()
+}
+
+pub fn init_env() -> anyhow::Result<()> {
+ let env = All::define();
+ env.try_init().map_err(|e| anyhow::anyhow!("{e}"))?;
+
+ records_lib::init_env(env.lib_env);
+ let _ = graphql_api::set_config(env.gql_env);
+ let _ = ENV.set(env.api_env);
- Ok(InitEnvOut {
- db_env: included.db_env,
- used_once: included.used_once,
- })
+ Ok(())
}
diff --git a/crates/game_api/src/graphql.rs b/crates/game_api/src/graphql.rs
index 425fd0e..4a7f778 100644
--- a/crates/game_api/src/graphql.rs
+++ b/crates/game_api/src/graphql.rs
@@ -6,6 +6,7 @@ use async_graphql::http::{GraphQLPlaygroundConfig, playground_source};
use async_graphql_actix_web::GraphQLRequest;
use graphql_api::error::{ApiGqlError, ApiGqlErrorKind};
use graphql_api::schema::{Schema, create_schema};
+use mkenv::prelude::*;
use records_lib::Database;
use reqwest::Client;
use tracing_actix_web::RequestId;
@@ -79,7 +80,7 @@ async fn index_playground() -> impl Responder {
HttpResponse::Ok()
.content_type("text/html; charset=utf-8")
.body(playground_source(GraphQLPlaygroundConfig::new(
- &crate::env().gql_endpoint,
+ &crate::env().gql_endpoint.get(),
)))
}
diff --git a/crates/game_api/src/http.rs b/crates/game_api/src/http.rs
index d858c06..e58839f 100644
--- a/crates/game_api/src/http.rs
+++ b/crates/game_api/src/http.rs
@@ -9,6 +9,8 @@ pub mod player;
use std::fmt;
+use mkenv::prelude::*;
+
use actix_web::body::BoxBody;
use actix_web::dev::{ServiceFactory, ServiceRequest, ServiceResponse};
use actix_web::web::{JsonConfig, Query};
@@ -131,7 +133,7 @@ async fn report_error(
}
client
- .post(&crate::env().wh_report_url)
+ .post(crate::env().wh_report_url.get())
.json(&WebhookBody {
content: format!("Error reported (mode version: {mode_vers})"),
embeds: vec![
diff --git a/crates/game_api/src/http/event.rs b/crates/game_api/src/http/event.rs
index 8f08205..f319f80 100644
--- a/crates/game_api/src/http/event.rs
+++ b/crates/game_api/src/http/event.rs
@@ -12,6 +12,7 @@ use entity::{
};
use futures::TryStreamExt;
use itertools::Itertools;
+use mkenv::prelude::*;
use records_lib::{
Database, Expirable as _, NullableInteger, NullableReal, NullableText, RedisPool,
error::RecordsError,
@@ -215,7 +216,7 @@ fn db_align_to_mp_align(
) -> NullableText {
match alignment {
None if pos_x.is_none() && pos_y.is_none() => {
- Some(records_lib::env().ingame_default_titles_align)
+ Some(records_lib::env().ingame_default_titles_align.get())
}
Some(_) if pos_x.is_some() || pos_y.is_some() => None,
other => other,
@@ -245,7 +246,7 @@ impl From for EventEditionInGameParams {
put_subtitle_on_newline: value
.put_subtitle_on_newline
.map(|b| b != 0)
- .unwrap_or_else(|| records_lib::env().ingame_default_subtitle_on_newline),
+ .unwrap_or_else(|| records_lib::env().ingame_default_subtitle_on_newline.get()),
titles_pos_x: value.titles_pos_x.into(),
titles_pos_y: value.titles_pos_y.into(),
lb_link_pos_x: value.lb_link_pos_x.into(),
@@ -262,22 +263,25 @@ impl Default for EventEditionInGameParams {
titles_align: NullableText(Some(
records_lib::env()
.ingame_default_titles_align
+ .get()
.to_char()
.to_string(),
)),
lb_link_align: NullableText(Some(
records_lib::env()
.ingame_default_lb_link_align
+ .get()
.to_char()
.to_string(),
)),
authors_align: NullableText(Some(
records_lib::env()
.ingame_default_authors_align
+ .get()
.to_char()
.to_string(),
)),
- put_subtitle_on_newline: records_lib::env().ingame_default_subtitle_on_newline,
+ put_subtitle_on_newline: records_lib::env().ingame_default_subtitle_on_newline.get(),
titles_pos_x: NullableReal(None),
titles_pos_y: NullableReal(None),
lb_link_pos_x: NullableReal(None),
diff --git a/crates/game_api/src/http/overview.rs b/crates/game_api/src/http/overview.rs
index 715eaa9..f2eaeca 100644
--- a/crates/game_api/src/http/overview.rs
+++ b/crates/game_api/src/http/overview.rs
@@ -169,7 +169,8 @@ async fn get_rank(
match min_time {
Some(time) => {
- let rank = ranks::get_rank(redis_pool, map.id, p.id, time, event)
+ let mut redis_conn = redis_pool.get().await.with_api_err()?;
+ let rank = ranks::get_rank(&mut redis_conn, map.id, time, event)
.await
.with_api_err()?;
Ok(Some(rank))
diff --git a/crates/game_api/src/http/player.rs b/crates/game_api/src/http/player.rs
index aeceba6..58238af 100644
--- a/crates/game_api/src/http/player.rs
+++ b/crates/game_api/src/http/player.rs
@@ -8,6 +8,7 @@ use actix_web::{
use deadpool_redis::redis::AsyncCommands as _;
use entity::{banishments, current_bans, maps, players, records, role, types};
use futures::TryStreamExt;
+use mkenv::prelude::*;
use records_lib::{Database, RedisPool, must, player, redis_key::alone_map_key, sync};
use reqwest::Client;
use sea_orm::{
@@ -475,7 +476,7 @@ async fn report_error(
};
client
- .post(&crate::env().wh_report_url)
+ .post(crate::env().wh_report_url.get())
.json(&WebhookBody {
content,
embeds: vec![
@@ -518,7 +519,7 @@ struct ACBody {
async fn ac(Res(client): Res, Json(body): Json) -> RecordsResult {
client
- .post(&crate::env().wh_ac_url)
+ .post(crate::env().wh_ac_url.get())
.json(&WebhookBody {
content: format!("Map has been finished in {}", body.run_time),
embeds: vec![WebhookBodyEmbed {
diff --git a/crates/game_api/src/http/player/auth/with_auth.rs b/crates/game_api/src/http/player/auth/with_auth.rs
index 3453fc9..11a708b 100644
--- a/crates/game_api/src/http/player/auth/with_auth.rs
+++ b/crates/game_api/src/http/player/auth/with_auth.rs
@@ -1,5 +1,6 @@
use actix_session::Session;
use actix_web::{HttpResponse, Responder, web};
+use mkenv::Layer as _;
use records_lib::Database;
use reqwest::{Client, StatusCode};
use tokio::time::timeout;
@@ -44,8 +45,8 @@ async fn test_access_token(
.post("https://prod.live.maniaplanet.com/login/oauth2/access_token")
.form(&MPAccessTokenBody {
grant_type: "authorization_code",
- client_id: &crate::env().mp_client_id,
- client_secret: &crate::env().mp_client_secret,
+ client_id: &crate::env().dynamic.mp_client_id.get(),
+ client_secret: &crate::env().dynamic.mp_client_secret.get(),
code,
redirect_uri,
})
diff --git a/crates/game_api/src/http/player_finished.rs b/crates/game_api/src/http/player_finished.rs
index ca4ab23..4ff8fc7 100644
--- a/crates/game_api/src/http/player_finished.rs
+++ b/crates/game_api/src/http/player_finished.rs
@@ -232,12 +232,14 @@ where
})
.await?;
+ let mut redis_conn = redis_pool.get().await.with_api_err()?;
+
let (old, new, has_improved, old_rank) = match result.old_record {
Some(records::Model { time: old, .. }) => (
old,
params.body.time,
params.body.time < old,
- Some(ranks::get_rank(redis_pool, map.id, player_id, old, params.event).await?),
+ Some(ranks::get_rank(&mut redis_conn, map.id, old, params.event).await?),
),
None => (params.body.time, params.body.time, true, None),
};
@@ -261,7 +263,7 @@ where
let (count,): (i32,) = pipe.query_async(&mut redis_conn).await.with_api_err()?;
count + 1
} else {
- ranks::get_rank(redis_pool, map.id, player_id, old, params.event)
+ ranks::get_rank(&mut redis_conn, map.id, old, params.event)
.await
.with_api_err()?
};
diff --git a/crates/game_api/src/main.rs b/crates/game_api/src/main.rs
index e76e8bc..7633e40 100644
--- a/crates/game_api/src/main.rs
+++ b/crates/game_api/src/main.rs
@@ -17,6 +17,7 @@ use actix_web::{
use anyhow::Context;
use game_api_lib::configure;
use migration::MigratorTrait;
+use mkenv::prelude::*;
use records_lib::Database;
use tracing::level_filters::LevelFilter;
use tracing_actix_web::TracingLogger;
@@ -26,15 +27,18 @@ use tracing_subscriber::{EnvFilter, fmt::format::FmtSpan};
#[tokio::main]
async fn main() -> anyhow::Result<()> {
dotenvy::dotenv()?;
- let env = game_api_lib::init_env()?;
+ game_api_lib::init_env()?;
#[cfg(feature = "request_filter")]
- request_filter::init_wh_url(env.used_once.wh_invalid_req_url).map_err(|_| {
+ request_filter::init_wh_url(game_api_lib::env().wh_invalid_req_url.get()).map_err(|_| {
game_api_lib::internal!("Invalid request WH URL isn't supposed to be set twice")
})?;
- let db = Database::from_db_url(env.db_env.db_url.db_url, env.db_env.redis_url.redis_url)
- .await
- .context("Cannot initialize database connection")?;
+ let db = Database::from_db_url(
+ game_api_lib::env().db_env.db_url.db_url.get(),
+ game_api_lib::env().db_env.redis_url.redis_url.get(),
+ )
+ .await
+ .context("Cannot initialize database connection")?;
migration::Migrator::up(&db.sql_conn, None)
.await
@@ -72,8 +76,7 @@ async fn main() -> anyhow::Result<()> {
tracing::info!("Using max connections: {max_connections}");
- let sess_key = Key::from(env.used_once.sess_key.as_bytes());
- drop(env.used_once.sess_key);
+ let sess_key = Key::from(game_api_lib::env().dynamic.sess_key.get().as_bytes());
HttpServer::new(move || {
let cors = Cors::default()
@@ -84,7 +87,7 @@ async fn main() -> anyhow::Result<()> {
#[cfg(debug_assertions)]
let cors = cors.allow_any_origin();
#[cfg(not(debug_assertions))]
- let cors = cors.allowed_origin(&game_api_lib::env().host);
+ let cors = cors.allowed_origin(&game_api_lib::env().host.host.get());
App::new()
.wrap(cors)
@@ -96,13 +99,13 @@ async fn main() -> anyhow::Result<()> {
.cookie_secure(cfg!(not(debug_assertions)))
.cookie_content_security(CookieContentSecurity::Private)
.session_lifecycle(PersistentSession::default().session_ttl(
- CookieDuration::seconds(game_api_lib::env().auth_token_ttl as i64),
+ CookieDuration::seconds(game_api_lib::env().auth_token_ttl.get() as i64),
))
.build(),
)
.configure(|cfg| configure::configure(cfg, db.clone()))
})
- .bind(("0.0.0.0", game_api_lib::env().port))
+ .bind(("0.0.0.0", game_api_lib::env().port.get()))
.context("Cannot bind address")?
.run()
.await
diff --git a/crates/game_api/tests/base.rs b/crates/game_api/tests/base.rs
index b1ba4ef..811b287 100644
--- a/crates/game_api/tests/base.rs
+++ b/crates/game_api/tests/base.rs
@@ -1,6 +1,6 @@
#![allow(dead_code)]
-use std::{fmt, panic};
+use std::fmt;
use actix_http::Request;
use actix_web::{
@@ -9,15 +9,11 @@ use actix_web::{
dev::{Service, ServiceResponse},
middleware, test,
};
-use anyhow::Context;
-use futures::FutureExt;
-use migration::MigratorTrait as _;
-use records_lib::{Database, pool::get_redis_pool};
-use sea_orm::{ConnectionTrait, DbConn};
+use records_lib::Database;
+use test_env::IntoResult;
use tracing_actix_web::TracingLogger;
use game_api_lib::{configure, init_env};
-use tracing_subscriber::fmt::TestWriter;
#[derive(Debug, serde::Deserialize)]
pub struct ErrorResponse {
@@ -26,17 +22,16 @@ pub struct ErrorResponse {
pub message: String,
}
-pub fn get_env() -> anyhow::Result {
- match dotenvy::dotenv() {
- Err(err) if !err.not_found() => return Err(err).context("retrieving .env files"),
- _ => (),
- }
-
- let _ = tracing_subscriber::fmt()
- .with_writer(TestWriter::new())
- .try_init();
-
- init_env()
+pub async fn with_db(test: F) -> anyhow::Result<::Out>
+where
+ F: AsyncFnOnce(Database) -> R,
+ R: IntoResult,
+{
+ test_env::wrap(async |db| {
+ init_env()?;
+ test(db).await.into_result()
+ })
+ .await
}
pub async fn get_app(
@@ -61,7 +56,7 @@ pub enum ApiError {
impl fmt::Display for ApiError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
- ApiError::InvalidJson(raw, deser_err) => match str::from_utf8(&raw) {
+ ApiError::InvalidJson(raw, deser_err) => match str::from_utf8(raw) {
Ok(s) => write!(
f,
"Invalid JSON returned by the API: {s}\nError when deserializing: {deser_err}"
@@ -90,125 +85,6 @@ impl fmt::Display for ApiError {
impl std::error::Error for ApiError {}
-pub trait IntoResult {
- type Out;
-
- fn into_result(self) -> anyhow::Result;
-}
-
-impl IntoResult for () {
- type Out = ();
-
- fn into_result(self) -> anyhow::Result {
- Ok(())
- }
-}
-
-impl IntoResult for Result
-where
- anyhow::Error: From,
-{
- type Out = T;
-
- fn into_result(self) -> anyhow::Result {
- self.map_err(From::from)
- }
-}
-
-pub async fn with_db(test: F) -> anyhow::Result<::Out>
-where
- F: AsyncFnOnce(Database) -> R,
- R: IntoResult,
-{
- let env = get_env()?;
- wrap(env.db_env.db_url.db_url, async |sql_conn| {
- let db = Database {
- sql_conn,
- redis_pool: get_redis_pool(env.db_env.redis_url.redis_url)?,
- };
- test(db).await.into_result()
- })
- .await
-}
-
-pub async fn wrap(db_url: String, test: F) -> anyhow::Result<::Out>
-where
- F: AsyncFnOnce(DbConn) -> R,
- R: IntoResult,
-{
- let master_db = sea_orm::Database::connect(&db_url).await?;
-
- // For some reasons, on MySQL/MariaDB, using a schema name with some capital letters
- // may produce the error code 1932 (42S02) "Table 'X' doesn't exist in engine" when
- // doing a query.
- let db_name = format!(
- "_test_db_{}",
- records_lib::gen_random_str(10).to_lowercase()
- );
-
- master_db
- .execute_unprepared(&format!("create database {db_name}"))
- .await?;
- println!("Created database {db_name}");
-
- let db = match master_db {
- #[cfg(feature = "mysql")]
- sea_orm::DatabaseConnection::SqlxMySqlPoolConnection(_) => {
- let connect_options = master_db.get_mysql_connection_pool().connect_options();
- let connect_options = (*connect_options).clone();
- let options = connect_options.database(&db_name);
- let db = sqlx::mysql::MySqlPool::connect_with(options).await?;
- DbConn::from(db)
- }
- #[cfg(feature = "postgres")]
- sea_orm::DatabaseConnection::SqlxPostgresPoolConnection(_) => {
- let connect_options = master_db.get_postgres_connection_pool().connect_options();
- let connect_options = (*connect_options).clone();
- let options = connect_options.database(&db_name);
- let db = sqlx::postgres::PgPool::connect_with(options).await?;
- DbConn::from(db)
- }
- _ => unreachable!(),
- };
-
- migration::Migrator::up(&db, None).await?;
-
- let r = panic::AssertUnwindSafe(test(db)).catch_unwind().await;
- #[cfg(test_force_db_deletion)]
- {
- master_db
- .execute_unprepared(&format!("drop database {db_name}"))
- .await?;
- println!("Database {db_name} force-deleted");
- match r {
- Ok(r) => r.into_result(),
- Err(e) => panic::resume_unwind(e),
- }
- }
- #[cfg(not(test_force_db_deletion))]
- {
- match r.map(IntoResult::into_result) {
- Ok(Ok(out)) => {
- master_db
- .execute_unprepared(&format!("drop database {db_name}"))
- .await?;
- Ok(out)
- }
- other => {
- println!(
- "Test failed, leaving database {db_name} as-is. \
- Run with cfg `test_force_db_deletion` to drop the database everytime."
- );
- match other {
- Ok(Err(e)) => Err(e),
- Err(e) => panic::resume_unwind(e),
- _ => unreachable!(),
- }
- }
- }
- }
-}
-
pub fn try_from_slice<'de, T>(slice: &'de [u8]) -> Result
where
T: serde::Deserialize<'de>,
diff --git a/crates/game_api/tests/overview.rs b/crates/game_api/tests/overview.rs
index 9811389..2cc7738 100644
--- a/crates/game_api/tests/overview.rs
+++ b/crates/game_api/tests/overview.rs
@@ -78,7 +78,7 @@ fn player_id_to_row(player_id: i32) -> Row {
}
async fn insert_sample_map(conn: &C) -> anyhow::Result {
- let map_id = rand::random_range(..100000);
+ let map_id = test_env::get_map_id();
maps::Entity::insert(maps::ActiveModel {
id: Set(map_id),
@@ -372,9 +372,7 @@ async fn competition_ranking() -> anyhow::Result<()> {
let app = base::get_app(db.clone()).await;
let req = test::TestRequest::get()
- .uri(&format!(
- "/overview?mapId=test_map_uid&playerId=player_1_login"
- ))
+ .uri("/overview?mapId=test_map_uid&playerId=player_1_login")
.to_request();
let resp = test::call_service(&app, req).await;
@@ -448,8 +446,8 @@ async fn overview_event_version_map_non_empty() -> anyhow::Result<()> {
.await
.context("couldn't insert players")?;
- let map_id = rand::random_range(1..=100000);
- let event_map_id = rand::random_range(1..=100000);
+ let map_id = test_env::get_map_id();
+ let event_map_id = test_env::get_map_id();
let map = maps::ActiveModel {
id: Set(map_id),
diff --git a/crates/game_api/tests/overview_event.rs b/crates/game_api/tests/overview_event.rs
index 96cef38..8c8705a 100644
--- a/crates/game_api/tests/overview_event.rs
+++ b/crates/game_api/tests/overview_event.rs
@@ -40,9 +40,9 @@ async fn event_overview_original_maps_diff() -> anyhow::Result<()> {
..Default::default()
};
- let map_id = rand::random_range(1..=100000);
+ let map_id = test_env::get_map_id();
let original_map_id = loop {
- let val = rand::random_range(1..=100000);
+ let val = test_env::get_map_id();
if val != map_id {
break val;
}
@@ -185,7 +185,7 @@ async fn event_overview_transparent_equal() -> anyhow::Result<()> {
..Default::default()
};
- let map_id = rand::random_range(1..=100000);
+ let map_id = test_env::get_map_id();
let map = maps::ActiveModel {
id: Set(map_id),
@@ -307,7 +307,7 @@ async fn event_overview_original_maps_empty() -> anyhow::Result<()> {
..Default::default()
};
- let map_id = rand::random_range(1..=100000);
+ let map_id = test_env::get_map_id();
let map = maps::ActiveModel {
id: Set(map_id),
diff --git a/crates/game_api/tests/player_finished.rs b/crates/game_api/tests/player_finished.rs
index 90bcafa..f9ac572 100644
--- a/crates/game_api/tests/player_finished.rs
+++ b/crates/game_api/tests/player_finished.rs
@@ -27,7 +27,7 @@ async fn single_try() -> anyhow::Result<()> {
..Default::default()
};
- let map_id = rand::random_range(1..=100000);
+ let map_id = test_env::get_map_id();
let map = maps::ActiveModel {
id: Set(map_id),
@@ -192,7 +192,7 @@ async fn many_tries() -> anyhow::Result<()> {
..Default::default()
};
- let map_id = rand::random_range(1..=100000);
+ let map_id = test_env::get_map_id();
let map = maps::ActiveModel {
id: Set(map_id),
@@ -338,7 +338,7 @@ async fn with_mode_version() -> anyhow::Result<()> {
..Default::default()
};
- let map_id = rand::random_range(1..=100000);
+ let map_id = test_env::get_map_id();
let map = maps::ActiveModel {
id: Set(map_id),
@@ -428,7 +428,7 @@ async fn many_records() -> anyhow::Result<()> {
..Default::default()
});
- let map_id = rand::random_range(1..=100000);
+ let map_id = test_env::get_map_id();
let map = maps::ActiveModel {
id: Set(map_id),
@@ -569,7 +569,7 @@ async fn save_record_for_related_non_transparent_event() -> anyhow::Result<()> {
..Default::default()
};
- let map_id = rand::random_range(1..=100000);
+ let map_id = test_env::get_map_id();
let map = maps::ActiveModel {
id: Set(map_id),
diff --git a/crates/game_api/tests/player_finished_event.rs b/crates/game_api/tests/player_finished_event.rs
index 10a32a7..dff6e5f 100644
--- a/crates/game_api/tests/player_finished_event.rs
+++ b/crates/game_api/tests/player_finished_event.rs
@@ -38,7 +38,7 @@ async fn finished_on_transparent_event() -> anyhow::Result<()> {
..Default::default()
};
- let map_id = rand::random_range(1..=100000);
+ let map_id = test_env::get_map_id();
let map = maps::ActiveModel {
id: Set(map_id),
@@ -166,9 +166,9 @@ async fn event_finish_transitive_save() -> anyhow::Result<()> {
..Default::default()
};
- let original_map_id = rand::random_range(1..=100000);
+ let original_map_id = test_env::get_map_id();
let map_id = loop {
- let val = rand::random_range(1..=100000);
+ let val = test_env::get_map_id();
if val != original_map_id {
break val;
}
@@ -313,9 +313,9 @@ async fn event_finish_save_to_original() -> anyhow::Result<()> {
..Default::default()
};
- let original_map_id = rand::random_range(1..=100000);
+ let original_map_id = test_env::get_map_id();
let map_id = loop {
- let val = rand::random_range(1..=100000);
+ let val = test_env::get_map_id();
if val != original_map_id {
break val;
}
@@ -459,9 +459,9 @@ async fn event_finish_non_transitive_save() -> anyhow::Result<()> {
..Default::default()
};
- let original_map_id = rand::random_range(1..=100000);
+ let original_map_id = test_env::get_map_id();
let map_id = loop {
- let val = rand::random_range(1..=100000);
+ let val = test_env::get_map_id();
if val != original_map_id {
break val;
}
@@ -576,9 +576,9 @@ async fn event_finish_save_non_event_record() -> anyhow::Result<()> {
..Default::default()
};
- let original_map_id = rand::random_range(1..=100000);
+ let original_map_id = test_env::get_map_id();
let map_id = loop {
- let val = rand::random_range(1..=100000);
+ let val = test_env::get_map_id();
if val != original_map_id {
break val;
}
diff --git a/crates/game_api/tests/root.rs b/crates/game_api/tests/root.rs
index 3e906c8..a8173a6 100644
--- a/crates/game_api/tests/root.rs
+++ b/crates/game_api/tests/root.rs
@@ -3,15 +3,18 @@ mod base;
use actix_web::test;
#[tokio::test]
-#[cfg(feature = "test")]
+#[cfg(test)]
async fn test_not_found() -> anyhow::Result<()> {
use actix_http::StatusCode;
use game_api_lib::TracedError;
- use records_lib::Database;
+ use mkenv::prelude::*;
+ use records_lib::{Database, DbEnv};
use sea_orm::DbBackend;
- let env = base::get_env()?;
- let db = Database::from_mock_db(DbBackend::MySql, env.db_env.redis_url.redis_url)?;
+ test_env::init_env()?;
+
+ let env = DbEnv::define();
+ let db = Database::from_mock_db(DbBackend::MySql, env.redis_url.redis_url.get())?;
let app = base::get_app(db).await;
let req = test::TestRequest::get().uri("/").to_request();
diff --git a/crates/graphql-api/Cargo.toml b/crates/graphql-api/Cargo.toml
index 2f06753..c195a4b 100644
--- a/crates/graphql-api/Cargo.toml
+++ b/crates/graphql-api/Cargo.toml
@@ -10,8 +10,25 @@ chrono = { workspace = true }
deadpool-redis = { workspace = true }
entity = { path = "../entity" }
futures = { workspace = true }
+hmac = "0.12.1"
+mkenv = { workspace = true }
records-lib = { path = "../records_lib" }
reqwest = { workspace = true }
sea-orm = { workspace = true }
serde = { workspace = true }
+sha2 = "0.10.9"
tokio = { workspace = true, features = ["macros"] }
+itertools.workspace = true
+serde_json.workspace = true
+
+[dev-dependencies]
+test-env = { path = "../test-env" }
+anyhow = { workspace = true }
+tracing.workspace = true
+rand.workspace = true
+
+[features]
+default = []
+mysql = ["records-lib/mysql", "test-env/mysql"]
+postgres = ["records-lib/postgres", "test-env/postgres"]
+sqlite = ["records-lib/sqlite"]
diff --git a/crates/graphql-api/src/config.rs b/crates/graphql-api/src/config.rs
new file mode 100644
index 0000000..523d04e
--- /dev/null
+++ b/crates/graphql-api/src/config.rs
@@ -0,0 +1,104 @@
+use hmac::Hmac;
+use sha2::{Sha224, digest::KeyInit};
+use std::{error::Error, fmt, sync::OnceLock};
+
+use mkenv::{ConfigDescriptor, exec::ConfigInitializer};
+
+pub type SecretKeyType = Hmac;
+
+fn parse_secret_key(input: &str) -> Result> {
+ SecretKeyType::new_from_slice(input.as_bytes()).map_err(From::from)
+}
+
+#[cfg(debug_assertions)]
+mkenv::make_config! {
+ pub struct SecretKey {
+ pub(crate) cursor_secret_key: {
+ var_name: "GQL_API_CURSOR_SECRET_KEY",
+ layers: [
+ parsed(parse_secret_key),
+ ],
+ description: "The secret key used for signing cursors for pagination \
+ in the GraphQL API",
+ }
+ }
+}
+
+#[cfg(not(debug_assertions))]
+mkenv::make_config! {
+ pub struct SecretKey {
+ pub(crate) cursor_secret_key: {
+ var_name: "GQL_API_CURSOR_SECRET_KEY_FILE",
+ layers: [
+ file_read(),
+ parsed(parse_secret_key),
+ ],
+ description: "The path to the file containing The secret key used for signing cursors \
+ for pagination in the GraphQL API",
+ }
+ }
+}
+
+mkenv::make_config! {
+ pub struct ApiConfig {
+ pub(crate) cursor_max_limit: {
+ var_name: "GQL_API_CURSOR_MAX_LIMIT",
+ layers: [
+ parsed_from_str(),
+ or_default_val(|| 100),
+ ],
+ },
+
+ pub(crate) cursor_default_limit: {
+ var_name: "GQL_API_CURSOR_DEFAULT_LIMIT",
+ layers: [
+ parsed_from_str(),
+ or_default_val(|| 50),
+ ],
+ },
+
+ pub(crate) cursor_secret_key: { SecretKey },
+ }
+}
+
+static CONFIG: OnceLock = OnceLock::new();
+
+#[derive(Debug)]
+pub enum InitError {
+ ConfigAlreadySet,
+ Config(Box),
+}
+
+impl fmt::Display for InitError {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ match self {
+ InitError::ConfigAlreadySet => f.write_str("config already set"),
+ InitError::Config(error) => write!(f, "error during config init: {error}"),
+ }
+ }
+}
+
+impl Error for InitError {
+ fn source(&self) -> Option<&(dyn Error + 'static)> {
+ match self {
+ InitError::ConfigAlreadySet => None,
+ InitError::Config(error) => Some(&**error),
+ }
+ }
+}
+
+pub fn set_config(config: ApiConfig) -> Result<(), Box> {
+ CONFIG.set(config).map_err(Box::new)
+}
+
+pub fn init_config() -> Result<(), InitError> {
+ let config = ApiConfig::define();
+ config
+ .try_init()
+ .map_err(|e| InitError::Config(format!("{e}").into()))?;
+ CONFIG.set(config).map_err(|_| InitError::ConfigAlreadySet)
+}
+
+pub(crate) fn config() -> &'static ApiConfig {
+ CONFIG.get().unwrap()
+}
diff --git a/crates/graphql-api/src/cursors.rs b/crates/graphql-api/src/cursors.rs
index 740ae60..ce7255a 100644
--- a/crates/graphql-api/src/cursors.rs
+++ b/crates/graphql-api/src/cursors.rs
@@ -1,141 +1,533 @@
-use std::{ops::RangeInclusive, str};
+pub mod expr_tuple;
+pub mod query_builder;
+pub mod query_trait;
+
+use std::str;
use async_graphql::{ID, connection::CursorType};
-use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64};
+use base64::{Engine as _, prelude::BASE64_URL_SAFE};
use chrono::{DateTime, Utc};
+use hmac::Mac;
+use mkenv::Layer;
+use sea_orm::{
+ Value,
+ sea_query::{IntoValueTuple, SimpleExpr, ValueTuple},
+};
-use crate::error::CursorDecodeErrorKind;
-
-pub const CURSOR_LIMIT_RANGE: RangeInclusive = 1..=100;
-pub const CURSOR_DEFAULT_LIMIT: usize = 50;
+use crate::{
+ cursors::expr_tuple::{ExprTuple, IntoExprTuple},
+ error::{CursorDecodeError, CursorDecodeErrorKind},
+};
-fn decode_base64(s: &str) -> Result {
- let decoded = BASE64
+fn decode_base64(s: &str) -> Result {
+ let decoded = BASE64_URL_SAFE
.decode(s)
- .map_err(|_| CursorDecodeErrorKind::NotBase64)?;
+ .map_err(|_| CursorDecodeError::from(CursorDecodeErrorKind::NotBase64))?;
+ // TODO: replace this with `slice::split_once` once it's stabilized
+ // https://github.com/rust-lang/rust/issues/112811
+ let idx = decoded
+ .iter()
+ .position(|b| *b == b'$')
+ .ok_or(CursorDecodeError::from(CursorDecodeErrorKind::NoSignature))?;
+ let (content, signature) = (&decoded[..idx], &decoded[idx + 1..]);
+
+ let mut mac = crate::config().cursor_secret_key.cursor_secret_key.get();
+ mac.update(content);
+ mac.verify_slice(signature)
+ .map_err(|e| CursorDecodeError::from(CursorDecodeErrorKind::InvalidSignature(e)))?;
- String::from_utf8(decoded).map_err(|_| CursorDecodeErrorKind::NotUtf8)
+ let content = str::from_utf8(content)
+ .map_err(|_| CursorDecodeError::from(CursorDecodeErrorKind::NotUtf8))?;
+
+ Ok(content.to_owned())
}
-fn check_prefix(splitted: I) -> Result<(), CursorDecodeErrorKind>
-where
- I: IntoIterator- ,
- S: AsRef,
-{
- match splitted
- .into_iter()
- .next()
- .filter(|s| s.as_ref() == "record")
+fn encode_base64(s: String) -> String {
+ let mut mac = crate::config().cursor_secret_key.cursor_secret_key.get();
+ mac.update(s.as_bytes());
+ let signature = mac.finalize().into_bytes();
+
+ let mut output = s.into_bytes();
+ output.push(b'$');
+ output.extend_from_slice(&signature);
+
+ BASE64_URL_SAFE.encode(&output)
+}
+
+mod datetime_serde {
+ use std::fmt;
+
+ use chrono::{DateTime, Utc};
+ use serde::de::{Unexpected, Visitor};
+
+ pub fn serialize
(datetime: &DateTime, serializer: S) -> Result
+ where
+ S: serde::Serializer,
{
- Some(_) => Ok(()),
- None => Err(CursorDecodeErrorKind::WrongPrefix),
+ let time = datetime.timestamp_millis();
+ serializer.serialize_i64(time)
}
+
+ pub fn deserialize<'de, D>(deser: D) -> Result, D::Error>
+ where
+ D: serde::Deserializer<'de>,
+ {
+ struct TimestampVisitor;
+ impl<'de> Visitor<'de> for TimestampVisitor {
+ type Value = DateTime;
+
+ fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ f.write_str("an integer")
+ }
+
+ fn visit_u64(self, v: u64) -> Result
+ where
+ E: serde::de::Error,
+ {
+ self.visit_i64(v as _)
+ }
+
+ fn visit_i64(self, v: i64) -> Result
+ where
+ E: serde::de::Error,
+ {
+ DateTime::from_timestamp_millis(v).ok_or_else(|| {
+ serde::de::Error::invalid_value(Unexpected::Signed(v), &"a valid timestamp")
+ })
+ }
+ }
+
+ deser.deserialize_i64(TimestampVisitor)
+ }
+}
+
+#[cold]
+#[inline(never)]
+fn serialization_failed(_obj: &T, e: serde_json::Error) -> ! {
+ panic!(
+ "serialization of `{}` failed: {e}",
+ std::any::type_name::()
+ )
}
-fn check_timestamp(splitted: I) -> Result, CursorDecodeErrorKind>
+fn encode_cursor(prefix: &str, obj: &T) -> String
where
- I: IntoIterator- ,
- S: AsRef,
+ T: serde::Serialize,
{
- splitted
- .into_iter()
- .next()
- .and_then(|t| t.as_ref().parse().ok())
- .ok_or(CursorDecodeErrorKind::NoTimestamp)
- .and_then(|t| {
- DateTime::from_timestamp_millis(t).ok_or(CursorDecodeErrorKind::InvalidTimestamp(t))
- })
+ encode_base64(format!(
+ "{prefix}:{}",
+ serde_json::to_string(obj).unwrap_or_else(|e| serialization_failed(obj, e))
+ ))
}
-fn check_time(splitted: I) -> Result
+fn decode_cursor(prefix: &str, input: &str) -> Result
where
- I: IntoIterator
- ,
- S: AsRef,
+ T: serde::de::DeserializeOwned,
{
- splitted
- .into_iter()
- .next()
- .and_then(|t| t.as_ref().parse().ok())
- .ok_or(CursorDecodeErrorKind::NoTime)
+ let decoded = decode_base64(input)?;
+
+ let Some((input_prefix, input)) = decoded.split_once(':') else {
+ return Err(CursorDecodeErrorKind::MissingPrefix.into());
+ };
+
+ if input_prefix != prefix {
+ return Err(CursorDecodeErrorKind::InvalidPrefix.into());
+ }
+
+ serde_json::from_str(input).map_err(|_| CursorDecodeErrorKind::InvalidData.into())
}
-pub struct RecordDateCursor(pub DateTime);
+#[derive(PartialEq, Debug, serde::Serialize, serde::Deserialize)]
+pub struct RecordDateCursor {
+ #[serde(with = "datetime_serde")]
+ pub record_date: DateTime,
+ pub data: T,
+}
-impl CursorType for RecordDateCursor {
- type Error = CursorDecodeErrorKind;
+impl CursorType for RecordDateCursor
+where
+ T: serde::Serialize + for<'de> serde::Deserialize<'de>,
+{
+ type Error = CursorDecodeError;
fn decode_cursor(s: &str) -> Result {
- let decoded = decode_base64(s)?;
- let mut splitted = decoded.split(':');
- check_prefix(&mut splitted)?;
- let record_date = check_timestamp(&mut splitted)?;
- Ok(Self(record_date))
+ decode_cursor("record_date", s)
}
fn encode_cursor(&self) -> String {
- let timestamp = self.0.timestamp_millis();
- BASE64.encode(format!("record:{}", timestamp))
+ encode_cursor("record_date", self)
}
}
-pub struct RecordRankCursor {
+impl IntoExprTuple for &RecordDateCursor
+where
+ T: Into + Clone,
+{
+ fn into_expr_tuple(self) -> ExprTuple {
+ (self.record_date, self.data.clone()).into_expr_tuple()
+ }
+}
+
+impl IntoValueTuple for &RecordDateCursor
+where
+ T: Into + Clone,
+{
+ fn into_value_tuple(self) -> ValueTuple {
+ (self.record_date, self.data.clone()).into_value_tuple()
+ }
+}
+
+#[derive(Debug, PartialEq, serde::Serialize, serde::Deserialize)]
+pub struct RecordRankCursor {
+ #[serde(with = "datetime_serde")]
pub record_date: DateTime,
pub time: i32,
+ pub data: T,
}
-impl CursorType for RecordRankCursor {
- type Error = CursorDecodeErrorKind;
+impl CursorType for RecordRankCursor
+where
+ T: serde::Serialize + for<'de> serde::Deserialize<'de>,
+{
+ type Error = CursorDecodeError;
fn decode_cursor(s: &str) -> Result {
- let decoded = decode_base64(s)?;
- let mut splitted = decoded.split(':');
- check_prefix(&mut splitted)?;
- let record_date = check_timestamp(&mut splitted)?;
- let time = check_time(&mut splitted)?;
- Ok(Self { record_date, time })
+ decode_cursor("record_rank", s)
}
fn encode_cursor(&self) -> String {
- let timestamp = self.record_date.timestamp_millis();
- BASE64.encode(format!("record:{timestamp}:{}", self.time))
+ encode_cursor("record_rank", self)
+ }
+}
+
+impl IntoExprTuple for &RecordRankCursor
+where
+ T: Into + Clone,
+{
+ fn into_expr_tuple(self) -> ExprTuple {
+ (self.time, self.record_date, self.data.clone()).into_expr_tuple()
+ }
+}
+
+impl IntoValueTuple for &RecordRankCursor
+where
+ T: Into + Clone,
+{
+ fn into_value_tuple(self) -> sea_orm::sea_query::ValueTuple {
+ (self.time, self.record_date, self.data.clone()).into_value_tuple()
}
}
-pub struct TextCursor(pub String);
+#[derive(Debug, PartialEq, serde::Serialize, serde::Deserialize)]
+pub struct TextCursor {
+ pub text: String,
+ pub data: T,
+}
-impl CursorType for TextCursor {
- type Error = CursorDecodeErrorKind;
+impl CursorType for TextCursor
+where
+ T: serde::Serialize + for<'de> serde::Deserialize<'de>,
+{
+ type Error = CursorDecodeError;
fn decode_cursor(s: &str) -> Result {
- decode_base64(s).map(Self)
+ decode_cursor("text", s)
}
fn encode_cursor(&self) -> String {
- BASE64.encode(&self.0)
+ encode_cursor("text", self)
+ }
+}
+
+impl IntoExprTuple for &TextCursor
+where
+ T: Into + Clone,
+{
+ fn into_expr_tuple(self) -> ExprTuple {
+ (self.text.clone(), self.data.clone()).into_expr_tuple()
}
}
-pub struct F64Cursor(pub f64);
+impl IntoValueTuple for &TextCursor
+where
+ T: Into + Clone,
+{
+ fn into_value_tuple(self) -> ValueTuple {
+ (self.text.clone(), self.data.clone()).into_value_tuple()
+ }
+}
-impl CursorType for F64Cursor {
- type Error = CursorDecodeErrorKind;
+#[derive(Debug, PartialEq, serde::Serialize, serde::Deserialize)]
+pub struct F64Cursor {
+ pub score: f64,
+ pub data: T,
+}
+
+impl CursorType for F64Cursor
+where
+ T: serde::Serialize + for<'de> serde::Deserialize<'de>,
+{
+ type Error = CursorDecodeError;
fn decode_cursor(s: &str) -> Result {
- let decoded = decode_base64(s)?;
- let parsed = decoded
- .parse()
- .map_err(|_| CursorDecodeErrorKind::NoScore)?;
- Ok(Self(parsed))
+ decode_cursor("score", s)
}
fn encode_cursor(&self) -> String {
- BASE64.encode(self.0.to_string())
+ encode_cursor("score", self)
+ }
+}
+
+impl IntoExprTuple for &F64Cursor
+where
+ T: Into + Clone,
+{
+ fn into_expr_tuple(self) -> ExprTuple {
+ (self.score, self.data.clone()).into_expr_tuple()
+ }
+}
+
+impl IntoValueTuple for &F64Cursor
+where
+ T: Into + Clone,
+{
+ fn into_value_tuple(self) -> ValueTuple {
+ (self.score, self.data.clone()).into_value_tuple()
}
}
-pub struct ConnectionParameters {
- pub before: Option,
- pub after: Option,
+pub struct ConnectionParameters {
+ pub before: Option,
+ pub after: Option,
pub first: Option,
pub last: Option,
}
+
+impl Default for ConnectionParameters {
+ fn default() -> Self {
+ Self {
+ before: Default::default(),
+ after: Default::default(),
+ first: Default::default(),
+ last: Default::default(),
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use std::fmt;
+
+ use async_graphql::connection::CursorType;
+ use base64::{Engine as _, prelude::BASE64_URL_SAFE};
+ use chrono::{DateTime, SubsecRound, Utc};
+ use sha2::digest::MacError;
+
+ use crate::{
+ config::InitError,
+ cursors::{F64Cursor, TextCursor},
+ error::{CursorDecodeError, CursorDecodeErrorKind},
+ };
+
+ use super::{RecordDateCursor, RecordRankCursor};
+
+ fn setup() {
+ match crate::init_config() {
+ Ok(_) | Err(InitError::ConfigAlreadySet) => (),
+ Err(InitError::Config(e)) => {
+ panic!("test setup error: {e}");
+ }
+ }
+ }
+
+ fn test_cursor_round_trip(source: &C, expected: &C)
+ where
+ C: CursorType + fmt::Debug + PartialEq,
+ {
+ let cursor = source.encode_cursor();
+
+ let decoded = C::decode_cursor(&cursor);
+ assert_eq!(decoded.as_ref(), Ok(expected));
+ }
+
+ fn test_decode_cursor_errors()
+ where
+ C: CursorType,
+ {
+ let decoded = ::decode_cursor("foobar").err();
+ assert_eq!(
+ decoded.map(|e| e.kind),
+ Some(CursorDecodeErrorKind::NotBase64)
+ );
+
+ let encoded = BASE64_URL_SAFE.encode("foobar");
+ let decoded = ::decode_cursor(&encoded).err();
+ assert_eq!(
+ decoded.map(|e| e.kind),
+ Some(CursorDecodeErrorKind::NoSignature)
+ );
+
+ let encoded = BASE64_URL_SAFE.encode("foobar$signature");
+ let decoded = ::decode_cursor(&encoded).err();
+ assert_eq!(
+ decoded.map(|e| e.kind),
+ Some(CursorDecodeErrorKind::InvalidSignature(MacError))
+ );
+ }
+
+ fn test_encode_cursor(cursor: &C, expected: &str) {
+ let cursor = cursor.encode_cursor();
+
+ let decoded = BASE64_URL_SAFE
+ .decode(&cursor)
+ .expect("cursor should be encoded as base64");
+ let expected = format!("{expected}$");
+ assert!(decoded.starts_with(expected.as_bytes()));
+ }
+
+ #[test]
+ fn encode_date_cursor() {
+ setup();
+
+ test_encode_cursor(
+ &RecordDateCursor {
+ record_date: DateTime::from_timestamp_millis(10).unwrap(),
+ data: 0,
+ },
+ r#"record_date:{"record_date":10,"data":0}"#,
+ );
+ }
+
+ #[test]
+ fn date_cursor_round_trip() {
+ setup();
+
+ let now = Utc::now();
+ test_cursor_round_trip(
+ &RecordDateCursor {
+ record_date: now,
+ data: 0,
+ },
+ &RecordDateCursor {
+ record_date: now.trunc_subsecs(3),
+ data: 0,
+ },
+ );
+ }
+
+ #[test]
+ fn decode_date_cursor_errors() {
+ setup();
+ test_decode_cursor_errors::();
+ }
+
+ #[test]
+ fn decode_rank_cursor_errors() {
+ setup();
+ test_decode_cursor_errors::();
+ }
+
+ #[test]
+ fn decode_text_cursor_errors() {
+ setup();
+ test_decode_cursor_errors::();
+ }
+
+ #[test]
+ fn decode_score_cursor_errors() {
+ setup();
+ test_decode_cursor_errors::();
+ }
+
+ #[test]
+ fn encode_rank_cursor() {
+ setup();
+
+ test_encode_cursor(
+ &RecordRankCursor {
+ record_date: DateTime::from_timestamp_millis(26).unwrap(),
+ time: 1000,
+ data: 24,
+ },
+ r#"record_rank:{"record_date":26,"time":1000,"data":24}"#,
+ );
+ }
+
+ #[test]
+ fn rank_cursor_round_trip() {
+ setup();
+
+ let now = Utc::now();
+ test_cursor_round_trip(
+ &RecordRankCursor {
+ record_date: now,
+ time: 2000,
+ data: 34,
+ },
+ &RecordRankCursor {
+ record_date: now.trunc_subsecs(3),
+ time: 2000,
+ data: 34,
+ },
+ );
+ }
+
+ #[test]
+ fn encode_text_cursor() {
+ setup();
+ test_encode_cursor(
+ &TextCursor {
+ text: "hello".to_owned(),
+ data: 123,
+ },
+ r#"text:{"text":"hello","data":123}"#,
+ );
+ }
+
+ #[test]
+ fn text_cursor_round_trip() {
+ setup();
+ test_cursor_round_trip(
+ &TextCursor {
+ text: "booga".to_owned(),
+ data: 232,
+ },
+ &TextCursor {
+ text: "booga".to_owned(),
+ data: 232,
+ },
+ );
+ }
+
+ #[test]
+ fn encode_score_cursor() {
+ setup();
+ test_encode_cursor(
+ &F64Cursor {
+ score: 12.34,
+ data: 2445,
+ },
+ r#"score:{"score":12.34,"data":2445}"#,
+ );
+ }
+
+ #[test]
+ fn decode_score_cursor() {
+ setup();
+ }
+
+ #[test]
+ fn score_cursor_round_trip() {
+ setup();
+ test_cursor_round_trip(
+ &F64Cursor {
+ score: 24.6,
+ data: 123,
+ },
+ &F64Cursor {
+ score: 24.6,
+ data: 123,
+ },
+ );
+ }
+}
diff --git a/crates/graphql-api/src/cursors/expr_tuple.rs b/crates/graphql-api/src/cursors/expr_tuple.rs
new file mode 100644
index 0000000..2420823
--- /dev/null
+++ b/crates/graphql-api/src/cursors/expr_tuple.rs
@@ -0,0 +1,69 @@
+use sea_orm::sea_query::SimpleExpr;
+
+#[derive(Debug, Clone)]
+pub enum ExprTuple {
+ One(SimpleExpr),
+ Two(SimpleExpr, SimpleExpr),
+ Three(SimpleExpr, SimpleExpr, SimpleExpr),
+ Many(Vec),
+}
+
+pub trait IntoExprTuple {
+ fn into_expr_tuple(self) -> ExprTuple;
+}
+
+impl> IntoExprTuple for (A,) {
+ fn into_expr_tuple(self) -> ExprTuple {
+ let (a,) = self;
+ ExprTuple::One(a.into())
+ }
+}
+
+impl IntoExprTuple for (A, B)
+where
+ A: Into,
+ B: Into,
+{
+ fn into_expr_tuple(self) -> ExprTuple {
+ let (a, b) = self;
+ ExprTuple::Two(a.into(), b.into())
+ }
+}
+
+impl IntoExprTuple for (A, B, C)
+where
+ A: Into,
+ B: Into,
+ C: Into,
+{
+ fn into_expr_tuple(self) -> ExprTuple {
+ let (a, b, c) = self;
+ ExprTuple::Three(a.into(), b.into(), c.into())
+ }
+}
+
+macro_rules! impl_into_expr_tuple {
+ ($($gen:ident),*) => {
+ impl<$($gen),*> IntoExprTuple for ($($gen),*)
+ where
+ $(
+ $gen: Into
+ ),*
+ {
+ fn into_expr_tuple(self) -> ExprTuple {
+ #[allow(non_snake_case)]
+ let ($($gen),*) = self;
+ ExprTuple::Many(vec![$($gen.into()),*])
+ }
+ }
+
+ };
+}
+
+impl_into_expr_tuple!(A, B, C, D);
+impl_into_expr_tuple!(A, B, C, D, E);
+impl_into_expr_tuple!(A, B, C, D, E, F);
+impl_into_expr_tuple!(A, B, C, D, E, F, G);
+impl_into_expr_tuple!(A, B, C, D, E, F, G, H);
+impl_into_expr_tuple!(A, B, C, D, E, F, G, H, I);
+impl_into_expr_tuple!(A, B, C, D, E, F, G, H, I, K);
diff --git a/crates/graphql-api/src/cursors/query_builder.rs b/crates/graphql-api/src/cursors/query_builder.rs
new file mode 100644
index 0000000..4e423d8
--- /dev/null
+++ b/crates/graphql-api/src/cursors/query_builder.rs
@@ -0,0 +1,266 @@
+use std::marker::PhantomData;
+
+use sea_orm::{
+ Condition, ConnectionTrait, DbErr, DynIden, FromQueryResult, Identity, IntoIdentity, Order,
+ PartialModelTrait, QuerySelect, SelectModel, SelectorTrait,
+ prelude::{Expr, SeaRc},
+ sea_query::{ExprTrait as _, SelectStatement, SimpleExpr},
+};
+
+use crate::cursors::expr_tuple::{ExprTuple, IntoExprTuple};
+
+/// An enhanced version of [`Cursor`][sea_orm::Cursor].
+///
+/// It allows specifying expressions for pagination values, and uses tuple syntax when building the
+/// SQL statement, instead of chaining inner `AND` and `OR` operations.
+///
+/// The last change is the most important, because there seems to be an issue with MariaDB/MySQL,
+/// and beside that, the tuple syntax seems more efficient in every DB engine.
+#[derive(Debug, Clone)]
+pub struct CursorQueryBuilder
{
+ query: SelectStatement,
+ table: DynIden,
+ order_columns: Identity,
+ first: Option,
+ last: Option,
+ before: Option,
+ after: Option,
+ sort_asc: bool,
+ is_result_reversed: bool,
+ phantom: PhantomData,
+}
+
+impl CursorQueryBuilder {
+ /// Create a new cursor
+ pub fn new(query: SelectStatement, table: DynIden, order_columns: C) -> Self
+ where
+ C: IntoIdentity,
+ {
+ Self {
+ query,
+ table,
+ order_columns: order_columns.into_identity(),
+ last: None,
+ first: None,
+ after: None,
+ before: None,
+ sort_asc: true,
+ is_result_reversed: false,
+ phantom: PhantomData,
+ }
+ }
+
+ /// Filter paginated result with corresponding column less than the input value
+ pub fn before(&mut self, values: V) -> &mut Self
+ where
+ V: IntoExprTuple,
+ {
+ self.before = Some(values.into_expr_tuple());
+ self
+ }
+
+ /// Filter paginated result with corresponding column greater than the input value
+ pub fn after(&mut self, values: V) -> &mut Self
+ where
+ V: IntoExprTuple,
+ {
+ self.after = Some(values.into_expr_tuple());
+ self
+ }
+
+ fn apply_filters(&mut self) -> &mut Self {
+ if let Some(values) = self.after.clone() {
+ let condition =
+ self.apply_filter(values, |c, v| if self.sort_asc { c.gt(v) } else { c.lt(v) });
+ self.query.cond_where(condition);
+ }
+
+ if let Some(values) = self.before.clone() {
+ let condition =
+ self.apply_filter(values, |c, v| if self.sort_asc { c.lt(v) } else { c.gt(v) });
+ self.query.cond_where(condition);
+ }
+
+ self
+ }
+
+ fn apply_filter(&self, values: ExprTuple, f: F) -> Condition
+ where
+ F: Fn(SimpleExpr, SimpleExpr) -> SimpleExpr,
+ {
+ match (&self.order_columns, values) {
+ (Identity::Unary(c1), ExprTuple::One(v1)) => {
+ let exp = Expr::col((SeaRc::clone(&self.table), SeaRc::clone(c1)));
+ Condition::all().add(f(exp.into(), v1))
+ }
+ (Identity::Binary(c1, c2), ExprTuple::Two(v1, v2)) => {
+ let c1 = Expr::col((SeaRc::clone(&self.table), SeaRc::clone(c1))).into();
+ let c2 = Expr::col((SeaRc::clone(&self.table), SeaRc::clone(c2))).into();
+ let columns = Expr::tuple([c1, c2]).into();
+ let values = Expr::tuple([v1, v2]).into();
+ Condition::all().add(f(columns, values))
+ }
+ (Identity::Ternary(c1, c2, c3), ExprTuple::Three(v1, v2, v3)) => {
+ let c1 = Expr::col((SeaRc::clone(&self.table), SeaRc::clone(c1))).into();
+ let c2 = Expr::col((SeaRc::clone(&self.table), SeaRc::clone(c2))).into();
+ let c3 = Expr::col((SeaRc::clone(&self.table), SeaRc::clone(c3))).into();
+ let columns = Expr::tuple([c1, c2, c3]).into();
+ let values = Expr::tuple([v1, v2, v3]).into();
+ Condition::all().add(f(columns, values))
+ }
+ (Identity::Many(col_vec), ExprTuple::Many(val_vec))
+ if col_vec.len() == val_vec.len() =>
+ {
+ let columns = Expr::tuple(
+ col_vec
+ .iter()
+ .map(|c| Expr::col((SeaRc::clone(&self.table), SeaRc::clone(c))).into())
+ .collect::>(),
+ )
+ .into();
+ let values = Expr::tuple(val_vec).into();
+ Condition::all().add(f(columns, values))
+ }
+ _ => panic!("column arity mismatch"),
+ }
+ }
+
+ /// Use ascending sort order
+ pub fn asc(&mut self) -> &mut Self {
+ self.sort_asc = true;
+ self
+ }
+
+ /// Use descending sort order
+ pub fn desc(&mut self) -> &mut Self {
+ self.sort_asc = false;
+ self
+ }
+
+ /// Limit result set to only first N rows in ascending order of the order by column
+ pub fn first(&mut self, num_rows: u64) -> &mut Self {
+ self.last = None;
+ self.first = Some(num_rows);
+ self
+ }
+
+ /// Limit result set to only last N rows in ascending order of the order by column
+ pub fn last(&mut self, num_rows: u64) -> &mut Self {
+ self.first = None;
+ self.last = Some(num_rows);
+ self
+ }
+
+ fn resolve_sort_order(&mut self) -> Order {
+ let should_reverse_order = self.last.is_some();
+ self.is_result_reversed = should_reverse_order;
+
+ if (self.sort_asc && !should_reverse_order) || (!self.sort_asc && should_reverse_order) {
+ Order::Asc
+ } else {
+ Order::Desc
+ }
+ }
+
+ fn apply_limit(&mut self) -> &mut Self {
+ if let Some(num_rows) = self.first {
+ self.query.limit(num_rows);
+ } else if let Some(num_rows) = self.last {
+ self.query.limit(num_rows);
+ }
+
+ self
+ }
+
+ fn apply_order_by(&mut self) -> &mut Self {
+ self.query.clear_order_by();
+ let ord = self.resolve_sort_order();
+
+ let query = &mut self.query;
+ let order = |query: &mut SelectStatement, col| {
+ query.order_by((SeaRc::clone(&self.table), SeaRc::clone(col)), ord.clone());
+ };
+ match &self.order_columns {
+ Identity::Unary(c1) => {
+ order(query, c1);
+ }
+ Identity::Binary(c1, c2) => {
+ order(query, c1);
+ order(query, c2);
+ }
+ Identity::Ternary(c1, c2, c3) => {
+ order(query, c1);
+ order(query, c2);
+ order(query, c3);
+ }
+ Identity::Many(vec) => {
+ for col in vec.iter() {
+ order(query, col);
+ }
+ }
+ }
+
+ self
+ }
+
+ /// Construct a [Cursor] that fetch any custom struct
+ pub fn into_model(self) -> CursorQueryBuilder>
+ where
+ M: FromQueryResult,
+ {
+ CursorQueryBuilder {
+ query: self.query,
+ table: self.table,
+ order_columns: self.order_columns,
+ last: self.last,
+ first: self.first,
+ after: self.after,
+ before: self.before,
+ sort_asc: self.sort_asc,
+ is_result_reversed: self.is_result_reversed,
+ phantom: PhantomData,
+ }
+ }
+
+ /// Return a [Selector] from `Self` that wraps a [SelectModel] with a [PartialModel](PartialModelTrait)
+ pub fn into_partial_model(self) -> CursorQueryBuilder>
+ where
+ M: PartialModelTrait,
+ {
+ M::select_cols(QuerySelect::select_only(self)).into_model::()
+ }
+}
+
+impl CursorQueryBuilder
+where
+ S: SelectorTrait,
+{
+ /// Fetch the paginated result
+ pub async fn all(&mut self, db: &C) -> Result, DbErr>
+ where
+ C: ConnectionTrait,
+ {
+ self.apply_limit();
+ self.apply_order_by();
+ self.apply_filters();
+
+ let stmt = db.get_database_backend().build(&self.query);
+ let rows = db.query_all(stmt).await?;
+ let mut buffer = Vec::with_capacity(rows.len());
+ for row in rows.into_iter() {
+ buffer.push(S::from_raw_query_result(row)?);
+ }
+ if self.is_result_reversed {
+ buffer.reverse()
+ }
+ Ok(buffer)
+ }
+}
+
+impl QuerySelect for CursorQueryBuilder {
+ type QueryStatement = SelectStatement;
+
+ fn query(&mut self) -> &mut SelectStatement {
+ &mut self.query
+ }
+}
diff --git a/crates/graphql-api/src/cursors/query_trait.rs b/crates/graphql-api/src/cursors/query_trait.rs
new file mode 100644
index 0000000..eb71c65
--- /dev/null
+++ b/crates/graphql-api/src/cursors/query_trait.rs
@@ -0,0 +1,27 @@
+use sea_orm::{EntityTrait, IntoIdentity, QuerySelect, Select, SelectModel, prelude::SeaRc};
+
+use crate::cursors::query_builder::CursorQueryBuilder;
+
+pub trait CursorPaginable {
+ type Selector;
+
+ fn paginate_cursor_by(
+ self,
+ order_columns: C,
+ ) -> CursorQueryBuilder;
+}
+
+impl CursorPaginable for Select {
+ type Selector = SelectModel<::Model>;
+
+ fn paginate_cursor_by(
+ mut self,
+ order_columns: C,
+ ) -> CursorQueryBuilder {
+ CursorQueryBuilder::new(
+ QuerySelect::query(&mut self).take(),
+ SeaRc::new(E::default()),
+ order_columns,
+ )
+ }
+}
diff --git a/crates/graphql-api/src/error.rs b/crates/graphql-api/src/error.rs
index e536ce8..f72f7ac 100644
--- a/crates/graphql-api/src/error.rs
+++ b/crates/graphql-api/src/error.rs
@@ -1,6 +1,7 @@
-use std::{fmt, ops::RangeInclusive, sync::Arc};
+use std::{error::Error, fmt, sync::Arc};
use records_lib::error::RecordsError;
+use sha2::digest::MacError;
pub(crate) fn map_gql_err(e: async_graphql::Error) -> ApiGqlError {
match e
@@ -13,15 +14,15 @@ pub(crate) fn map_gql_err(e: async_graphql::Error) -> ApiGqlError {
}
}
-#[derive(Debug)]
+#[derive(Debug, PartialEq)]
pub enum CursorDecodeErrorKind {
NotBase64,
NotUtf8,
- WrongPrefix,
- NoTimestamp,
- NoTime,
- NoScore,
- InvalidTimestamp(i64),
+ MissingPrefix,
+ InvalidPrefix,
+ NoSignature,
+ InvalidSignature(MacError),
+ InvalidData,
}
impl fmt::Display for CursorDecodeErrorKind {
@@ -29,74 +30,114 @@ impl fmt::Display for CursorDecodeErrorKind {
match self {
CursorDecodeErrorKind::NotBase64 => f.write_str("not base64"),
CursorDecodeErrorKind::NotUtf8 => f.write_str("not UTF-8"),
- CursorDecodeErrorKind::WrongPrefix => f.write_str("wrong prefix"),
- CursorDecodeErrorKind::NoTimestamp => f.write_str("no timestamp"),
- CursorDecodeErrorKind::NoTime => f.write_str("no time"),
- CursorDecodeErrorKind::NoScore => f.write_str("no score"),
- CursorDecodeErrorKind::InvalidTimestamp(t) => {
- f.write_str("invalid timestamp: ")?;
- fmt::Display::fmt(t, f)
- }
+ CursorDecodeErrorKind::MissingPrefix => f.write_str("missing prefix"),
+ CursorDecodeErrorKind::InvalidPrefix => f.write_str("invalid prefix"),
+ CursorDecodeErrorKind::NoSignature => f.write_str("no signature"),
+ CursorDecodeErrorKind::InvalidSignature(e) => write!(f, "invalid signature: {e}"),
+ CursorDecodeErrorKind::InvalidData => f.write_str("invalid data"),
}
}
}
-#[derive(Debug)]
+impl Error for CursorDecodeErrorKind {
+ fn source(&self) -> Option<&(dyn Error + 'static)> {
+ match self {
+ CursorDecodeErrorKind::NotBase64
+ | CursorDecodeErrorKind::NotUtf8
+ | CursorDecodeErrorKind::MissingPrefix
+ | CursorDecodeErrorKind::InvalidPrefix
+ | CursorDecodeErrorKind::NoSignature
+ | CursorDecodeErrorKind::InvalidData => None,
+ CursorDecodeErrorKind::InvalidSignature(mac_error) => Some(mac_error),
+ }
+ }
+}
+
+#[derive(Debug, PartialEq)]
pub struct CursorDecodeError {
- arg_name: &'static str,
- value: String,
- kind: CursorDecodeErrorKind,
+ pub kind: CursorDecodeErrorKind,
}
-#[derive(Debug)]
-pub struct CursorRangeError {
- arg_name: &'static str,
- value: usize,
- range: RangeInclusive,
+impl From for CursorDecodeError {
+ fn from(kind: CursorDecodeErrorKind) -> Self {
+ Self { kind }
+ }
+}
+
+impl fmt::Display for CursorDecodeError {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.write_str("cursor decode error: ")?;
+ fmt::Display::fmt(&self.kind, f)
+ }
+}
+
+impl Error for CursorDecodeError {
+ fn source(&self) -> Option<&(dyn Error + 'static)> {
+ Some(&self.kind)
+ }
}
#[derive(Debug)]
pub enum ApiGqlErrorKind {
Lib(RecordsError),
- CursorRange(CursorRangeError),
- CursorDecode(CursorDecodeError),
+ PaginationInput,
GqlError(async_graphql::Error),
RecordNotFound { record_id: u32 },
MapNotFound { map_uid: String },
PlayerNotFound { login: String },
}
+impl fmt::Display for ApiGqlErrorKind {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ match self {
+ ApiGqlErrorKind::Lib(records_error) => fmt::Display::fmt(records_error, f),
+ ApiGqlErrorKind::PaginationInput => f.write_str(
+ "cursor pagination input is invalid. \
+ must provide either: `after`, `after` with `first`, \
+ `before`, or `before` with `last`.",
+ ),
+ ApiGqlErrorKind::GqlError(error) => f.write_str(&error.message),
+ ApiGqlErrorKind::RecordNotFound { record_id } => {
+ write!(f, "record `{record_id}` not found")
+ }
+ ApiGqlErrorKind::MapNotFound { map_uid } => {
+ write!(f, "map with UID `{map_uid}` not found")
+ }
+ ApiGqlErrorKind::PlayerNotFound { login } => {
+ write!(f, "player with login `{login}` not found")
+ }
+ }
+ }
+}
+
+impl std::error::Error for ApiGqlErrorKind {
+ fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
+ match self {
+ ApiGqlErrorKind::Lib(records_error) => Some(records_error),
+ ApiGqlErrorKind::PaginationInput => None,
+ ApiGqlErrorKind::GqlError(_) => None,
+ ApiGqlErrorKind::RecordNotFound { .. } => None,
+ ApiGqlErrorKind::MapNotFound { .. } => None,
+ ApiGqlErrorKind::PlayerNotFound { .. } => None,
+ }
+ }
+}
+
#[derive(Debug, Clone)]
pub struct ApiGqlError {
inner: Arc,
}
-impl ApiGqlError {
- pub(crate) fn from_cursor_range_error(
- arg_name: &'static str,
- expected_range: RangeInclusive,
- value: usize,
- ) -> Self {
- Self {
- inner: Arc::new(ApiGqlErrorKind::CursorRange(CursorRangeError {
- arg_name,
- value,
- range: expected_range,
- })),
- }
+impl std::error::Error for ApiGqlError {
+ fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
+ Some(&self.inner)
}
+}
- pub(crate) fn from_cursor_decode_error(
- arg_name: &'static str,
- value: String,
- decode_error: CursorDecodeErrorKind,
- ) -> Self {
+impl ApiGqlError {
+ pub(crate) fn from_pagination_input_error() -> Self {
Self {
- inner: Arc::new(ApiGqlErrorKind::CursorDecode(CursorDecodeError {
- arg_name,
- value,
- kind: decode_error,
- })),
+ inner: Arc::new(ApiGqlErrorKind::PaginationInput),
}
}
@@ -145,36 +186,7 @@ where
impl fmt::Display for ApiGqlError {
#[inline]
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
- match &*self.inner {
- ApiGqlErrorKind::Lib(records_error) => fmt::Display::fmt(records_error, f),
- ApiGqlErrorKind::CursorRange(cursor_error) => {
- write!(
- f,
- "`{}` must be between {} and {} included, got {}",
- cursor_error.arg_name,
- cursor_error.range.start(),
- cursor_error.range.end(),
- cursor_error.value
- )
- }
- ApiGqlErrorKind::CursorDecode(decode_error) => {
- write!(
- f,
- "cursor argument `{}` couldn't be decoded: {}. got `{}`",
- decode_error.arg_name, decode_error.kind, decode_error.value
- )
- }
- ApiGqlErrorKind::GqlError(error) => f.write_str(&error.message),
- ApiGqlErrorKind::RecordNotFound { record_id } => {
- write!(f, "record `{record_id}` not found")
- }
- ApiGqlErrorKind::MapNotFound { map_uid } => {
- write!(f, "map with UID `{map_uid}` not found")
- }
- ApiGqlErrorKind::PlayerNotFound { login } => {
- write!(f, "player with login `{login}` not found")
- }
- }
+ fmt::Display::fmt(&self.inner, f)
}
}
@@ -185,4 +197,4 @@ impl From for async_graphql::Error {
}
}
-pub type GqlResult = Result;
+pub type GqlResult = Result;
diff --git a/crates/graphql-api/src/lib.rs b/crates/graphql-api/src/lib.rs
index b85452f..e1229a5 100644
--- a/crates/graphql-api/src/lib.rs
+++ b/crates/graphql-api/src/lib.rs
@@ -4,3 +4,12 @@ pub mod objects;
pub mod schema;
pub mod cursors;
+
+pub mod config;
+pub(crate) use config::config;
+pub use config::{init_config, set_config};
+
+pub(crate) mod utils;
+
+#[cfg(test)]
+mod tests;
diff --git a/crates/graphql-api/src/objects/event_edition_map.rs b/crates/graphql-api/src/objects/event_edition_map.rs
index 5147c7b..3547fad 100644
--- a/crates/graphql-api/src/objects/event_edition_map.rs
+++ b/crates/graphql-api/src/objects/event_edition_map.rs
@@ -8,8 +8,8 @@ use crate::{
loaders::map::MapLoader,
objects::{
event_edition::EventEdition, map::Map, medal_times::MedalTimes,
- ranked_record::RankedRecord, records_filter::RecordsFilter, sort_state::SortState,
- sortable_fields::MapRecordSortableField,
+ ranked_record::RankedRecord, records_filter::RecordsFilter, sort::MapRecordSort,
+ sort_state::SortState,
},
};
@@ -91,7 +91,7 @@ impl EventEditionMap<'_> {
before: Option,
#[graphql(desc = "Number of records to fetch (default: 50, max: 100)")] first: Option,
#[graphql(desc = "Number of records to fetch from the end (for backward pagination)")] last: Option,
- sort_field: Option,
+ sort: Option,
filter: Option,
) -> GqlResult> {
self.map
@@ -102,7 +102,7 @@ impl EventEditionMap<'_> {
before,
first,
last,
- sort_field,
+ sort,
filter,
)
.await
diff --git a/crates/graphql-api/src/objects/map.rs b/crates/graphql-api/src/objects/map.rs
index e38e1ad..46edc50 100644
--- a/crates/graphql-api/src/objects/map.rs
+++ b/crates/graphql-api/src/objects/map.rs
@@ -6,7 +6,7 @@ use async_graphql::{
use deadpool_redis::redis::AsyncCommands as _;
use entity::{
event_edition, event_edition_maps, global_event_records, global_records, maps, player_rating,
- players, records,
+ records,
};
use records_lib::{
Database, RedisPool, internal,
@@ -16,24 +16,29 @@ use records_lib::{
sync,
};
use sea_orm::{
- ColumnTrait as _, ConnectionTrait, DbConn, EntityTrait as _, FromQueryResult, JoinType, Order,
- QueryFilter as _, QueryOrder as _, QuerySelect as _, StreamTrait,
+ ColumnTrait as _, ConnectionTrait, DbConn, EntityTrait as _, FromQueryResult, QueryFilter as _,
+ QueryOrder as _, QuerySelect as _, StreamTrait,
prelude::Expr,
- sea_query::{Asterisk, ExprTrait as _, Func, Query},
+ sea_query::{Asterisk, ExprTrait as _, Func, IntoValueTuple, Query},
};
use crate::{
cursors::{
- CURSOR_DEFAULT_LIMIT, CURSOR_LIMIT_RANGE, ConnectionParameters, RecordDateCursor,
- RecordRankCursor,
+ ConnectionParameters, RecordDateCursor, RecordRankCursor, expr_tuple::IntoExprTuple,
+ query_trait::CursorPaginable,
},
- error::{self, ApiGqlError, GqlResult},
+ error::{self, ApiGqlError, CursorDecodeError, CursorDecodeErrorKind, GqlResult},
loaders::{map::MapLoader, player::PlayerLoader},
objects::{
event_edition::EventEdition, player::Player, player_rating::PlayerRating,
ranked_record::RankedRecord, records_filter::RecordsFilter,
- related_edition::RelatedEdition, sort_state::SortState,
- sortable_fields::MapRecordSortableField,
+ related_edition::RelatedEdition, sort::MapRecordSort, sort_order::SortOrder,
+ sort_state::SortState, sortable_fields::MapRecordSortableField,
+ },
+ utils::{
+ page_input::{PaginationInput, apply_cursor_input},
+ pagination_result::{PaginationResult, get_paginated},
+ records_filter::apply_filter,
},
};
@@ -136,256 +141,186 @@ async fn get_map_records(
let mut ranked_records = Vec::with_capacity(records.len());
- let mut ranking_session = ranks::RankingSession::try_from_pool(redis_pool).await?;
+ let mut redis_conn = redis_pool.get().await?;
for record in records {
- let rank = ranks::get_rank_in_session(
- &mut ranking_session,
- map_id,
- record.record_player_id,
- record.time,
- event,
- )
- .await?;
+ let rank = ranks::get_rank(&mut redis_conn, map_id, record.time, event).await?;
ranked_records.push(records::RankedRecord { rank, record }.into());
}
Ok(ranked_records)
}
-enum MapRecordCursor {
+pub(crate) enum MapRecordCursor {
Date(RecordDateCursor),
Rank(RecordRankCursor),
}
-fn encode_map_cursor(cursor: &MapRecordCursor) -> String {
- match cursor {
- MapRecordCursor::Date(date_cursor) => date_cursor.encode_cursor(),
- MapRecordCursor::Rank(rank_cursor) => rank_cursor.encode_cursor(),
+impl From for MapRecordCursor {
+ #[inline]
+ fn from(value: RecordDateCursor) -> Self {
+ Self::Date(value)
}
}
-async fn get_map_records_connection(
- conn: &C,
- redis_pool: &RedisPool,
- map_id: u32,
- event: OptEvent<'_>,
- ConnectionParameters {
- before,
- after,
- first,
- last,
- }: ConnectionParameters,
- sort_field: Option,
- filter: Option,
-) -> GqlResult> {
- let limit = if let Some(first) = first {
- if !CURSOR_LIMIT_RANGE.contains(&first) {
- return Err(ApiGqlError::from_cursor_range_error(
- "first",
- CURSOR_LIMIT_RANGE,
- first,
- ));
- }
- first
- } else if let Some(last) = last {
- if !CURSOR_LIMIT_RANGE.contains(&last) {
- return Err(ApiGqlError::from_cursor_range_error(
- "last",
- CURSOR_LIMIT_RANGE,
- last,
- ));
- }
- last
- } else {
- CURSOR_DEFAULT_LIMIT
- };
-
- let has_next_page = after.is_some();
-
- // Decode cursors if provided
- let after = match after {
- Some(cursor) => {
- let cursor = match sort_field {
- Some(MapRecordSortableField::Date) => {
- CursorType::decode_cursor(&cursor).map(MapRecordCursor::Date)
- }
- Some(MapRecordSortableField::Rank) | None => {
- CursorType::decode_cursor(&cursor).map(MapRecordCursor::Rank)
- }
- }
- .map_err(|e| ApiGqlError::from_cursor_decode_error("after", cursor.0, e))?;
+impl From for MapRecordCursor {
+ #[inline]
+ fn from(value: RecordRankCursor) -> Self {
+ Self::Rank(value)
+ }
+}
- Some(cursor)
+impl CursorType for MapRecordCursor {
+ type Error = CursorDecodeError;
+
+ fn decode_cursor(s: &str) -> Result {
+ match RecordRankCursor::decode_cursor(s) {
+ Ok(rank) => Ok(rank.into()),
+ Err(CursorDecodeError {
+ kind: CursorDecodeErrorKind::InvalidPrefix,
+ }) => match RecordDateCursor::decode_cursor(s) {
+ Ok(date) => Ok(date.into()),
+ Err(e) => Err(e),
+ },
+ Err(e) => Err(e),
}
- None => None,
- };
-
- let before = match before {
- Some(cursor) => {
- let cursor = match sort_field {
- Some(MapRecordSortableField::Date) => {
- CursorType::decode_cursor(&cursor).map(MapRecordCursor::Date)
- }
- Some(MapRecordSortableField::Rank) | None => {
- CursorType::decode_cursor(&cursor).map(MapRecordCursor::Rank)
- }
- }
- .map_err(|e| ApiGqlError::from_cursor_decode_error("before", cursor.0, e))?;
+ }
- Some(cursor)
+ fn encode_cursor(&self) -> String {
+ match self {
+ MapRecordCursor::Date(record_date_cursor) => record_date_cursor.encode_cursor(),
+ MapRecordCursor::Rank(record_rank_cursor) => record_rank_cursor.encode_cursor(),
}
- None => None,
- };
-
- update_leaderboard(conn, redis_pool, map_id, event).await?;
-
- let mut select = Query::select();
-
- let select = match event.get() {
- Some((ev, ed)) => select.from_as(global_event_records::Entity, "r").and_where(
- Expr::col(("r", global_event_records::Column::EventId))
- .eq(ev.id)
- .and(Expr::col(("r", global_event_records::Column::EditionId)).eq(ed.id)),
- ),
- None => select.from_as(global_records::Entity, "r"),
}
- .column(Asterisk)
- .and_where(Expr::col(("r", records::Column::MapId)).eq(map_id));
-
- // Apply filters if provided
- if let Some(filter) = filter {
- // For player filters, we need to join with players table
- if let Some(filter) = filter.player {
- select.join_as(
- JoinType::InnerJoin,
- players::Entity,
- "p",
- Expr::col(("r", records::Column::RecordPlayerId))
- .equals(("p", players::Column::Id)),
- );
+}
- if let Some(ref login) = filter.player_login {
- select
- .and_where(Expr::col(("p", players::Column::Login)).like(format!("%{login}%")));
+impl IntoExprTuple for &MapRecordCursor {
+ fn into_expr_tuple(self) -> crate::cursors::expr_tuple::ExprTuple {
+ match self {
+ MapRecordCursor::Date(record_date_cursor) => {
+ IntoExprTuple::into_expr_tuple(record_date_cursor)
}
-
- if let Some(ref name) = filter.player_name {
- select.and_where(
- Func::cust("rm_mp_style")
- .arg(Expr::col(("p", players::Column::Name)))
- .like(format!("%{name}%")),
- );
+ MapRecordCursor::Rank(record_rank_cursor) => {
+ IntoExprTuple::into_expr_tuple(record_rank_cursor)
}
}
+ }
+}
- // Apply date filters
- if let Some(before_date) = filter.before_date {
- select.and_where(Expr::col(("r", records::Column::RecordDate)).lt(before_date));
- }
-
- if let Some(after_date) = filter.after_date {
- select.and_where(Expr::col(("r", records::Column::RecordDate)).gt(after_date));
- }
-
- // Apply time filters
- if let Some(time_gt) = filter.time_gt {
- select.and_where(Expr::col(("r", records::Column::Time)).gt(time_gt));
- }
-
- if let Some(time_lt) = filter.time_lt {
- select.and_where(Expr::col(("r", records::Column::Time)).lt(time_lt));
+impl IntoValueTuple for &MapRecordCursor {
+ fn into_value_tuple(self) -> sea_orm::sea_query::ValueTuple {
+ match self {
+ MapRecordCursor::Date(record_date_cursor) => {
+ IntoValueTuple::into_value_tuple(record_date_cursor)
+ }
+ MapRecordCursor::Rank(record_rank_cursor) => {
+ IntoValueTuple::into_value_tuple(record_rank_cursor)
+ }
}
}
+}
- // Apply cursor filters
- if let Some(cursor) = after {
- let (date, time) = match cursor {
- MapRecordCursor::Date(date) => (date.0, None),
- MapRecordCursor::Rank(rank) => (rank.record_date, Some(rank.time)),
- };
-
- select.and_where(Expr::col(("r", records::Column::RecordDate)).gt(date));
+pub(crate) async fn get_map_records_connection(
+ conn: &C,
+ redis_pool: &RedisPool,
+ map_id: u32,
+ event: OptEvent<'_>,
+ connection_parameters: ConnectionParameters,
+ sort: Option,
+ filter: Option,
+) -> GqlResult> {
+ let pagination_input = PaginationInput::try_from_input(connection_parameters)?;
+ let cursor_encoder = match sort.map(|s| s.field) {
+ Some(MapRecordSortableField::Date) => |record: &records::Model| {
+ RecordDateCursor {
+ record_date: record.record_date.and_utc(),
+ data: record.record_id,
+ }
+ .encode_cursor()
+ },
+ _ => |record: &records::Model| {
+ RecordRankCursor {
+ time: record.time,
+ record_date: record.record_date.and_utc(),
+ data: record.record_id,
+ }
+ .encode_cursor()
+ },
+ };
- if let Some(time) = time {
- select.and_where(Expr::col(("r", records::Column::Time)).gte(time));
- }
- }
+ let mut query =
+ match event.get() {
+ Some((ev, ed)) => {
+ let base_query = apply_filter(
+ global_event_records::Entity::find().filter(
+ global_event_records::Column::MapId
+ .eq(map_id)
+ .and(global_event_records::Column::EventId.eq(ev.id))
+ .and(global_event_records::Column::EditionId.eq(ed.id)),
+ ),
+ filter.as_ref(),
+ );
- if let Some(cursor) = before {
- let (date, time) = match cursor {
- MapRecordCursor::Date(date) => (date.0, None),
- MapRecordCursor::Rank(rank) => (rank.record_date, Some(rank.time)),
- };
+ match (pagination_input.get_cursor(), sort.map(|s| s.field)) {
+ (Some(MapRecordCursor::Date(_)), _)
+ | (_, Some(MapRecordSortableField::Date)) => base_query.paginate_cursor_by((
+ global_event_records::Column::RecordDate,
+ global_event_records::Column::RecordId,
+ )),
+ _ => base_query.paginate_cursor_by((
+ global_event_records::Column::Time,
+ global_event_records::Column::RecordDate,
+ global_event_records::Column::RecordId,
+ )),
+ }
+ .into_model::()
+ }
- select.and_where(Expr::col(("r", records::Column::RecordDate)).lt(date));
+ None => {
+ let base_query = apply_filter(
+ global_records::Entity::find().filter(global_records::Column::MapId.eq(map_id)),
+ filter.as_ref(),
+ );
- if let Some(time) = time {
- select.and_where(Expr::col(("r", records::Column::Time)).lt(time));
- }
- }
+ match (pagination_input.get_cursor(), sort.map(|s| s.field)) {
+ (Some(MapRecordCursor::Date(_)), _)
+ | (_, Some(MapRecordSortableField::Date)) => base_query.paginate_cursor_by((
+ global_event_records::Column::RecordDate,
+ global_event_records::Column::RecordId,
+ )),
+ _ => base_query.paginate_cursor_by((
+ global_records::Column::Time,
+ global_records::Column::RecordDate,
+ global_records::Column::RecordId,
+ )),
+ }
+ .into_model::()
+ }
+ };
- // Apply ordering
- if let Some(MapRecordSortableField::Rank) | None = sort_field {
- select.order_by_expr(
- Expr::col(("r", records::Column::Time)).into(),
- if last.is_some() {
- Order::Desc
- } else {
- Order::Asc
- },
- );
- }
- select.order_by_expr(
- Expr::col(("r", records::Column::RecordDate)).into(),
- if last.is_some() {
- Order::Desc
- } else {
- Order::Asc
- },
- );
+ apply_cursor_input(&mut query, &pagination_input);
- // Fetch one extra to determine if there's a next page
- select.limit((limit + 1) as u64);
+ match sort.and_then(|s| s.order) {
+ Some(SortOrder::Descending) => query.desc(),
+ _ => query.asc(),
+ };
- let stmt = conn.get_database_backend().build(&*select);
- let records = conn
- .query_all(stmt)
- .await?
- .into_iter()
- .map(|result| records::Model::from_query_result(&result, ""))
- .collect::, _>>()?;
+ let PaginationResult {
+ mut connection,
+ iter: records,
+ } = get_paginated(conn, query, &pagination_input).await?;
- let mut connection = connection::Connection::new(has_next_page, records.len() > limit);
+ connection.edges.reserve(records.len());
- let encode_cursor_fn = match sort_field {
- Some(MapRecordSortableField::Date) => |record: &records::Model| {
- encode_map_cursor(&MapRecordCursor::Date(RecordDateCursor(
- record.record_date.and_utc(),
- )))
- },
- Some(MapRecordSortableField::Rank) | None => |record: &records::Model| {
- encode_map_cursor(&MapRecordCursor::Rank(RecordRankCursor {
- record_date: record.record_date.and_utc(),
- time: record.time,
- }))
- },
- };
+ ranks::update_leaderboard(conn, redis_pool, map_id, event).await?;
- let mut ranking_session = ranks::RankingSession::try_from_pool(redis_pool).await?;
+ let mut redis_conn = redis_pool.get().await?;
- for record in records.into_iter().take(limit) {
- let rank = ranks::get_rank_in_session(
- &mut ranking_session,
- map_id,
- record.record_player_id,
- record.time,
- event,
- )
- .await?;
+ for record in records {
+ let rank = ranks::get_rank(&mut redis_conn, record.map_id, record.time, event).await?;
connection.edges.push(connection::Edge::new(
- ID(encode_cursor_fn(&record)),
+ ID((cursor_encoder)(&record)),
records::RankedRecord { rank, record }.into(),
));
}
@@ -426,39 +361,36 @@ impl Map {
before: Option,
first: Option,
last: Option,
- sort_field: Option,
+ sort: Option,
filter: Option,
) -> GqlResult> {
let db = gql_ctx.data_unchecked::();
- records_lib::assert_future_send(sync::transaction(&db.sql_conn, async |txn| {
- connection::query(
- after,
- before,
- first,
- last,
- |after, before, first, last| async move {
- get_map_records_connection(
- txn,
- &db.redis_pool,
- self.inner.id,
- event,
- ConnectionParameters {
- after,
- before,
- first,
- last,
- },
- sort_field,
- filter,
- )
- .await
- },
- )
- .await
- .map_err(error::map_gql_err)
- }))
+ connection::query_with(
+ after,
+ before,
+ first,
+ last,
+ |after, before, first, last| async move {
+ get_map_records_connection(
+ &db.sql_conn,
+ &db.redis_pool,
+ self.inner.id,
+ event,
+ ConnectionParameters {
+ after,
+ before,
+ first,
+ last,
+ },
+ sort,
+ filter,
+ )
+ .await
+ },
+ )
.await
+ .map_err(error::map_gql_err)
}
}
@@ -497,6 +429,10 @@ impl Map {
&self.inner.name
}
+ async fn score(&self) -> f64 {
+ self.inner.score
+ }
+
async fn related_event_editions(
&self,
ctx: &async_graphql::Context<'_>,
@@ -586,7 +522,7 @@ impl Map {
before: Option,
#[graphql(desc = "Number of records to fetch (default: 50, max: 100)")] first: Option,
#[graphql(desc = "Number of records to fetch from the end (for backward pagination)")] last: Option,
- sort_field: Option,
+ sort: Option,
filter: Option,
) -> GqlResult> {
self.get_records_connection(
@@ -596,7 +532,7 @@ impl Map {
before,
first,
last,
- sort_field,
+ sort,
filter,
)
.await
diff --git a/crates/graphql-api/src/objects/map_with_score.rs b/crates/graphql-api/src/objects/map_with_score.rs
index 20407d3..2950a3f 100644
--- a/crates/graphql-api/src/objects/map_with_score.rs
+++ b/crates/graphql-api/src/objects/map_with_score.rs
@@ -6,5 +6,4 @@ use crate::objects::map::Map;
pub struct MapWithScore {
pub rank: i32,
pub map: Map,
- pub score: f64,
}
diff --git a/crates/graphql-api/src/objects/mappack.rs b/crates/graphql-api/src/objects/mappack.rs
index a102e29..7146720 100644
--- a/crates/graphql-api/src/objects/mappack.rs
+++ b/crates/graphql-api/src/objects/mappack.rs
@@ -1,3 +1,4 @@
+use mkenv::prelude::*;
use std::time::SystemTime;
use deadpool_redis::redis::{self, AsyncCommands as _};
@@ -192,7 +193,7 @@ impl Mappack {
.get(mappack_time_key(AnyMappackId::Id(&self.mappack_id)))
.await?;
Ok(last_upd_time
- .map(|last| last + records_lib::env().event_scores_interval.as_secs())
+ .map(|last| last + records_lib::env().event_scores_interval.get().as_secs())
.and_then(|last| {
SystemTime::UNIX_EPOCH
.elapsed()
diff --git a/crates/graphql-api/src/objects/player.rs b/crates/graphql-api/src/objects/player.rs
index 2b8274e..cc890e6 100644
--- a/crates/graphql-api/src/objects/player.rs
+++ b/crates/graphql-api/src/objects/player.rs
@@ -1,27 +1,24 @@
-use async_graphql::connection::CursorType as _;
use async_graphql::{Enum, ID, connection};
-use entity::{global_records, maps, players, records, role};
+use entity::{global_records, players, records, role};
use records_lib::{Database, ranks};
use records_lib::{RedisPool, error::RecordsError, internal, opt_event::OptEvent, sync};
-use sea_orm::Order;
use sea_orm::{
- ColumnTrait as _, ConnectionTrait, DbConn, EntityTrait as _, FromQueryResult, JoinType,
- QueryFilter as _, QueryOrder as _, QuerySelect as _, RelationTrait, StreamTrait,
- prelude::Expr,
- sea_query::{ExprTrait as _, Func},
+ ColumnTrait as _, ConnectionTrait, DbConn, EntityTrait as _, FromQueryResult, QueryFilter as _,
+ QueryOrder as _, QuerySelect as _, StreamTrait, TransactionTrait,
};
+use crate::cursors::RecordDateCursor;
use crate::objects::records_filter::RecordsFilter;
+use crate::objects::root::get_records_connection_impl;
+use crate::objects::sort::UnorderedRecordSort;
+use crate::utils::records_filter::apply_filter;
use crate::{
cursors::ConnectionParameters,
- error::{ApiGqlError, GqlResult},
+ error::GqlResult,
objects::{ranked_record::RankedRecord, sort_state::SortState},
};
-use crate::{
- cursors::{CURSOR_DEFAULT_LIMIT, CURSOR_LIMIT_RANGE, RecordDateCursor},
- error,
-};
+use crate::error;
#[derive(Copy, Clone, Eq, PartialEq, Enum)]
#[repr(u8)]
@@ -74,6 +71,10 @@ impl Player {
self.inner.zone_path.as_deref()
}
+ async fn score(&self) -> f64 {
+ self.inner.score
+ }
+
async fn role(&self, ctx: &async_graphql::Context<'_>) -> GqlResult {
let conn = ctx.data_unchecked::();
@@ -117,37 +118,36 @@ impl Player {
before: Option,
#[graphql(desc = "Number of records to fetch (default: 50, max: 100)")] first: Option,
#[graphql(desc = "Number of records to fetch from the end (for backward pagination)")] last: Option,
+ sort: Option,
filter: Option,
) -> GqlResult> {
let db = ctx.data_unchecked::();
- records_lib::assert_future_send(sync::transaction(&db.sql_conn, async |txn| {
- connection::query(
- after,
- before,
- first,
- last,
- |after, before, first, last| async move {
- get_player_records_connection(
- txn,
- &db.redis_pool,
- self.inner.id,
- Default::default(),
- ConnectionParameters {
- after,
- before,
- first,
- last,
- },
- filter,
- )
- .await
- },
- )
- .await
- .map_err(error::map_gql_err)
- }))
+ connection::query_with(
+ after,
+ before,
+ first,
+ last,
+ |after, before, first, last| async move {
+ get_player_records_connection(
+ &db.sql_conn,
+ &db.redis_pool,
+ self.inner.id,
+ Default::default(),
+ ConnectionParameters {
+ after,
+ before,
+ first,
+ last,
+ },
+ sort,
+ filter,
+ )
+ .await
+ },
+ )
.await
+ .map_err(error::map_gql_err)
}
}
@@ -175,17 +175,10 @@ async fn get_player_records(
let mut ranked_records = Vec::with_capacity(records.len());
- let mut ranking_session = ranks::RankingSession::try_from_pool(redis_pool).await?;
+ let mut redis_conn = redis_pool.get().await?;
for record in records {
- let rank = ranks::get_rank_in_session(
- &mut ranking_session,
- record.map_id,
- record.record_player_id,
- record.time,
- event,
- )
- .await?;
+ let rank = ranks::get_rank(&mut redis_conn, record.map_id, record.time, event).await?;
ranked_records.push(
records::RankedRecord {
@@ -199,158 +192,30 @@ async fn get_player_records(
Ok(ranked_records)
}
-async fn get_player_records_connection(
+pub(crate) async fn get_player_records_connection(
conn: &C,
redis_pool: &RedisPool,
player_id: u32,
event: OptEvent<'_>,
- ConnectionParameters {
- after,
- before,
- first,
- last,
- }: ConnectionParameters,
+ connection_parameters: ConnectionParameters,
+ sort: Option,
filter: Option,
-) -> GqlResult> {
- let limit = if let Some(first) = first {
- if !CURSOR_LIMIT_RANGE.contains(&first) {
- return Err(ApiGqlError::from_cursor_range_error(
- "first",
- CURSOR_LIMIT_RANGE,
- first,
- ));
- }
- first
- } else if let Some(last) = last {
- if !CURSOR_LIMIT_RANGE.contains(&last) {
- return Err(ApiGqlError::from_cursor_range_error(
- "last",
- CURSOR_LIMIT_RANGE,
- last,
- ));
- }
- last
- } else {
- CURSOR_DEFAULT_LIMIT
- };
-
- let has_previous_page = after.is_some();
-
- // Decode cursors if provided
- let after_timestamp = match after {
- Some(cursor) => {
- let decoded = RecordDateCursor::decode_cursor(&cursor)
- .map_err(|e| ApiGqlError::from_cursor_decode_error("after", cursor.0, e))?;
- Some(decoded)
- }
- None => None,
- };
-
- let before_timestamp = match before {
- Some(cursor) => {
- let decoded = RecordDateCursor::decode_cursor(&cursor)
- .map_err(|e| ApiGqlError::from_cursor_decode_error("before", cursor.0, e))?;
- Some(decoded)
- }
- None => None,
- };
-
- // Build query with appropriate ordering
- let mut query =
- global_records::Entity::find().filter(global_records::Column::RecordPlayerId.eq(player_id));
-
- // Apply filters if provided
- if let Some(filter) = filter {
- // Join with maps table if needed for map filters
- if let Some(filter) = filter.map {
- query = query.join_as(
- JoinType::InnerJoin,
- global_records::Relation::Maps.def(),
- "m",
- );
-
- // Apply map UID filter
- if let Some(uid) = filter.map_uid {
- query =
- query.filter(Expr::col(("m", maps::Column::GameId)).like(format!("%{uid}%")));
- }
-
- // Apply map name filter
- if let Some(name) = filter.map_name {
- query = query.filter(
- Func::cust("rm_mp_style")
- .arg(Expr::col(("m", maps::Column::Name)))
- .like(format!("%{name}%")),
- );
- }
- }
-
- // Apply date filters
- if let Some(before_date) = filter.before_date {
- query = query.filter(global_records::Column::RecordDate.lt(before_date));
- }
-
- if let Some(after_date) = filter.after_date {
- query = query.filter(global_records::Column::RecordDate.gt(after_date));
- }
-
- // Apply time filters
- if let Some(time_gt) = filter.time_gt {
- query = query.filter(global_records::Column::Time.gt(time_gt));
- }
-
- if let Some(time_lt) = filter.time_lt {
- query = query.filter(global_records::Column::Time.lt(time_lt));
- }
- }
-
- // Apply cursor filters
- if let Some(timestamp) = after_timestamp {
- query = query.filter(global_records::Column::RecordDate.lt(timestamp.0));
- }
-
- if let Some(timestamp) = before_timestamp {
- query = query.filter(global_records::Column::RecordDate.gt(timestamp.0));
- }
-
- // Apply ordering
- query = query.order_by(
- global_records::Column::RecordDate,
- if last.is_some() {
- Order::Asc
- } else {
- Order::Desc
- },
+) -> GqlResult>
+where
+ C: ConnectionTrait + TransactionTrait,
+{
+ let base_query = apply_filter(
+ global_records::Entity::find().filter(global_records::Column::RecordPlayerId.eq(player_id)),
+ filter.as_ref(),
);
- // Fetch one extra to determine if there's a next page
- query = query.limit((limit + 1) as u64);
-
- let records = query.all(conn).await?;
-
- let mut connection = connection::Connection::new(has_previous_page, records.len() > limit);
-
- let mut ranking_session = ranks::RankingSession::try_from_pool(redis_pool).await?;
-
- for record in records.into_iter().take(limit) {
- let rank = ranks::get_rank_in_session(
- &mut ranking_session,
- record.map_id,
- record.record_player_id,
- record.time,
- event,
- )
- .await?;
-
- connection.edges.push(connection::Edge::new(
- ID(RecordDateCursor(record.record_date.and_utc()).encode_cursor()),
- records::RankedRecord {
- rank,
- record: record.into(),
- }
- .into(),
- ));
- }
-
- Ok(connection)
+ get_records_connection_impl(
+ conn,
+ redis_pool,
+ connection_parameters,
+ event,
+ sort,
+ base_query,
+ )
+ .await
}
diff --git a/crates/graphql-api/src/objects/player_with_score.rs b/crates/graphql-api/src/objects/player_with_score.rs
index 5e53e92..fb806a7 100644
--- a/crates/graphql-api/src/objects/player_with_score.rs
+++ b/crates/graphql-api/src/objects/player_with_score.rs
@@ -6,5 +6,4 @@ use crate::objects::player::Player;
pub struct PlayerWithScore {
pub rank: i32,
pub player: Player,
- pub score: f64,
}
diff --git a/crates/graphql-api/src/objects/root.rs b/crates/graphql-api/src/objects/root.rs
index a8c2d9f..cf08b74 100644
--- a/crates/graphql-api/src/objects/root.rs
+++ b/crates/graphql-api/src/objects/root.rs
@@ -1,31 +1,29 @@
-use std::{collections::HashMap, fmt};
-
use async_graphql::{
- ID, OutputType,
- connection::{self, CursorType as _},
+ ID,
+ connection::{self, CursorType},
};
use deadpool_redis::redis::{AsyncCommands, ToRedisArgs};
use entity::{event as event_entity, event_edition, global_records, maps, players, records};
use records_lib::{
- Database, RedisConnection, RedisPool, internal, must,
+ Database, RedisConnection, RedisPool, must,
opt_event::OptEvent,
ranks,
- redis_key::{map_ranking, player_ranking},
+ redis_key::{MapRanking, PlayerRanking, map_ranking, player_ranking},
sync,
};
use sea_orm::{
- ColumnTrait as _, ConnectionTrait, DbConn, EntityTrait as _, JoinType, Order, QueryFilter as _,
- QueryOrder as _, QuerySelect as _, QueryTrait, RelationTrait, StreamTrait,
+ ColumnTrait, ConnectionTrait, DbConn, EntityTrait, FromQueryResult, Identity, QueryFilter as _,
+ QueryOrder as _, QuerySelect, Select, SelectModel, StreamTrait, TransactionTrait,
prelude::Expr,
- sea_query::{ExprTrait as _, Func},
+ sea_query::{Asterisk, ExprTrait as _, Func, IntoIden, IntoValueTuple, SelectStatement},
};
use crate::{
cursors::{
- CURSOR_DEFAULT_LIMIT, CURSOR_LIMIT_RANGE, ConnectionParameters, F64Cursor,
- RecordDateCursor, TextCursor,
+ ConnectionParameters, F64Cursor, RecordDateCursor, TextCursor, expr_tuple::IntoExprTuple,
+ query_builder::CursorQueryBuilder, query_trait::CursorPaginable,
},
- error::{self, ApiGqlError, GqlResult},
+ error::{self, ApiGqlError, CursorDecodeError, CursorDecodeErrorKind, GqlResult},
objects::{
event::Event,
event_edition::EventEdition,
@@ -38,7 +36,15 @@ use crate::{
player_with_score::PlayerWithScore,
ranked_record::RankedRecord,
records_filter::RecordsFilter,
+ sort::{PlayerMapRankingSort, UnorderedRecordSort},
+ sort_order::SortOrder,
sort_state::SortState,
+ sortable_fields::{PlayerMapRankingSortableField, UnorderedRecordSortableField},
+ },
+ utils::{
+ page_input::{PaginationInput, apply_cursor_input},
+ pagination_result::{PaginationResult, get_paginated},
+ records_filter::apply_filter,
},
};
@@ -56,17 +62,10 @@ async fn get_record(
return Err(ApiGqlError::from_record_not_found_error(record_id));
};
- let mut ranking_session = ranks::RankingSession::try_from_pool(redis_pool).await?;
+ let mut redis_conn = redis_pool.get().await?;
let out = records::RankedRecord {
- rank: ranks::get_rank_in_session(
- &mut ranking_session,
- record.map_id,
- record.record_player_id,
- record.time,
- event,
- )
- .await?,
+ rank: ranks::get_rank(&mut redis_conn, record.map_id, record.time, event).await?,
record,
}
.into();
@@ -94,17 +93,10 @@ async fn get_records(
let mut ranked_records = Vec::with_capacity(records.len());
- let mut ranking_session = ranks::RankingSession::try_from_pool(redis_pool).await?;
+ let mut redis_conn = redis_pool.get().await?;
for record in records {
- let rank = ranks::get_rank_in_session(
- &mut ranking_session,
- record.map_id,
- record.record_player_id,
- record.time,
- event,
- )
- .await?;
+ let rank = ranks::get_rank(&mut redis_conn, record.map_id, record.time, event).await?;
ranked_records.push(
records::RankedRecord {
@@ -118,201 +110,57 @@ async fn get_records(
Ok(ranked_records)
}
-async fn get_records_connection(
+pub(crate) async fn get_records_connection_impl(
conn: &C,
redis_pool: &RedisPool,
- ConnectionParameters {
- after,
- before,
- first,
- last,
- }: ConnectionParameters,
+ connection_parameters: ConnectionParameters,
event: OptEvent<'_>,
- filter: Option,
+ sort: Option,
+ base_query: Select,
) -> GqlResult> {
- let limit = if let Some(first) = first {
- if !CURSOR_LIMIT_RANGE.contains(&first) {
- return Err(ApiGqlError::from_cursor_range_error(
- "first",
- CURSOR_LIMIT_RANGE,
- first,
- ));
- }
- first
- } else if let Some(last) = last {
- if !CURSOR_LIMIT_RANGE.contains(&last) {
- return Err(ApiGqlError::from_cursor_range_error(
- "last",
- CURSOR_LIMIT_RANGE,
- last,
- ));
- }
- last
- } else {
- CURSOR_DEFAULT_LIMIT
- };
-
- let has_previous_page = after.is_some();
-
- // Decode cursors if provided
- let after_timestamp = match after {
- Some(cursor) => {
- let decoded = RecordDateCursor::decode_cursor(&cursor)
- .map_err(|e| ApiGqlError::from_cursor_decode_error("after", cursor.0, e))?;
- Some(decoded)
- }
- None => None,
- };
-
- let before_timestamp = match before {
- Some(cursor) => {
- let decoded = RecordDateCursor::decode_cursor(&cursor)
- .map_err(|e| ApiGqlError::from_cursor_decode_error("before", cursor.0, e))?;
- Some(decoded)
- }
- None => None,
- };
-
- // Build query with appropriate ordering
- let mut query = global_records::Entity::find();
-
- // Apply filters if provided
- if let Some(filter) = filter {
- // Join with players table if needed for player filters
- if filter.player.is_some() {
- query = query.join_as(
- JoinType::InnerJoin,
- global_records::Relation::Players.def(),
- "p",
- );
- }
-
- // Join with maps table if needed for map filters
- if let Some(m) = &filter.map {
- query = query.join_as(
- JoinType::InnerJoin,
- global_records::Relation::Maps.def(),
- "m",
- );
-
- // Join again with players table if filtering on map author
- if m.author.is_some() {
- query = query.join_as(JoinType::InnerJoin, maps::Relation::Players.def(), "p2");
- }
- }
-
- if let Some(filter) = filter.player {
- // Apply player login filter
- if let Some(login) = filter.player_login {
- query = query
- .filter(Expr::col(("p", players::Column::Login)).like(format!("%{login}%")));
- }
-
- // Apply player name filter
- if let Some(name) = filter.player_name {
- query = query.filter(
- Func::cust("rm_mp_style")
- .arg(Expr::col(("p", players::Column::Name)))
- .like(format!("%{name}%")),
- );
- }
- }
-
- if let Some(filter) = filter.map {
- // Apply map UID filter
- if let Some(uid) = filter.map_uid {
- query =
- query.filter(Expr::col(("m", maps::Column::GameId)).like(format!("%{uid}%")));
- }
-
- // Apply map name filter
- if let Some(name) = filter.map_name {
- query = query.filter(
- Func::cust("rm_mp_style")
- .arg(Expr::col(("m", maps::Column::Name)))
- .like(format!("%{name}%")),
- );
- }
-
- if let Some(filter) = filter.author {
- // Apply player login filter
- if let Some(login) = filter.player_login {
- query = query.filter(
- Expr::col(("p2", players::Column::Login)).like(format!("%{login}%")),
- );
- }
-
- // Apply player name filter
- if let Some(name) = filter.player_name {
- query = query.filter(
- Func::cust("rm_mp_style")
- .arg(Expr::col(("p2", players::Column::Name)))
- .like(format!("%{name}%")),
- );
- }
- }
- }
-
- // Apply date filters
- if let Some(before_date) = filter.before_date {
- query = query.filter(global_records::Column::RecordDate.lt(before_date));
- }
-
- if let Some(after_date) = filter.after_date {
- query = query.filter(global_records::Column::RecordDate.gt(after_date));
- }
-
- // Apply time filters
- if let Some(time_gt) = filter.time_gt {
- query = query.filter(global_records::Column::Time.gt(time_gt));
- }
+ let pagination_input = PaginationInput::try_from_input(connection_parameters)?;
- if let Some(time_lt) = filter.time_lt {
- query = query.filter(global_records::Column::Time.lt(time_lt));
- }
- }
-
- // Apply cursor filters
- if let Some(timestamp) = after_timestamp {
- query = query.filter(global_records::Column::RecordDate.lt(timestamp.0));
- }
+ let mut query = base_query.paginate_cursor_by((
+ global_records::Column::RecordDate,
+ global_records::Column::RecordId,
+ ));
- if let Some(timestamp) = before_timestamp {
- query = query.filter(global_records::Column::RecordDate.gt(timestamp.0));
- }
+ apply_cursor_input(&mut query, &pagination_input);
- // Apply ordering
- query = query.order_by(
- global_records::Column::RecordDate,
- if last.is_some() {
- Order::Asc
- } else {
- Order::Desc
- },
+ // Record dates are ordered by desc by default
+ let is_sort_asc = matches!(
+ sort,
+ Some(UnorderedRecordSort {
+ field: UnorderedRecordSortableField::Date,
+ order: Some(SortOrder::Descending),
+ })
);
- // Fetch one extra to determine if there's a next page
- query = query.limit((limit + 1) as u64);
+ if is_sort_asc {
+ query.asc();
+ } else {
+ query.desc();
+ }
- let records = query.all(conn).await?;
+ // inline get_paginated
+ let PaginationResult {
+ mut connection,
+ iter: records,
+ } = get_paginated(conn, query, &pagination_input).await?;
- let mut connection = connection::Connection::new(has_previous_page, records.len() > limit);
connection.edges.reserve(records.len());
- let mut ranking_session = ranks::RankingSession::try_from_pool(redis_pool).await?;
+ let mut redis_conn = redis_pool.get().await?;
- for record in records.into_iter().take(limit) {
- let rank = ranks::get_rank_in_session(
- &mut ranking_session,
- record.map_id,
- record.record_player_id,
- record.time,
- event,
- )
- .await?;
+ for record in records {
+ let rank = ranks::get_rank(&mut redis_conn, record.map_id, record.time, event).await?;
connection.edges.push(connection::Edge::new(
- ID(RecordDateCursor(record.record_date.and_utc()).encode_cursor()),
+ ID(RecordDateCursor {
+ record_date: record.record_date.and_utc(),
+ data: record.record_id,
+ }
+ .encode_cursor()),
records::RankedRecord {
rank,
record: record.into(),
@@ -324,6 +172,27 @@ async fn get_records_connection(
Ok(connection)
}
+pub(crate) async fn get_records_connection(
+ conn: &C,
+ redis_pool: &RedisPool,
+ connection_parameters: ConnectionParameters,
+ event: OptEvent<'_>,
+ sort: Option,
+ filter: Option,
+) -> GqlResult> {
+ let base_query = apply_filter(global_records::Entity::find(), filter.as_ref());
+
+ get_records_connection_impl(
+ conn,
+ redis_pool,
+ connection_parameters,
+ event,
+ sort,
+ base_query,
+ )
+ .await
+}
+
#[async_graphql::Object]
impl QueryRoot {
async fn event_edition_from_mx_id(
@@ -436,6 +305,7 @@ impl QueryRoot {
.await
}
+ #[allow(clippy::too_many_arguments)]
async fn players(
&self,
ctx: &async_graphql::Context<'_>,
@@ -444,34 +314,34 @@ impl QueryRoot {
first: Option,
last: Option