diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index e80f58d..31bddad 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -41,7 +41,7 @@ jobs: sudo mariadb -e "FLUSH PRIVILEGES"; - name: Setup Rust - uses: dtolnay/rust-toolchain@1.89.0 + uses: dtolnay/rust-toolchain@1.92.0 - name: Configure cache uses: Swatinem/rust-cache@v2 @@ -50,7 +50,8 @@ jobs: env: DATABASE_URL: mysql://api:${{ secrets.MARIADB_PW }}@localhost:3306/master_db REDIS_URL: redis://localhost:6379 - run: RUST_BACKTRACE=1 cargo test -F mysql -F test + GQL_API_CURSOR_SECRET_KEY: ${{ secrets.GQL_API_CURSOR_SECRET_KEY }} + run: RUST_BACKTRACE=1 cargo test -F mysql lint: name: Lint @@ -61,13 +62,16 @@ jobs: uses: actions/checkout@v4 - name: Setup Rust - uses: dtolnay/rust-toolchain@1.89.0 + uses: dtolnay/rust-toolchain@1.92.0 - name: Setup Clippy run: rustup component add clippy - name: Run clippy run: | - cargo clippy --no-default-features -- -D warnings - cargo clippy -- -D warnings + cargo clippy --no-default-features -F mysql -- -D warnings + cargo clippy -r --no-default-features -F mysql -- -D warnings + cargo clippy -F mysql -- -D warnings + cargo clippy -r -F mysql -- -D warnings cargo clippy --all-features -- -D warnings + cargo clippy -r --all-features -- -D warnings diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 5909c1a..7c02e2c 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -17,7 +17,7 @@ jobs: - name: Checkout repository uses: actions/checkout@v4 - name: Setup Rust - uses: dtolnay/rust-toolchain@1.89.0 + uses: dtolnay/rust-toolchain@1.92.0 - name: Configure cache uses: Swatinem/rust-cache@v2 - name: Setup pages @@ -26,7 +26,7 @@ jobs: - name: Clean docs folder run: cargo clean --doc - name: Build docs - run: cargo +nightly doc --all-features --no-deps --workspace --exclude clear_redis_mappacks --exclude admin + run: cargo +nightly doc --all-features --no-deps --workspace --exclude admin - name: Add redirect run: echo '' > target/doc/index.html - name: Remove lock file diff --git a/Cargo.toml b/Cargo.toml index f9910dd..7382ba0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,40 +13,41 @@ members = [ "crates/graphql-schema-generator", "crates/player-map-ranking", "crates/compute-player-map-ranking", + "crates/test-env", ] [workspace.dependencies] -thiserror = "2.0.12" -tokio = "1.44.1" -async-graphql = "7.0.15" -sqlx = { version = "0.8.3", features = ["chrono", "mysql", "runtime-tokio"] } -chrono = { version = "0.4.40", features = ["serde"] } -deadpool = { version = "0.12.2", features = ["managed", "rt_tokio_1"] } +thiserror = "2.0.17" +tokio = "1.49.0" +async-graphql = "7.1.0" +sqlx = { version = "0.8.6", features = ["chrono", "mysql", "runtime-tokio"] } +chrono = { version = "0.4.42", features = ["serde"] } +deadpool = { version = "0.12.3", features = ["managed", "rt_tokio_1"] } deadpool-redis = { version = "0.20.0", features = ["rt_tokio_1"] } -serde = "1.0.219" -tracing = "0.1.41" -tracing-subscriber = "0.3.19" -actix-web = "4.10.2" +serde = "1.0.228" +tracing = "0.1.44" +tracing-subscriber = "0.3.22" +actix-web = "4.12.1" actix-cors = "0.7.1" -async-graphql-actix-web = "7.0.15" -tracing-actix-web = "0.7.16" -reqwest = { version = "0.12.14", features = ["json"] } -rand = "0.9.0" -futures = "0.3.27" +async-graphql-actix-web = "7.1.0" +tracing-actix-web = "0.7.20" +reqwest = { version = "0.12.28", features = ["json"] } +rand = "0.9.2" +futures = "0.3.31" sha256 = "1.6.0" actix-session = { version = "0.10.1", features = ["cookie-session"] } -anyhow = "1.0.97" +anyhow = "1.0.100" dotenvy = "0.15.7" itertools = "0.14.0" -once_cell = "1.21.1" -csv = "1.3.0" -mkenv = "0.1.6" +once_cell = "1.21.3" +csv = "1.4.0" +mkenv = "1.0.2" nom = "8.0.0" pin-project-lite = "0.2.16" -sea-orm = { version = "1.1.14", features = [ +sea-orm = { version = "1.1.19", features = [ "runtime-tokio", "macros", "with-chrono", ] } -serde_json = "1.0.143" +serde_json = "1.0.149" prettytable = { version = "0.10.0", default-features = false } diff --git a/crates/admin/Cargo.toml b/crates/admin/Cargo.toml index a50bc98..fe7716c 100644 --- a/crates/admin/Cargo.toml +++ b/crates/admin/Cargo.toml @@ -5,7 +5,7 @@ edition = "2024" [dependencies] anyhow = { workspace = true } -clap = { version = "4.5.16", features = ["derive"] } +clap = { version = "4.5.54", features = ["derive"] } csv = { workspace = true } deadpool-redis = { workspace = true } dotenvy = { workspace = true } diff --git a/crates/admin/src/main.rs b/crates/admin/src/main.rs index 3c05e4c..26a7207 100644 --- a/crates/admin/src/main.rs +++ b/crates/admin/src/main.rs @@ -1,5 +1,5 @@ use clap::Parser; -use mkenv::Env as _; +use mkenv::prelude::*; use records_lib::{Database, DbEnv, LibEnv}; use self::{clear::ClearCommand, leaderboard::LbCommand, populate::PopulateCommand}; @@ -24,7 +24,12 @@ enum EventCommand { Clear(ClearCommand), } -mkenv::make_env!(Env includes [DbEnv as db_env, LibEnv as lib_env]:); +mkenv::make_config! { + struct Env { + db_env: { DbEnv }, + lib_env: { LibEnv }, + } +} #[tokio::main] async fn main() -> anyhow::Result<()> { @@ -33,11 +38,15 @@ async fn main() -> anyhow::Result<()> { .compact() .try_init() .map_err(|e| anyhow::anyhow!("unable to init tracing_subscriber: {e}"))?; - let env = Env::try_get()?; + let env = Env::define(); + env.init(); records_lib::init_env(env.lib_env); - let db = - Database::from_db_url(env.db_env.db_url.db_url, env.db_env.redis_url.redis_url).await?; + let db = Database::from_db_url( + env.db_env.db_url.db_url.get(), + env.db_env.redis_url.redis_url.get(), + ) + .await?; let cmd = Command::parse(); diff --git a/crates/compute-player-map-ranking/Cargo.toml b/crates/compute-player-map-ranking/Cargo.toml index 1d298c6..ebfe68c 100644 --- a/crates/compute-player-map-ranking/Cargo.toml +++ b/crates/compute-player-map-ranking/Cargo.toml @@ -5,7 +5,7 @@ edition = "2024" [dependencies] anyhow = { workspace = true } -clap = { version = "4.5.48", features = ["derive"] } +clap = { version = "4.5.54", features = ["derive"] } dotenvy = { workspace = true } mkenv = { workspace = true } player-map-ranking = { path = "../player-map-ranking" } diff --git a/crates/compute-player-map-ranking/src/main.rs b/crates/compute-player-map-ranking/src/main.rs index 67ed3a9..ea9f523 100644 --- a/crates/compute-player-map-ranking/src/main.rs +++ b/crates/compute-player-map-ranking/src/main.rs @@ -11,11 +11,15 @@ use std::{ use anyhow::Context as _; use chrono::{DateTime, Days, Months, Utc}; use clap::Parser as _; -use mkenv::Env as _; +use mkenv::prelude::*; use records_lib::{DbUrlEnv, time::Time}; use sea_orm::Database; -mkenv::make_env! {AppEnv includes [DbUrlEnv as db_env]:} +mkenv::make_config! { + struct AppEnv { + db_env: { DbUrlEnv }, + } +} #[derive(Clone)] struct SinceDuration { @@ -96,10 +100,11 @@ async fn main() -> anyhow::Result<()> { }) .context("couldn't write header to map ranking file")?; - let db_url = AppEnv::try_get() - .context("couldn't initialize environment")? - .db_env - .db_url; + let app_config = AppEnv::define(); + app_config + .try_init() + .map_err(|e| anyhow::anyhow!("couldn't initialize environment: {e}"))?; + let db_url = app_config.db_env.db_url.get(); let db = Database::connect(db_url) .await .context("couldn't connect to database")?; diff --git a/crates/entity/src/entities/maps.rs b/crates/entity/src/entities/maps.rs index 30d4aeb..c983b04 100644 --- a/crates/entity/src/entities/maps.rs +++ b/crates/entity/src/entities/maps.rs @@ -1,7 +1,7 @@ use sea_orm::entity::prelude::*; /// A ShootMania Obstacle map in the database. -#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Eq)] +#[derive(Clone, Debug, PartialEq, DeriveEntityModel)] #[sea_orm(table_name = "maps")] pub struct Model { /// The map ID. @@ -32,6 +32,8 @@ pub struct Model { pub gold_time: Option, /// The author time of the map. pub author_time: Option, + /// The score of the player, calculated periodically. + pub score: f64, } #[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] diff --git a/crates/entity/src/entities/players.rs b/crates/entity/src/entities/players.rs index c026a6f..d79e2fa 100644 --- a/crates/entity/src/entities/players.rs +++ b/crates/entity/src/entities/players.rs @@ -1,7 +1,7 @@ use sea_orm::entity::prelude::*; /// A player in the database. -#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Eq)] +#[derive(Clone, Debug, PartialEq, DeriveEntityModel)] #[sea_orm(table_name = "players")] pub struct Model { /// The player ID. @@ -20,6 +20,8 @@ pub struct Model { pub admins_note: Option, /// The player role. pub role: u8, + /// The score of the player, calculated periodically. + pub score: f64, } #[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] diff --git a/crates/game_api/Cargo.toml b/crates/game_api/Cargo.toml index 96ee5c7..36305d1 100644 --- a/crates/game_api/Cargo.toml +++ b/crates/game_api/Cargo.toml @@ -19,7 +19,7 @@ tracing = { workspace = true } tracing-subscriber = { workspace = true, features = ["env-filter"] } actix-web = { workspace = true } actix-cors = { workspace = true } -actix-http = "3.11.0" +actix-http = "3.11.2" async-graphql-actix-web = { workspace = true } tracing-actix-web = { workspace = true } reqwest = { workspace = true, features = ["multipart"] } @@ -44,12 +44,12 @@ graphql-api = { path = "../graphql-api" } serde_json = { workspace = true } [dev-dependencies] -records-lib = { path = "../records_lib" } +test-env = { path = "../test-env" } +records-lib = { path = "../records_lib", features = ["tracing", "mock"] } [features] default = ["request_filter"] request_filter = ["dep:pin-project-lite", "dep:request_filter"] auth = [] -mysql = ["records-lib/mysql"] -postgres = ["records-lib/postgres"] -test = ["records-lib/test"] +mysql = ["records-lib/mysql", "test-env/mysql"] +postgres = ["records-lib/postgres", "test-env/postgres"] diff --git a/crates/game_api/build.rs b/crates/game_api/build.rs index fb8d5e7..8bd602c 100644 --- a/crates/game_api/build.rs +++ b/crates/game_api/build.rs @@ -2,6 +2,4 @@ fn main() { #[cfg(all(feature = "auth", not(test)))] println!("cargo:rustc-cfg=auth"); println!("cargo:rustc-check-cfg=cfg(auth)"); - - println!("cargo:rustc-check-cfg=cfg(test_force_db_deletion)"); } diff --git a/crates/game_api/src/auth/gen_token.rs b/crates/game_api/src/auth/gen_token.rs index ba4c140..d98e2ba 100644 --- a/crates/game_api/src/auth/gen_token.rs +++ b/crates/game_api/src/auth/gen_token.rs @@ -1,4 +1,5 @@ use deadpool_redis::redis::AsyncCommands as _; +use mkenv::Layer as _; use records_lib::{ RedisConnection, gen_random_str, redis_key::{mp_token_key, web_token_key}, @@ -23,7 +24,7 @@ pub async fn gen_token_for( let mp_key = mp_token_key(login); let web_key = web_token_key(login); - let ex = crate::env().auth_token_ttl as _; + let ex = crate::env().auth_token_ttl.get() as _; let mp_token_hash = digest(&*mp_token); let web_token_hash = digest(&*web_token); diff --git a/crates/game_api/src/configure.rs b/crates/game_api/src/configure.rs index 0c305a0..29b8992 100644 --- a/crates/game_api/src/configure.rs +++ b/crates/game_api/src/configure.rs @@ -9,6 +9,7 @@ use actix_web::{ web, }; use dsc_webhook::{FormattedRequestHead, WebhookBody, WebhookBodyEmbed, WebhookBodyEmbedField}; +use mkenv::prelude::*; use records_lib::{Database, pool::clone_dbconn}; use tracing_actix_web::{DefaultRootSpanBuilder, RequestId}; @@ -163,7 +164,7 @@ pub(crate) fn send_internal_err_msg_detached( tokio::task::spawn(async move { if let Err(e) = client - .post(&crate::env().wh_report_url) + .post(crate::env().wh_report_url.get()) .json(&wh_msg) .send() .await diff --git a/crates/game_api/src/env.rs b/crates/game_api/src/env.rs index 1c5f571..9e1b987 100644 --- a/crates/game_api/src/env.rs +++ b/crates/game_api/src/env.rs @@ -1,163 +1,160 @@ -use mkenv::{Env as _, EnvSplitIncluded as _}; +use mkenv::{make_config, prelude::*}; use once_cell::sync::OnceCell; use records_lib::{DbEnv, LibEnv}; -#[cfg(feature = "test")] -const DEFAULT_SESSION_KEY: &str = ""; - -mkenv::make_env! {pub ApiEnvUsedOnce: - #[cfg(not(feature = "test"))] - sess_key: { - id: SessKey(String), - kind: file, - var: "RECORDS_API_SESSION_KEY_FILE", - desc: "The path to the file containing the session key used by the API", - }, - - #[cfg(feature = "test")] - sess_key: { - id: SessKey(String), - kind: normal, - var: "RECORDS_API_SESSION_KEY", - desc: "The session key used by the API", - default: DEFAULT_SESSION_KEY, - }, - - wh_invalid_req_url: { - id: WebhookInvalidReqUrl(String), - kind: normal, - var: "WEBHOOK_INVALID_REQ_URL", - desc: "The URL to the Discord webhook used to flag invalid requests", - default: DEFAULT_WH_INVALID_REQ_URL, - }, +#[cfg(not(debug_assertions))] +mkenv::make_config! { + pub struct DynamicApiEnv { + pub sess_key: { + var_name: "RECORDS_API_SESSION_KEY_FILE", + layers: [file_read()], + description: "The path to the file containing the session key used by the API", + }, + + pub mp_client_id: { + var_name: "RECORDS_MP_APP_CLIENT_ID_FILE", + layers: [file_read()], + description: "The path to the file containing the Obstacle ManiaPlanet client ID", + }, + + pub mp_client_secret: { + var_name: "RECORDS_MP_APP_CLIENT_SECRET_FILE", + layers: [file_read()], + description: "The path to the file containing the Obstacle ManiaPlanet client secret", + }, + } } -const DEFAULT_PORT: u16 = 3000; -const DEFAULT_TOKEN_TTL: u32 = 15_552_000; -#[cfg(feature = "test")] -const DEFAULT_MP_CLIENT_ID: &str = ""; -#[cfg(feature = "test")] -const DEFAULT_MP_CLIENT_SECRET: &str = ""; -const DEFAULT_WH_REPORT_URL: &str = ""; -const DEFAULT_WH_AC_URL: &str = ""; -const DEFAULT_WH_INVALID_REQ_URL: &str = ""; -const DEFAULT_WH_RANK_COMPUTE_ERROR: &str = ""; -const DEFAULT_GQL_ENDPOINT: &str = "/graphql"; - -mkenv::make_env! {pub ApiEnv includes [ - DbEnv as db_env, - LibEnv as lib_env, - ApiEnvUsedOnce as used_once -]: - port: { - id: Port(u16), - kind: parse, - var: "RECORDS_API_PORT", - desc: "The port used to expose the API", - default: DEFAULT_PORT, - }, - - #[cfg(not(debug_assertions))] - host: { - id: Host(String), - kind: normal, - var: "RECORDS_API_HOST", - desc: "The hostname of the server where the API is running (e.g. https://obstacle.titlepack.io)", - }, - - auth_token_ttl: { - id: AuthTokenTtl(u32), - kind: parse, - var: "RECORDS_API_TOKEN_TTL", - desc: "The TTL (time-to-live) of an authentication token or anything related to it (in seconds)", - default: DEFAULT_TOKEN_TTL, - }, - - #[cfg(not(feature = "test"))] - mp_client_id: { - id: MpClientId(String), - kind: file, - var: "RECORDS_MP_APP_CLIENT_ID_FILE", - desc: "The path to the file containing the Obstacle ManiaPlanet client ID", - }, - #[cfg(feature = "test")] - mp_client_id: { - id: MpClientId(String), - kind: normal, - var: "RECORDS_MP_APP_CLIENT_ID", - desc: "The Obstacle ManiaPlanet client ID", - default: DEFAULT_MP_CLIENT_ID, - }, - - #[cfg(not(feature = "test"))] - mp_client_secret: { - id: MpClientSecret(String), - kind: file, - var: "RECORDS_MP_APP_CLIENT_SECRET_FILE", - desc: "The path to the file containing the Obstacle ManiaPlanet client secret", - }, - #[cfg(feature = "test")] - mp_client_secret: { - id: MpClientSecret(String), - kind: normal, - var: "RECORDS_MP_APP_CLIENT_SECRET", - desc: "The Obstacle ManiaPlanet client secret", - default: DEFAULT_MP_CLIENT_SECRET, - }, - - wh_report_url: { - id: WebhookReportUrl(String), - kind: normal, - var: "WEBHOOK_REPORT_URL", - desc: "The URL to the Discord webhook used to report errors", - default: DEFAULT_WH_REPORT_URL, - }, - - wh_ac_url: { - id: WebhookAcUrl(String), - kind: normal, - var: "WEBHOOK_AC_URL", - desc: "The URL to the Discord webhook used to share in-game statistics", - default: DEFAULT_WH_AC_URL, - }, - - gql_endpoint: { - id: GqlEndpoint(String), - kind: normal, - var: "GQL_ENDPOINT", - desc: "The route to the GraphQL endpoint (e.g. /graphql)", - default: DEFAULT_GQL_ENDPOINT, - }, - - wh_rank_compute_err: { - id: WebhookRankComputeError(String), - kind: normal, - var: "WEBHOOK_RANK_COMPUTE_ERROR", - desc: "The URL to the Discord webhook used to send rank compute errors", - default: DEFAULT_WH_RANK_COMPUTE_ERROR, - }, +#[cfg(debug_assertions)] +mkenv::make_config! { + pub struct DynamicApiEnv { + pub sess_key: { + var_name: "RECORDS_API_SESSION_KEY", + layers: [or_default()], + description: "The session key used by the API", + default_val_fmt: "empty", + }, + + pub mp_client_id: { + var_name: "RECORDS_MP_APP_CLIENT_ID", + layers: [or_default()], + description: "The Obstacle ManiaPlanet client ID", + default_val_fmt: "empty", + }, + + pub mp_client_secret: { + var_name: "RECORDS_MP_APP_CLIENT_SECRET", + layers: [or_default()], + description: "The Obstacle ManiaPlanet client secret", + default_val_fmt: "empty", + }, + } } -pub struct InitEnvOut { - pub db_env: DbEnv, - pub used_once: ApiEnvUsedOnce, +#[cfg(debug_assertions)] +mkenv::make_config! { + pub struct Hostname {} } -static ENV: OnceCell = OnceCell::new(); +#[cfg(not(debug_assertions))] +mkenv::make_config! { + pub struct Hostname { + pub host: { + var_name: "RECORDS_API_HOST", + description: "The hostname of the server where the API is running (e.g. https://obstacle.titlepack.io)", + } + } +} + +mkenv::make_config! { + pub struct ApiEnv { + pub db_env: { DbEnv }, + + pub dynamic: { DynamicApiEnv }, + + pub wh_invalid_req_url: { + var_name: "WEBHOOK_INVALID_REQ_URL", + layers: [or_default()], + description: "The URL to the Discord webhook used to flag invalid requests", + default_val_fmt: "empty", + }, + + pub port: { + var_name: "RECORDS_API_PORT", + layers: [ + parsed_from_str(), + or_default(), + ], + description: "The port used to expose the API", + default_val_fmt: "3000", + }, + + pub host: { Hostname }, + + pub auth_token_ttl: { + var_name: "RECORDS_API_TOKEN_TTL", + layers: [ + parsed_from_str(), + or_default_val(|| 180 * 24 * 3600), + ], + description: "The TTL (time-to-live) of an authentication token or anything related to it (in seconds)", + default_val_fmt: "180 days", + }, + + pub wh_report_url: { + var_name: "WEBHOOK_REPORT_URL", + layers: [or_default()], + description: "The URL to the Discord webhook used to report errors", + default_val_fmt: "empty", + }, + + pub wh_ac_url: { + var_name: "WEBHOOK_AC_URL", + layers: [or_default()], + description: "The URL to the Discord webhook used to share in-game statistics", + default_val_fmt: "empty", + }, + + + pub gql_endpoint: { + var_name: "GQL_ENDPOINT", + layers: [ + or_default_val(|| "/graphql".to_owned()), + ], + description: "The route to the GraphQL endpoint (e.g. /graphql)", + default_val_fmt: "/graphql", + }, + + pub wh_rank_compute_err: { + var_name: "WEBHOOK_RANK_COMPUTE_ERROR", + layers: [or_default()], + description: "The URL to the Discord webhook used to send rank compute errors", + default_val_fmt: "empty", + }, + } +} -pub fn env() -> &'static mkenv::init_env!(ApiEnv) { - // SAFETY: this function is always called when the `init_env()` is called at the start. - unsafe { ENV.get_unchecked() } +make_config! { + struct All { + api_env: { ApiEnv }, + lib_env: { LibEnv }, + gql_env: { graphql_api::config::ApiConfig }, + } } -pub fn init_env() -> anyhow::Result { - let env = ApiEnv::try_get()?; - let (included, rest) = env.split(); - records_lib::init_env(included.lib_env); - let _ = ENV.set(rest); +static ENV: OnceCell = OnceCell::new(); + +pub fn env() -> &'static ApiEnv { + ENV.get().unwrap() +} + +pub fn init_env() -> anyhow::Result<()> { + let env = All::define(); + env.try_init().map_err(|e| anyhow::anyhow!("{e}"))?; + + records_lib::init_env(env.lib_env); + let _ = graphql_api::set_config(env.gql_env); + let _ = ENV.set(env.api_env); - Ok(InitEnvOut { - db_env: included.db_env, - used_once: included.used_once, - }) + Ok(()) } diff --git a/crates/game_api/src/graphql.rs b/crates/game_api/src/graphql.rs index 425fd0e..4a7f778 100644 --- a/crates/game_api/src/graphql.rs +++ b/crates/game_api/src/graphql.rs @@ -6,6 +6,7 @@ use async_graphql::http::{GraphQLPlaygroundConfig, playground_source}; use async_graphql_actix_web::GraphQLRequest; use graphql_api::error::{ApiGqlError, ApiGqlErrorKind}; use graphql_api::schema::{Schema, create_schema}; +use mkenv::prelude::*; use records_lib::Database; use reqwest::Client; use tracing_actix_web::RequestId; @@ -79,7 +80,7 @@ async fn index_playground() -> impl Responder { HttpResponse::Ok() .content_type("text/html; charset=utf-8") .body(playground_source(GraphQLPlaygroundConfig::new( - &crate::env().gql_endpoint, + &crate::env().gql_endpoint.get(), ))) } diff --git a/crates/game_api/src/http.rs b/crates/game_api/src/http.rs index d858c06..e58839f 100644 --- a/crates/game_api/src/http.rs +++ b/crates/game_api/src/http.rs @@ -9,6 +9,8 @@ pub mod player; use std::fmt; +use mkenv::prelude::*; + use actix_web::body::BoxBody; use actix_web::dev::{ServiceFactory, ServiceRequest, ServiceResponse}; use actix_web::web::{JsonConfig, Query}; @@ -131,7 +133,7 @@ async fn report_error( } client - .post(&crate::env().wh_report_url) + .post(crate::env().wh_report_url.get()) .json(&WebhookBody { content: format!("Error reported (mode version: {mode_vers})"), embeds: vec![ diff --git a/crates/game_api/src/http/event.rs b/crates/game_api/src/http/event.rs index 8f08205..f319f80 100644 --- a/crates/game_api/src/http/event.rs +++ b/crates/game_api/src/http/event.rs @@ -12,6 +12,7 @@ use entity::{ }; use futures::TryStreamExt; use itertools::Itertools; +use mkenv::prelude::*; use records_lib::{ Database, Expirable as _, NullableInteger, NullableReal, NullableText, RedisPool, error::RecordsError, @@ -215,7 +216,7 @@ fn db_align_to_mp_align( ) -> NullableText { match alignment { None if pos_x.is_none() && pos_y.is_none() => { - Some(records_lib::env().ingame_default_titles_align) + Some(records_lib::env().ingame_default_titles_align.get()) } Some(_) if pos_x.is_some() || pos_y.is_some() => None, other => other, @@ -245,7 +246,7 @@ impl From for EventEditionInGameParams { put_subtitle_on_newline: value .put_subtitle_on_newline .map(|b| b != 0) - .unwrap_or_else(|| records_lib::env().ingame_default_subtitle_on_newline), + .unwrap_or_else(|| records_lib::env().ingame_default_subtitle_on_newline.get()), titles_pos_x: value.titles_pos_x.into(), titles_pos_y: value.titles_pos_y.into(), lb_link_pos_x: value.lb_link_pos_x.into(), @@ -262,22 +263,25 @@ impl Default for EventEditionInGameParams { titles_align: NullableText(Some( records_lib::env() .ingame_default_titles_align + .get() .to_char() .to_string(), )), lb_link_align: NullableText(Some( records_lib::env() .ingame_default_lb_link_align + .get() .to_char() .to_string(), )), authors_align: NullableText(Some( records_lib::env() .ingame_default_authors_align + .get() .to_char() .to_string(), )), - put_subtitle_on_newline: records_lib::env().ingame_default_subtitle_on_newline, + put_subtitle_on_newline: records_lib::env().ingame_default_subtitle_on_newline.get(), titles_pos_x: NullableReal(None), titles_pos_y: NullableReal(None), lb_link_pos_x: NullableReal(None), diff --git a/crates/game_api/src/http/overview.rs b/crates/game_api/src/http/overview.rs index 715eaa9..f2eaeca 100644 --- a/crates/game_api/src/http/overview.rs +++ b/crates/game_api/src/http/overview.rs @@ -169,7 +169,8 @@ async fn get_rank( match min_time { Some(time) => { - let rank = ranks::get_rank(redis_pool, map.id, p.id, time, event) + let mut redis_conn = redis_pool.get().await.with_api_err()?; + let rank = ranks::get_rank(&mut redis_conn, map.id, time, event) .await .with_api_err()?; Ok(Some(rank)) diff --git a/crates/game_api/src/http/player.rs b/crates/game_api/src/http/player.rs index aeceba6..58238af 100644 --- a/crates/game_api/src/http/player.rs +++ b/crates/game_api/src/http/player.rs @@ -8,6 +8,7 @@ use actix_web::{ use deadpool_redis::redis::AsyncCommands as _; use entity::{banishments, current_bans, maps, players, records, role, types}; use futures::TryStreamExt; +use mkenv::prelude::*; use records_lib::{Database, RedisPool, must, player, redis_key::alone_map_key, sync}; use reqwest::Client; use sea_orm::{ @@ -475,7 +476,7 @@ async fn report_error( }; client - .post(&crate::env().wh_report_url) + .post(crate::env().wh_report_url.get()) .json(&WebhookBody { content, embeds: vec![ @@ -518,7 +519,7 @@ struct ACBody { async fn ac(Res(client): Res, Json(body): Json) -> RecordsResult { client - .post(&crate::env().wh_ac_url) + .post(crate::env().wh_ac_url.get()) .json(&WebhookBody { content: format!("Map has been finished in {}", body.run_time), embeds: vec![WebhookBodyEmbed { diff --git a/crates/game_api/src/http/player/auth/with_auth.rs b/crates/game_api/src/http/player/auth/with_auth.rs index 3453fc9..11a708b 100644 --- a/crates/game_api/src/http/player/auth/with_auth.rs +++ b/crates/game_api/src/http/player/auth/with_auth.rs @@ -1,5 +1,6 @@ use actix_session::Session; use actix_web::{HttpResponse, Responder, web}; +use mkenv::Layer as _; use records_lib::Database; use reqwest::{Client, StatusCode}; use tokio::time::timeout; @@ -44,8 +45,8 @@ async fn test_access_token( .post("https://prod.live.maniaplanet.com/login/oauth2/access_token") .form(&MPAccessTokenBody { grant_type: "authorization_code", - client_id: &crate::env().mp_client_id, - client_secret: &crate::env().mp_client_secret, + client_id: &crate::env().dynamic.mp_client_id.get(), + client_secret: &crate::env().dynamic.mp_client_secret.get(), code, redirect_uri, }) diff --git a/crates/game_api/src/http/player_finished.rs b/crates/game_api/src/http/player_finished.rs index ca4ab23..4ff8fc7 100644 --- a/crates/game_api/src/http/player_finished.rs +++ b/crates/game_api/src/http/player_finished.rs @@ -232,12 +232,14 @@ where }) .await?; + let mut redis_conn = redis_pool.get().await.with_api_err()?; + let (old, new, has_improved, old_rank) = match result.old_record { Some(records::Model { time: old, .. }) => ( old, params.body.time, params.body.time < old, - Some(ranks::get_rank(redis_pool, map.id, player_id, old, params.event).await?), + Some(ranks::get_rank(&mut redis_conn, map.id, old, params.event).await?), ), None => (params.body.time, params.body.time, true, None), }; @@ -261,7 +263,7 @@ where let (count,): (i32,) = pipe.query_async(&mut redis_conn).await.with_api_err()?; count + 1 } else { - ranks::get_rank(redis_pool, map.id, player_id, old, params.event) + ranks::get_rank(&mut redis_conn, map.id, old, params.event) .await .with_api_err()? }; diff --git a/crates/game_api/src/main.rs b/crates/game_api/src/main.rs index e76e8bc..7633e40 100644 --- a/crates/game_api/src/main.rs +++ b/crates/game_api/src/main.rs @@ -17,6 +17,7 @@ use actix_web::{ use anyhow::Context; use game_api_lib::configure; use migration::MigratorTrait; +use mkenv::prelude::*; use records_lib::Database; use tracing::level_filters::LevelFilter; use tracing_actix_web::TracingLogger; @@ -26,15 +27,18 @@ use tracing_subscriber::{EnvFilter, fmt::format::FmtSpan}; #[tokio::main] async fn main() -> anyhow::Result<()> { dotenvy::dotenv()?; - let env = game_api_lib::init_env()?; + game_api_lib::init_env()?; #[cfg(feature = "request_filter")] - request_filter::init_wh_url(env.used_once.wh_invalid_req_url).map_err(|_| { + request_filter::init_wh_url(game_api_lib::env().wh_invalid_req_url.get()).map_err(|_| { game_api_lib::internal!("Invalid request WH URL isn't supposed to be set twice") })?; - let db = Database::from_db_url(env.db_env.db_url.db_url, env.db_env.redis_url.redis_url) - .await - .context("Cannot initialize database connection")?; + let db = Database::from_db_url( + game_api_lib::env().db_env.db_url.db_url.get(), + game_api_lib::env().db_env.redis_url.redis_url.get(), + ) + .await + .context("Cannot initialize database connection")?; migration::Migrator::up(&db.sql_conn, None) .await @@ -72,8 +76,7 @@ async fn main() -> anyhow::Result<()> { tracing::info!("Using max connections: {max_connections}"); - let sess_key = Key::from(env.used_once.sess_key.as_bytes()); - drop(env.used_once.sess_key); + let sess_key = Key::from(game_api_lib::env().dynamic.sess_key.get().as_bytes()); HttpServer::new(move || { let cors = Cors::default() @@ -84,7 +87,7 @@ async fn main() -> anyhow::Result<()> { #[cfg(debug_assertions)] let cors = cors.allow_any_origin(); #[cfg(not(debug_assertions))] - let cors = cors.allowed_origin(&game_api_lib::env().host); + let cors = cors.allowed_origin(&game_api_lib::env().host.host.get()); App::new() .wrap(cors) @@ -96,13 +99,13 @@ async fn main() -> anyhow::Result<()> { .cookie_secure(cfg!(not(debug_assertions))) .cookie_content_security(CookieContentSecurity::Private) .session_lifecycle(PersistentSession::default().session_ttl( - CookieDuration::seconds(game_api_lib::env().auth_token_ttl as i64), + CookieDuration::seconds(game_api_lib::env().auth_token_ttl.get() as i64), )) .build(), ) .configure(|cfg| configure::configure(cfg, db.clone())) }) - .bind(("0.0.0.0", game_api_lib::env().port)) + .bind(("0.0.0.0", game_api_lib::env().port.get())) .context("Cannot bind address")? .run() .await diff --git a/crates/game_api/tests/base.rs b/crates/game_api/tests/base.rs index b1ba4ef..811b287 100644 --- a/crates/game_api/tests/base.rs +++ b/crates/game_api/tests/base.rs @@ -1,6 +1,6 @@ #![allow(dead_code)] -use std::{fmt, panic}; +use std::fmt; use actix_http::Request; use actix_web::{ @@ -9,15 +9,11 @@ use actix_web::{ dev::{Service, ServiceResponse}, middleware, test, }; -use anyhow::Context; -use futures::FutureExt; -use migration::MigratorTrait as _; -use records_lib::{Database, pool::get_redis_pool}; -use sea_orm::{ConnectionTrait, DbConn}; +use records_lib::Database; +use test_env::IntoResult; use tracing_actix_web::TracingLogger; use game_api_lib::{configure, init_env}; -use tracing_subscriber::fmt::TestWriter; #[derive(Debug, serde::Deserialize)] pub struct ErrorResponse { @@ -26,17 +22,16 @@ pub struct ErrorResponse { pub message: String, } -pub fn get_env() -> anyhow::Result { - match dotenvy::dotenv() { - Err(err) if !err.not_found() => return Err(err).context("retrieving .env files"), - _ => (), - } - - let _ = tracing_subscriber::fmt() - .with_writer(TestWriter::new()) - .try_init(); - - init_env() +pub async fn with_db(test: F) -> anyhow::Result<::Out> +where + F: AsyncFnOnce(Database) -> R, + R: IntoResult, +{ + test_env::wrap(async |db| { + init_env()?; + test(db).await.into_result() + }) + .await } pub async fn get_app( @@ -61,7 +56,7 @@ pub enum ApiError { impl fmt::Display for ApiError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - ApiError::InvalidJson(raw, deser_err) => match str::from_utf8(&raw) { + ApiError::InvalidJson(raw, deser_err) => match str::from_utf8(raw) { Ok(s) => write!( f, "Invalid JSON returned by the API: {s}\nError when deserializing: {deser_err}" @@ -90,125 +85,6 @@ impl fmt::Display for ApiError { impl std::error::Error for ApiError {} -pub trait IntoResult { - type Out; - - fn into_result(self) -> anyhow::Result; -} - -impl IntoResult for () { - type Out = (); - - fn into_result(self) -> anyhow::Result { - Ok(()) - } -} - -impl IntoResult for Result -where - anyhow::Error: From, -{ - type Out = T; - - fn into_result(self) -> anyhow::Result { - self.map_err(From::from) - } -} - -pub async fn with_db(test: F) -> anyhow::Result<::Out> -where - F: AsyncFnOnce(Database) -> R, - R: IntoResult, -{ - let env = get_env()?; - wrap(env.db_env.db_url.db_url, async |sql_conn| { - let db = Database { - sql_conn, - redis_pool: get_redis_pool(env.db_env.redis_url.redis_url)?, - }; - test(db).await.into_result() - }) - .await -} - -pub async fn wrap(db_url: String, test: F) -> anyhow::Result<::Out> -where - F: AsyncFnOnce(DbConn) -> R, - R: IntoResult, -{ - let master_db = sea_orm::Database::connect(&db_url).await?; - - // For some reasons, on MySQL/MariaDB, using a schema name with some capital letters - // may produce the error code 1932 (42S02) "Table 'X' doesn't exist in engine" when - // doing a query. - let db_name = format!( - "_test_db_{}", - records_lib::gen_random_str(10).to_lowercase() - ); - - master_db - .execute_unprepared(&format!("create database {db_name}")) - .await?; - println!("Created database {db_name}"); - - let db = match master_db { - #[cfg(feature = "mysql")] - sea_orm::DatabaseConnection::SqlxMySqlPoolConnection(_) => { - let connect_options = master_db.get_mysql_connection_pool().connect_options(); - let connect_options = (*connect_options).clone(); - let options = connect_options.database(&db_name); - let db = sqlx::mysql::MySqlPool::connect_with(options).await?; - DbConn::from(db) - } - #[cfg(feature = "postgres")] - sea_orm::DatabaseConnection::SqlxPostgresPoolConnection(_) => { - let connect_options = master_db.get_postgres_connection_pool().connect_options(); - let connect_options = (*connect_options).clone(); - let options = connect_options.database(&db_name); - let db = sqlx::postgres::PgPool::connect_with(options).await?; - DbConn::from(db) - } - _ => unreachable!(), - }; - - migration::Migrator::up(&db, None).await?; - - let r = panic::AssertUnwindSafe(test(db)).catch_unwind().await; - #[cfg(test_force_db_deletion)] - { - master_db - .execute_unprepared(&format!("drop database {db_name}")) - .await?; - println!("Database {db_name} force-deleted"); - match r { - Ok(r) => r.into_result(), - Err(e) => panic::resume_unwind(e), - } - } - #[cfg(not(test_force_db_deletion))] - { - match r.map(IntoResult::into_result) { - Ok(Ok(out)) => { - master_db - .execute_unprepared(&format!("drop database {db_name}")) - .await?; - Ok(out) - } - other => { - println!( - "Test failed, leaving database {db_name} as-is. \ - Run with cfg `test_force_db_deletion` to drop the database everytime." - ); - match other { - Ok(Err(e)) => Err(e), - Err(e) => panic::resume_unwind(e), - _ => unreachable!(), - } - } - } - } -} - pub fn try_from_slice<'de, T>(slice: &'de [u8]) -> Result where T: serde::Deserialize<'de>, diff --git a/crates/game_api/tests/overview.rs b/crates/game_api/tests/overview.rs index 9811389..2cc7738 100644 --- a/crates/game_api/tests/overview.rs +++ b/crates/game_api/tests/overview.rs @@ -78,7 +78,7 @@ fn player_id_to_row(player_id: i32) -> Row { } async fn insert_sample_map(conn: &C) -> anyhow::Result { - let map_id = rand::random_range(..100000); + let map_id = test_env::get_map_id(); maps::Entity::insert(maps::ActiveModel { id: Set(map_id), @@ -372,9 +372,7 @@ async fn competition_ranking() -> anyhow::Result<()> { let app = base::get_app(db.clone()).await; let req = test::TestRequest::get() - .uri(&format!( - "/overview?mapId=test_map_uid&playerId=player_1_login" - )) + .uri("/overview?mapId=test_map_uid&playerId=player_1_login") .to_request(); let resp = test::call_service(&app, req).await; @@ -448,8 +446,8 @@ async fn overview_event_version_map_non_empty() -> anyhow::Result<()> { .await .context("couldn't insert players")?; - let map_id = rand::random_range(1..=100000); - let event_map_id = rand::random_range(1..=100000); + let map_id = test_env::get_map_id(); + let event_map_id = test_env::get_map_id(); let map = maps::ActiveModel { id: Set(map_id), diff --git a/crates/game_api/tests/overview_event.rs b/crates/game_api/tests/overview_event.rs index 96cef38..8c8705a 100644 --- a/crates/game_api/tests/overview_event.rs +++ b/crates/game_api/tests/overview_event.rs @@ -40,9 +40,9 @@ async fn event_overview_original_maps_diff() -> anyhow::Result<()> { ..Default::default() }; - let map_id = rand::random_range(1..=100000); + let map_id = test_env::get_map_id(); let original_map_id = loop { - let val = rand::random_range(1..=100000); + let val = test_env::get_map_id(); if val != map_id { break val; } @@ -185,7 +185,7 @@ async fn event_overview_transparent_equal() -> anyhow::Result<()> { ..Default::default() }; - let map_id = rand::random_range(1..=100000); + let map_id = test_env::get_map_id(); let map = maps::ActiveModel { id: Set(map_id), @@ -307,7 +307,7 @@ async fn event_overview_original_maps_empty() -> anyhow::Result<()> { ..Default::default() }; - let map_id = rand::random_range(1..=100000); + let map_id = test_env::get_map_id(); let map = maps::ActiveModel { id: Set(map_id), diff --git a/crates/game_api/tests/player_finished.rs b/crates/game_api/tests/player_finished.rs index 90bcafa..f9ac572 100644 --- a/crates/game_api/tests/player_finished.rs +++ b/crates/game_api/tests/player_finished.rs @@ -27,7 +27,7 @@ async fn single_try() -> anyhow::Result<()> { ..Default::default() }; - let map_id = rand::random_range(1..=100000); + let map_id = test_env::get_map_id(); let map = maps::ActiveModel { id: Set(map_id), @@ -192,7 +192,7 @@ async fn many_tries() -> anyhow::Result<()> { ..Default::default() }; - let map_id = rand::random_range(1..=100000); + let map_id = test_env::get_map_id(); let map = maps::ActiveModel { id: Set(map_id), @@ -338,7 +338,7 @@ async fn with_mode_version() -> anyhow::Result<()> { ..Default::default() }; - let map_id = rand::random_range(1..=100000); + let map_id = test_env::get_map_id(); let map = maps::ActiveModel { id: Set(map_id), @@ -428,7 +428,7 @@ async fn many_records() -> anyhow::Result<()> { ..Default::default() }); - let map_id = rand::random_range(1..=100000); + let map_id = test_env::get_map_id(); let map = maps::ActiveModel { id: Set(map_id), @@ -569,7 +569,7 @@ async fn save_record_for_related_non_transparent_event() -> anyhow::Result<()> { ..Default::default() }; - let map_id = rand::random_range(1..=100000); + let map_id = test_env::get_map_id(); let map = maps::ActiveModel { id: Set(map_id), diff --git a/crates/game_api/tests/player_finished_event.rs b/crates/game_api/tests/player_finished_event.rs index 10a32a7..dff6e5f 100644 --- a/crates/game_api/tests/player_finished_event.rs +++ b/crates/game_api/tests/player_finished_event.rs @@ -38,7 +38,7 @@ async fn finished_on_transparent_event() -> anyhow::Result<()> { ..Default::default() }; - let map_id = rand::random_range(1..=100000); + let map_id = test_env::get_map_id(); let map = maps::ActiveModel { id: Set(map_id), @@ -166,9 +166,9 @@ async fn event_finish_transitive_save() -> anyhow::Result<()> { ..Default::default() }; - let original_map_id = rand::random_range(1..=100000); + let original_map_id = test_env::get_map_id(); let map_id = loop { - let val = rand::random_range(1..=100000); + let val = test_env::get_map_id(); if val != original_map_id { break val; } @@ -313,9 +313,9 @@ async fn event_finish_save_to_original() -> anyhow::Result<()> { ..Default::default() }; - let original_map_id = rand::random_range(1..=100000); + let original_map_id = test_env::get_map_id(); let map_id = loop { - let val = rand::random_range(1..=100000); + let val = test_env::get_map_id(); if val != original_map_id { break val; } @@ -459,9 +459,9 @@ async fn event_finish_non_transitive_save() -> anyhow::Result<()> { ..Default::default() }; - let original_map_id = rand::random_range(1..=100000); + let original_map_id = test_env::get_map_id(); let map_id = loop { - let val = rand::random_range(1..=100000); + let val = test_env::get_map_id(); if val != original_map_id { break val; } @@ -576,9 +576,9 @@ async fn event_finish_save_non_event_record() -> anyhow::Result<()> { ..Default::default() }; - let original_map_id = rand::random_range(1..=100000); + let original_map_id = test_env::get_map_id(); let map_id = loop { - let val = rand::random_range(1..=100000); + let val = test_env::get_map_id(); if val != original_map_id { break val; } diff --git a/crates/game_api/tests/root.rs b/crates/game_api/tests/root.rs index 3e906c8..a8173a6 100644 --- a/crates/game_api/tests/root.rs +++ b/crates/game_api/tests/root.rs @@ -3,15 +3,18 @@ mod base; use actix_web::test; #[tokio::test] -#[cfg(feature = "test")] +#[cfg(test)] async fn test_not_found() -> anyhow::Result<()> { use actix_http::StatusCode; use game_api_lib::TracedError; - use records_lib::Database; + use mkenv::prelude::*; + use records_lib::{Database, DbEnv}; use sea_orm::DbBackend; - let env = base::get_env()?; - let db = Database::from_mock_db(DbBackend::MySql, env.db_env.redis_url.redis_url)?; + test_env::init_env()?; + + let env = DbEnv::define(); + let db = Database::from_mock_db(DbBackend::MySql, env.redis_url.redis_url.get())?; let app = base::get_app(db).await; let req = test::TestRequest::get().uri("/").to_request(); diff --git a/crates/graphql-api/Cargo.toml b/crates/graphql-api/Cargo.toml index 2f06753..c195a4b 100644 --- a/crates/graphql-api/Cargo.toml +++ b/crates/graphql-api/Cargo.toml @@ -10,8 +10,25 @@ chrono = { workspace = true } deadpool-redis = { workspace = true } entity = { path = "../entity" } futures = { workspace = true } +hmac = "0.12.1" +mkenv = { workspace = true } records-lib = { path = "../records_lib" } reqwest = { workspace = true } sea-orm = { workspace = true } serde = { workspace = true } +sha2 = "0.10.9" tokio = { workspace = true, features = ["macros"] } +itertools.workspace = true +serde_json.workspace = true + +[dev-dependencies] +test-env = { path = "../test-env" } +anyhow = { workspace = true } +tracing.workspace = true +rand.workspace = true + +[features] +default = [] +mysql = ["records-lib/mysql", "test-env/mysql"] +postgres = ["records-lib/postgres", "test-env/postgres"] +sqlite = ["records-lib/sqlite"] diff --git a/crates/graphql-api/src/config.rs b/crates/graphql-api/src/config.rs new file mode 100644 index 0000000..523d04e --- /dev/null +++ b/crates/graphql-api/src/config.rs @@ -0,0 +1,104 @@ +use hmac::Hmac; +use sha2::{Sha224, digest::KeyInit}; +use std::{error::Error, fmt, sync::OnceLock}; + +use mkenv::{ConfigDescriptor, exec::ConfigInitializer}; + +pub type SecretKeyType = Hmac; + +fn parse_secret_key(input: &str) -> Result> { + SecretKeyType::new_from_slice(input.as_bytes()).map_err(From::from) +} + +#[cfg(debug_assertions)] +mkenv::make_config! { + pub struct SecretKey { + pub(crate) cursor_secret_key: { + var_name: "GQL_API_CURSOR_SECRET_KEY", + layers: [ + parsed(parse_secret_key), + ], + description: "The secret key used for signing cursors for pagination \ + in the GraphQL API", + } + } +} + +#[cfg(not(debug_assertions))] +mkenv::make_config! { + pub struct SecretKey { + pub(crate) cursor_secret_key: { + var_name: "GQL_API_CURSOR_SECRET_KEY_FILE", + layers: [ + file_read(), + parsed(parse_secret_key), + ], + description: "The path to the file containing The secret key used for signing cursors \ + for pagination in the GraphQL API", + } + } +} + +mkenv::make_config! { + pub struct ApiConfig { + pub(crate) cursor_max_limit: { + var_name: "GQL_API_CURSOR_MAX_LIMIT", + layers: [ + parsed_from_str(), + or_default_val(|| 100), + ], + }, + + pub(crate) cursor_default_limit: { + var_name: "GQL_API_CURSOR_DEFAULT_LIMIT", + layers: [ + parsed_from_str(), + or_default_val(|| 50), + ], + }, + + pub(crate) cursor_secret_key: { SecretKey }, + } +} + +static CONFIG: OnceLock = OnceLock::new(); + +#[derive(Debug)] +pub enum InitError { + ConfigAlreadySet, + Config(Box), +} + +impl fmt::Display for InitError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + InitError::ConfigAlreadySet => f.write_str("config already set"), + InitError::Config(error) => write!(f, "error during config init: {error}"), + } + } +} + +impl Error for InitError { + fn source(&self) -> Option<&(dyn Error + 'static)> { + match self { + InitError::ConfigAlreadySet => None, + InitError::Config(error) => Some(&**error), + } + } +} + +pub fn set_config(config: ApiConfig) -> Result<(), Box> { + CONFIG.set(config).map_err(Box::new) +} + +pub fn init_config() -> Result<(), InitError> { + let config = ApiConfig::define(); + config + .try_init() + .map_err(|e| InitError::Config(format!("{e}").into()))?; + CONFIG.set(config).map_err(|_| InitError::ConfigAlreadySet) +} + +pub(crate) fn config() -> &'static ApiConfig { + CONFIG.get().unwrap() +} diff --git a/crates/graphql-api/src/cursors.rs b/crates/graphql-api/src/cursors.rs index 740ae60..ce7255a 100644 --- a/crates/graphql-api/src/cursors.rs +++ b/crates/graphql-api/src/cursors.rs @@ -1,141 +1,533 @@ -use std::{ops::RangeInclusive, str}; +pub mod expr_tuple; +pub mod query_builder; +pub mod query_trait; + +use std::str; use async_graphql::{ID, connection::CursorType}; -use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64}; +use base64::{Engine as _, prelude::BASE64_URL_SAFE}; use chrono::{DateTime, Utc}; +use hmac::Mac; +use mkenv::Layer; +use sea_orm::{ + Value, + sea_query::{IntoValueTuple, SimpleExpr, ValueTuple}, +}; -use crate::error::CursorDecodeErrorKind; - -pub const CURSOR_LIMIT_RANGE: RangeInclusive = 1..=100; -pub const CURSOR_DEFAULT_LIMIT: usize = 50; +use crate::{ + cursors::expr_tuple::{ExprTuple, IntoExprTuple}, + error::{CursorDecodeError, CursorDecodeErrorKind}, +}; -fn decode_base64(s: &str) -> Result { - let decoded = BASE64 +fn decode_base64(s: &str) -> Result { + let decoded = BASE64_URL_SAFE .decode(s) - .map_err(|_| CursorDecodeErrorKind::NotBase64)?; + .map_err(|_| CursorDecodeError::from(CursorDecodeErrorKind::NotBase64))?; + // TODO: replace this with `slice::split_once` once it's stabilized + // https://github.com/rust-lang/rust/issues/112811 + let idx = decoded + .iter() + .position(|b| *b == b'$') + .ok_or(CursorDecodeError::from(CursorDecodeErrorKind::NoSignature))?; + let (content, signature) = (&decoded[..idx], &decoded[idx + 1..]); + + let mut mac = crate::config().cursor_secret_key.cursor_secret_key.get(); + mac.update(content); + mac.verify_slice(signature) + .map_err(|e| CursorDecodeError::from(CursorDecodeErrorKind::InvalidSignature(e)))?; - String::from_utf8(decoded).map_err(|_| CursorDecodeErrorKind::NotUtf8) + let content = str::from_utf8(content) + .map_err(|_| CursorDecodeError::from(CursorDecodeErrorKind::NotUtf8))?; + + Ok(content.to_owned()) } -fn check_prefix(splitted: I) -> Result<(), CursorDecodeErrorKind> -where - I: IntoIterator, - S: AsRef, -{ - match splitted - .into_iter() - .next() - .filter(|s| s.as_ref() == "record") +fn encode_base64(s: String) -> String { + let mut mac = crate::config().cursor_secret_key.cursor_secret_key.get(); + mac.update(s.as_bytes()); + let signature = mac.finalize().into_bytes(); + + let mut output = s.into_bytes(); + output.push(b'$'); + output.extend_from_slice(&signature); + + BASE64_URL_SAFE.encode(&output) +} + +mod datetime_serde { + use std::fmt; + + use chrono::{DateTime, Utc}; + use serde::de::{Unexpected, Visitor}; + + pub fn serialize(datetime: &DateTime, serializer: S) -> Result + where + S: serde::Serializer, { - Some(_) => Ok(()), - None => Err(CursorDecodeErrorKind::WrongPrefix), + let time = datetime.timestamp_millis(); + serializer.serialize_i64(time) } + + pub fn deserialize<'de, D>(deser: D) -> Result, D::Error> + where + D: serde::Deserializer<'de>, + { + struct TimestampVisitor; + impl<'de> Visitor<'de> for TimestampVisitor { + type Value = DateTime; + + fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.write_str("an integer") + } + + fn visit_u64(self, v: u64) -> Result + where + E: serde::de::Error, + { + self.visit_i64(v as _) + } + + fn visit_i64(self, v: i64) -> Result + where + E: serde::de::Error, + { + DateTime::from_timestamp_millis(v).ok_or_else(|| { + serde::de::Error::invalid_value(Unexpected::Signed(v), &"a valid timestamp") + }) + } + } + + deser.deserialize_i64(TimestampVisitor) + } +} + +#[cold] +#[inline(never)] +fn serialization_failed(_obj: &T, e: serde_json::Error) -> ! { + panic!( + "serialization of `{}` failed: {e}", + std::any::type_name::() + ) } -fn check_timestamp(splitted: I) -> Result, CursorDecodeErrorKind> +fn encode_cursor(prefix: &str, obj: &T) -> String where - I: IntoIterator, - S: AsRef, + T: serde::Serialize, { - splitted - .into_iter() - .next() - .and_then(|t| t.as_ref().parse().ok()) - .ok_or(CursorDecodeErrorKind::NoTimestamp) - .and_then(|t| { - DateTime::from_timestamp_millis(t).ok_or(CursorDecodeErrorKind::InvalidTimestamp(t)) - }) + encode_base64(format!( + "{prefix}:{}", + serde_json::to_string(obj).unwrap_or_else(|e| serialization_failed(obj, e)) + )) } -fn check_time(splitted: I) -> Result +fn decode_cursor(prefix: &str, input: &str) -> Result where - I: IntoIterator, - S: AsRef, + T: serde::de::DeserializeOwned, { - splitted - .into_iter() - .next() - .and_then(|t| t.as_ref().parse().ok()) - .ok_or(CursorDecodeErrorKind::NoTime) + let decoded = decode_base64(input)?; + + let Some((input_prefix, input)) = decoded.split_once(':') else { + return Err(CursorDecodeErrorKind::MissingPrefix.into()); + }; + + if input_prefix != prefix { + return Err(CursorDecodeErrorKind::InvalidPrefix.into()); + } + + serde_json::from_str(input).map_err(|_| CursorDecodeErrorKind::InvalidData.into()) } -pub struct RecordDateCursor(pub DateTime); +#[derive(PartialEq, Debug, serde::Serialize, serde::Deserialize)] +pub struct RecordDateCursor { + #[serde(with = "datetime_serde")] + pub record_date: DateTime, + pub data: T, +} -impl CursorType for RecordDateCursor { - type Error = CursorDecodeErrorKind; +impl CursorType for RecordDateCursor +where + T: serde::Serialize + for<'de> serde::Deserialize<'de>, +{ + type Error = CursorDecodeError; fn decode_cursor(s: &str) -> Result { - let decoded = decode_base64(s)?; - let mut splitted = decoded.split(':'); - check_prefix(&mut splitted)?; - let record_date = check_timestamp(&mut splitted)?; - Ok(Self(record_date)) + decode_cursor("record_date", s) } fn encode_cursor(&self) -> String { - let timestamp = self.0.timestamp_millis(); - BASE64.encode(format!("record:{}", timestamp)) + encode_cursor("record_date", self) } } -pub struct RecordRankCursor { +impl IntoExprTuple for &RecordDateCursor +where + T: Into + Clone, +{ + fn into_expr_tuple(self) -> ExprTuple { + (self.record_date, self.data.clone()).into_expr_tuple() + } +} + +impl IntoValueTuple for &RecordDateCursor +where + T: Into + Clone, +{ + fn into_value_tuple(self) -> ValueTuple { + (self.record_date, self.data.clone()).into_value_tuple() + } +} + +#[derive(Debug, PartialEq, serde::Serialize, serde::Deserialize)] +pub struct RecordRankCursor { + #[serde(with = "datetime_serde")] pub record_date: DateTime, pub time: i32, + pub data: T, } -impl CursorType for RecordRankCursor { - type Error = CursorDecodeErrorKind; +impl CursorType for RecordRankCursor +where + T: serde::Serialize + for<'de> serde::Deserialize<'de>, +{ + type Error = CursorDecodeError; fn decode_cursor(s: &str) -> Result { - let decoded = decode_base64(s)?; - let mut splitted = decoded.split(':'); - check_prefix(&mut splitted)?; - let record_date = check_timestamp(&mut splitted)?; - let time = check_time(&mut splitted)?; - Ok(Self { record_date, time }) + decode_cursor("record_rank", s) } fn encode_cursor(&self) -> String { - let timestamp = self.record_date.timestamp_millis(); - BASE64.encode(format!("record:{timestamp}:{}", self.time)) + encode_cursor("record_rank", self) + } +} + +impl IntoExprTuple for &RecordRankCursor +where + T: Into + Clone, +{ + fn into_expr_tuple(self) -> ExprTuple { + (self.time, self.record_date, self.data.clone()).into_expr_tuple() + } +} + +impl IntoValueTuple for &RecordRankCursor +where + T: Into + Clone, +{ + fn into_value_tuple(self) -> sea_orm::sea_query::ValueTuple { + (self.time, self.record_date, self.data.clone()).into_value_tuple() } } -pub struct TextCursor(pub String); +#[derive(Debug, PartialEq, serde::Serialize, serde::Deserialize)] +pub struct TextCursor { + pub text: String, + pub data: T, +} -impl CursorType for TextCursor { - type Error = CursorDecodeErrorKind; +impl CursorType for TextCursor +where + T: serde::Serialize + for<'de> serde::Deserialize<'de>, +{ + type Error = CursorDecodeError; fn decode_cursor(s: &str) -> Result { - decode_base64(s).map(Self) + decode_cursor("text", s) } fn encode_cursor(&self) -> String { - BASE64.encode(&self.0) + encode_cursor("text", self) + } +} + +impl IntoExprTuple for &TextCursor +where + T: Into + Clone, +{ + fn into_expr_tuple(self) -> ExprTuple { + (self.text.clone(), self.data.clone()).into_expr_tuple() } } -pub struct F64Cursor(pub f64); +impl IntoValueTuple for &TextCursor +where + T: Into + Clone, +{ + fn into_value_tuple(self) -> ValueTuple { + (self.text.clone(), self.data.clone()).into_value_tuple() + } +} -impl CursorType for F64Cursor { - type Error = CursorDecodeErrorKind; +#[derive(Debug, PartialEq, serde::Serialize, serde::Deserialize)] +pub struct F64Cursor { + pub score: f64, + pub data: T, +} + +impl CursorType for F64Cursor +where + T: serde::Serialize + for<'de> serde::Deserialize<'de>, +{ + type Error = CursorDecodeError; fn decode_cursor(s: &str) -> Result { - let decoded = decode_base64(s)?; - let parsed = decoded - .parse() - .map_err(|_| CursorDecodeErrorKind::NoScore)?; - Ok(Self(parsed)) + decode_cursor("score", s) } fn encode_cursor(&self) -> String { - BASE64.encode(self.0.to_string()) + encode_cursor("score", self) + } +} + +impl IntoExprTuple for &F64Cursor +where + T: Into + Clone, +{ + fn into_expr_tuple(self) -> ExprTuple { + (self.score, self.data.clone()).into_expr_tuple() + } +} + +impl IntoValueTuple for &F64Cursor +where + T: Into + Clone, +{ + fn into_value_tuple(self) -> ValueTuple { + (self.score, self.data.clone()).into_value_tuple() } } -pub struct ConnectionParameters { - pub before: Option, - pub after: Option, +pub struct ConnectionParameters { + pub before: Option, + pub after: Option, pub first: Option, pub last: Option, } + +impl Default for ConnectionParameters { + fn default() -> Self { + Self { + before: Default::default(), + after: Default::default(), + first: Default::default(), + last: Default::default(), + } + } +} + +#[cfg(test)] +mod tests { + use std::fmt; + + use async_graphql::connection::CursorType; + use base64::{Engine as _, prelude::BASE64_URL_SAFE}; + use chrono::{DateTime, SubsecRound, Utc}; + use sha2::digest::MacError; + + use crate::{ + config::InitError, + cursors::{F64Cursor, TextCursor}, + error::{CursorDecodeError, CursorDecodeErrorKind}, + }; + + use super::{RecordDateCursor, RecordRankCursor}; + + fn setup() { + match crate::init_config() { + Ok(_) | Err(InitError::ConfigAlreadySet) => (), + Err(InitError::Config(e)) => { + panic!("test setup error: {e}"); + } + } + } + + fn test_cursor_round_trip(source: &C, expected: &C) + where + C: CursorType + fmt::Debug + PartialEq, + { + let cursor = source.encode_cursor(); + + let decoded = C::decode_cursor(&cursor); + assert_eq!(decoded.as_ref(), Ok(expected)); + } + + fn test_decode_cursor_errors() + where + C: CursorType, + { + let decoded = ::decode_cursor("foobar").err(); + assert_eq!( + decoded.map(|e| e.kind), + Some(CursorDecodeErrorKind::NotBase64) + ); + + let encoded = BASE64_URL_SAFE.encode("foobar"); + let decoded = ::decode_cursor(&encoded).err(); + assert_eq!( + decoded.map(|e| e.kind), + Some(CursorDecodeErrorKind::NoSignature) + ); + + let encoded = BASE64_URL_SAFE.encode("foobar$signature"); + let decoded = ::decode_cursor(&encoded).err(); + assert_eq!( + decoded.map(|e| e.kind), + Some(CursorDecodeErrorKind::InvalidSignature(MacError)) + ); + } + + fn test_encode_cursor(cursor: &C, expected: &str) { + let cursor = cursor.encode_cursor(); + + let decoded = BASE64_URL_SAFE + .decode(&cursor) + .expect("cursor should be encoded as base64"); + let expected = format!("{expected}$"); + assert!(decoded.starts_with(expected.as_bytes())); + } + + #[test] + fn encode_date_cursor() { + setup(); + + test_encode_cursor( + &RecordDateCursor { + record_date: DateTime::from_timestamp_millis(10).unwrap(), + data: 0, + }, + r#"record_date:{"record_date":10,"data":0}"#, + ); + } + + #[test] + fn date_cursor_round_trip() { + setup(); + + let now = Utc::now(); + test_cursor_round_trip( + &RecordDateCursor { + record_date: now, + data: 0, + }, + &RecordDateCursor { + record_date: now.trunc_subsecs(3), + data: 0, + }, + ); + } + + #[test] + fn decode_date_cursor_errors() { + setup(); + test_decode_cursor_errors::(); + } + + #[test] + fn decode_rank_cursor_errors() { + setup(); + test_decode_cursor_errors::(); + } + + #[test] + fn decode_text_cursor_errors() { + setup(); + test_decode_cursor_errors::(); + } + + #[test] + fn decode_score_cursor_errors() { + setup(); + test_decode_cursor_errors::(); + } + + #[test] + fn encode_rank_cursor() { + setup(); + + test_encode_cursor( + &RecordRankCursor { + record_date: DateTime::from_timestamp_millis(26).unwrap(), + time: 1000, + data: 24, + }, + r#"record_rank:{"record_date":26,"time":1000,"data":24}"#, + ); + } + + #[test] + fn rank_cursor_round_trip() { + setup(); + + let now = Utc::now(); + test_cursor_round_trip( + &RecordRankCursor { + record_date: now, + time: 2000, + data: 34, + }, + &RecordRankCursor { + record_date: now.trunc_subsecs(3), + time: 2000, + data: 34, + }, + ); + } + + #[test] + fn encode_text_cursor() { + setup(); + test_encode_cursor( + &TextCursor { + text: "hello".to_owned(), + data: 123, + }, + r#"text:{"text":"hello","data":123}"#, + ); + } + + #[test] + fn text_cursor_round_trip() { + setup(); + test_cursor_round_trip( + &TextCursor { + text: "booga".to_owned(), + data: 232, + }, + &TextCursor { + text: "booga".to_owned(), + data: 232, + }, + ); + } + + #[test] + fn encode_score_cursor() { + setup(); + test_encode_cursor( + &F64Cursor { + score: 12.34, + data: 2445, + }, + r#"score:{"score":12.34,"data":2445}"#, + ); + } + + #[test] + fn decode_score_cursor() { + setup(); + } + + #[test] + fn score_cursor_round_trip() { + setup(); + test_cursor_round_trip( + &F64Cursor { + score: 24.6, + data: 123, + }, + &F64Cursor { + score: 24.6, + data: 123, + }, + ); + } +} diff --git a/crates/graphql-api/src/cursors/expr_tuple.rs b/crates/graphql-api/src/cursors/expr_tuple.rs new file mode 100644 index 0000000..2420823 --- /dev/null +++ b/crates/graphql-api/src/cursors/expr_tuple.rs @@ -0,0 +1,69 @@ +use sea_orm::sea_query::SimpleExpr; + +#[derive(Debug, Clone)] +pub enum ExprTuple { + One(SimpleExpr), + Two(SimpleExpr, SimpleExpr), + Three(SimpleExpr, SimpleExpr, SimpleExpr), + Many(Vec), +} + +pub trait IntoExprTuple { + fn into_expr_tuple(self) -> ExprTuple; +} + +impl> IntoExprTuple for (A,) { + fn into_expr_tuple(self) -> ExprTuple { + let (a,) = self; + ExprTuple::One(a.into()) + } +} + +impl IntoExprTuple for (A, B) +where + A: Into, + B: Into, +{ + fn into_expr_tuple(self) -> ExprTuple { + let (a, b) = self; + ExprTuple::Two(a.into(), b.into()) + } +} + +impl IntoExprTuple for (A, B, C) +where + A: Into, + B: Into, + C: Into, +{ + fn into_expr_tuple(self) -> ExprTuple { + let (a, b, c) = self; + ExprTuple::Three(a.into(), b.into(), c.into()) + } +} + +macro_rules! impl_into_expr_tuple { + ($($gen:ident),*) => { + impl<$($gen),*> IntoExprTuple for ($($gen),*) + where + $( + $gen: Into + ),* + { + fn into_expr_tuple(self) -> ExprTuple { + #[allow(non_snake_case)] + let ($($gen),*) = self; + ExprTuple::Many(vec![$($gen.into()),*]) + } + } + + }; +} + +impl_into_expr_tuple!(A, B, C, D); +impl_into_expr_tuple!(A, B, C, D, E); +impl_into_expr_tuple!(A, B, C, D, E, F); +impl_into_expr_tuple!(A, B, C, D, E, F, G); +impl_into_expr_tuple!(A, B, C, D, E, F, G, H); +impl_into_expr_tuple!(A, B, C, D, E, F, G, H, I); +impl_into_expr_tuple!(A, B, C, D, E, F, G, H, I, K); diff --git a/crates/graphql-api/src/cursors/query_builder.rs b/crates/graphql-api/src/cursors/query_builder.rs new file mode 100644 index 0000000..4e423d8 --- /dev/null +++ b/crates/graphql-api/src/cursors/query_builder.rs @@ -0,0 +1,266 @@ +use std::marker::PhantomData; + +use sea_orm::{ + Condition, ConnectionTrait, DbErr, DynIden, FromQueryResult, Identity, IntoIdentity, Order, + PartialModelTrait, QuerySelect, SelectModel, SelectorTrait, + prelude::{Expr, SeaRc}, + sea_query::{ExprTrait as _, SelectStatement, SimpleExpr}, +}; + +use crate::cursors::expr_tuple::{ExprTuple, IntoExprTuple}; + +/// An enhanced version of [`Cursor`][sea_orm::Cursor]. +/// +/// It allows specifying expressions for pagination values, and uses tuple syntax when building the +/// SQL statement, instead of chaining inner `AND` and `OR` operations. +/// +/// The last change is the most important, because there seems to be an issue with MariaDB/MySQL, +/// and beside that, the tuple syntax seems more efficient in every DB engine. +#[derive(Debug, Clone)] +pub struct CursorQueryBuilder { + query: SelectStatement, + table: DynIden, + order_columns: Identity, + first: Option, + last: Option, + before: Option, + after: Option, + sort_asc: bool, + is_result_reversed: bool, + phantom: PhantomData, +} + +impl CursorQueryBuilder { + /// Create a new cursor + pub fn new(query: SelectStatement, table: DynIden, order_columns: C) -> Self + where + C: IntoIdentity, + { + Self { + query, + table, + order_columns: order_columns.into_identity(), + last: None, + first: None, + after: None, + before: None, + sort_asc: true, + is_result_reversed: false, + phantom: PhantomData, + } + } + + /// Filter paginated result with corresponding column less than the input value + pub fn before(&mut self, values: V) -> &mut Self + where + V: IntoExprTuple, + { + self.before = Some(values.into_expr_tuple()); + self + } + + /// Filter paginated result with corresponding column greater than the input value + pub fn after(&mut self, values: V) -> &mut Self + where + V: IntoExprTuple, + { + self.after = Some(values.into_expr_tuple()); + self + } + + fn apply_filters(&mut self) -> &mut Self { + if let Some(values) = self.after.clone() { + let condition = + self.apply_filter(values, |c, v| if self.sort_asc { c.gt(v) } else { c.lt(v) }); + self.query.cond_where(condition); + } + + if let Some(values) = self.before.clone() { + let condition = + self.apply_filter(values, |c, v| if self.sort_asc { c.lt(v) } else { c.gt(v) }); + self.query.cond_where(condition); + } + + self + } + + fn apply_filter(&self, values: ExprTuple, f: F) -> Condition + where + F: Fn(SimpleExpr, SimpleExpr) -> SimpleExpr, + { + match (&self.order_columns, values) { + (Identity::Unary(c1), ExprTuple::One(v1)) => { + let exp = Expr::col((SeaRc::clone(&self.table), SeaRc::clone(c1))); + Condition::all().add(f(exp.into(), v1)) + } + (Identity::Binary(c1, c2), ExprTuple::Two(v1, v2)) => { + let c1 = Expr::col((SeaRc::clone(&self.table), SeaRc::clone(c1))).into(); + let c2 = Expr::col((SeaRc::clone(&self.table), SeaRc::clone(c2))).into(); + let columns = Expr::tuple([c1, c2]).into(); + let values = Expr::tuple([v1, v2]).into(); + Condition::all().add(f(columns, values)) + } + (Identity::Ternary(c1, c2, c3), ExprTuple::Three(v1, v2, v3)) => { + let c1 = Expr::col((SeaRc::clone(&self.table), SeaRc::clone(c1))).into(); + let c2 = Expr::col((SeaRc::clone(&self.table), SeaRc::clone(c2))).into(); + let c3 = Expr::col((SeaRc::clone(&self.table), SeaRc::clone(c3))).into(); + let columns = Expr::tuple([c1, c2, c3]).into(); + let values = Expr::tuple([v1, v2, v3]).into(); + Condition::all().add(f(columns, values)) + } + (Identity::Many(col_vec), ExprTuple::Many(val_vec)) + if col_vec.len() == val_vec.len() => + { + let columns = Expr::tuple( + col_vec + .iter() + .map(|c| Expr::col((SeaRc::clone(&self.table), SeaRc::clone(c))).into()) + .collect::>(), + ) + .into(); + let values = Expr::tuple(val_vec).into(); + Condition::all().add(f(columns, values)) + } + _ => panic!("column arity mismatch"), + } + } + + /// Use ascending sort order + pub fn asc(&mut self) -> &mut Self { + self.sort_asc = true; + self + } + + /// Use descending sort order + pub fn desc(&mut self) -> &mut Self { + self.sort_asc = false; + self + } + + /// Limit result set to only first N rows in ascending order of the order by column + pub fn first(&mut self, num_rows: u64) -> &mut Self { + self.last = None; + self.first = Some(num_rows); + self + } + + /// Limit result set to only last N rows in ascending order of the order by column + pub fn last(&mut self, num_rows: u64) -> &mut Self { + self.first = None; + self.last = Some(num_rows); + self + } + + fn resolve_sort_order(&mut self) -> Order { + let should_reverse_order = self.last.is_some(); + self.is_result_reversed = should_reverse_order; + + if (self.sort_asc && !should_reverse_order) || (!self.sort_asc && should_reverse_order) { + Order::Asc + } else { + Order::Desc + } + } + + fn apply_limit(&mut self) -> &mut Self { + if let Some(num_rows) = self.first { + self.query.limit(num_rows); + } else if let Some(num_rows) = self.last { + self.query.limit(num_rows); + } + + self + } + + fn apply_order_by(&mut self) -> &mut Self { + self.query.clear_order_by(); + let ord = self.resolve_sort_order(); + + let query = &mut self.query; + let order = |query: &mut SelectStatement, col| { + query.order_by((SeaRc::clone(&self.table), SeaRc::clone(col)), ord.clone()); + }; + match &self.order_columns { + Identity::Unary(c1) => { + order(query, c1); + } + Identity::Binary(c1, c2) => { + order(query, c1); + order(query, c2); + } + Identity::Ternary(c1, c2, c3) => { + order(query, c1); + order(query, c2); + order(query, c3); + } + Identity::Many(vec) => { + for col in vec.iter() { + order(query, col); + } + } + } + + self + } + + /// Construct a [Cursor] that fetch any custom struct + pub fn into_model(self) -> CursorQueryBuilder> + where + M: FromQueryResult, + { + CursorQueryBuilder { + query: self.query, + table: self.table, + order_columns: self.order_columns, + last: self.last, + first: self.first, + after: self.after, + before: self.before, + sort_asc: self.sort_asc, + is_result_reversed: self.is_result_reversed, + phantom: PhantomData, + } + } + + /// Return a [Selector] from `Self` that wraps a [SelectModel] with a [PartialModel](PartialModelTrait) + pub fn into_partial_model(self) -> CursorQueryBuilder> + where + M: PartialModelTrait, + { + M::select_cols(QuerySelect::select_only(self)).into_model::() + } +} + +impl CursorQueryBuilder +where + S: SelectorTrait, +{ + /// Fetch the paginated result + pub async fn all(&mut self, db: &C) -> Result, DbErr> + where + C: ConnectionTrait, + { + self.apply_limit(); + self.apply_order_by(); + self.apply_filters(); + + let stmt = db.get_database_backend().build(&self.query); + let rows = db.query_all(stmt).await?; + let mut buffer = Vec::with_capacity(rows.len()); + for row in rows.into_iter() { + buffer.push(S::from_raw_query_result(row)?); + } + if self.is_result_reversed { + buffer.reverse() + } + Ok(buffer) + } +} + +impl QuerySelect for CursorQueryBuilder { + type QueryStatement = SelectStatement; + + fn query(&mut self) -> &mut SelectStatement { + &mut self.query + } +} diff --git a/crates/graphql-api/src/cursors/query_trait.rs b/crates/graphql-api/src/cursors/query_trait.rs new file mode 100644 index 0000000..eb71c65 --- /dev/null +++ b/crates/graphql-api/src/cursors/query_trait.rs @@ -0,0 +1,27 @@ +use sea_orm::{EntityTrait, IntoIdentity, QuerySelect, Select, SelectModel, prelude::SeaRc}; + +use crate::cursors::query_builder::CursorQueryBuilder; + +pub trait CursorPaginable { + type Selector; + + fn paginate_cursor_by( + self, + order_columns: C, + ) -> CursorQueryBuilder; +} + +impl CursorPaginable for Select { + type Selector = SelectModel<::Model>; + + fn paginate_cursor_by( + mut self, + order_columns: C, + ) -> CursorQueryBuilder { + CursorQueryBuilder::new( + QuerySelect::query(&mut self).take(), + SeaRc::new(E::default()), + order_columns, + ) + } +} diff --git a/crates/graphql-api/src/error.rs b/crates/graphql-api/src/error.rs index e536ce8..f72f7ac 100644 --- a/crates/graphql-api/src/error.rs +++ b/crates/graphql-api/src/error.rs @@ -1,6 +1,7 @@ -use std::{fmt, ops::RangeInclusive, sync::Arc}; +use std::{error::Error, fmt, sync::Arc}; use records_lib::error::RecordsError; +use sha2::digest::MacError; pub(crate) fn map_gql_err(e: async_graphql::Error) -> ApiGqlError { match e @@ -13,15 +14,15 @@ pub(crate) fn map_gql_err(e: async_graphql::Error) -> ApiGqlError { } } -#[derive(Debug)] +#[derive(Debug, PartialEq)] pub enum CursorDecodeErrorKind { NotBase64, NotUtf8, - WrongPrefix, - NoTimestamp, - NoTime, - NoScore, - InvalidTimestamp(i64), + MissingPrefix, + InvalidPrefix, + NoSignature, + InvalidSignature(MacError), + InvalidData, } impl fmt::Display for CursorDecodeErrorKind { @@ -29,74 +30,114 @@ impl fmt::Display for CursorDecodeErrorKind { match self { CursorDecodeErrorKind::NotBase64 => f.write_str("not base64"), CursorDecodeErrorKind::NotUtf8 => f.write_str("not UTF-8"), - CursorDecodeErrorKind::WrongPrefix => f.write_str("wrong prefix"), - CursorDecodeErrorKind::NoTimestamp => f.write_str("no timestamp"), - CursorDecodeErrorKind::NoTime => f.write_str("no time"), - CursorDecodeErrorKind::NoScore => f.write_str("no score"), - CursorDecodeErrorKind::InvalidTimestamp(t) => { - f.write_str("invalid timestamp: ")?; - fmt::Display::fmt(t, f) - } + CursorDecodeErrorKind::MissingPrefix => f.write_str("missing prefix"), + CursorDecodeErrorKind::InvalidPrefix => f.write_str("invalid prefix"), + CursorDecodeErrorKind::NoSignature => f.write_str("no signature"), + CursorDecodeErrorKind::InvalidSignature(e) => write!(f, "invalid signature: {e}"), + CursorDecodeErrorKind::InvalidData => f.write_str("invalid data"), } } } -#[derive(Debug)] +impl Error for CursorDecodeErrorKind { + fn source(&self) -> Option<&(dyn Error + 'static)> { + match self { + CursorDecodeErrorKind::NotBase64 + | CursorDecodeErrorKind::NotUtf8 + | CursorDecodeErrorKind::MissingPrefix + | CursorDecodeErrorKind::InvalidPrefix + | CursorDecodeErrorKind::NoSignature + | CursorDecodeErrorKind::InvalidData => None, + CursorDecodeErrorKind::InvalidSignature(mac_error) => Some(mac_error), + } + } +} + +#[derive(Debug, PartialEq)] pub struct CursorDecodeError { - arg_name: &'static str, - value: String, - kind: CursorDecodeErrorKind, + pub kind: CursorDecodeErrorKind, } -#[derive(Debug)] -pub struct CursorRangeError { - arg_name: &'static str, - value: usize, - range: RangeInclusive, +impl From for CursorDecodeError { + fn from(kind: CursorDecodeErrorKind) -> Self { + Self { kind } + } +} + +impl fmt::Display for CursorDecodeError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("cursor decode error: ")?; + fmt::Display::fmt(&self.kind, f) + } +} + +impl Error for CursorDecodeError { + fn source(&self) -> Option<&(dyn Error + 'static)> { + Some(&self.kind) + } } #[derive(Debug)] pub enum ApiGqlErrorKind { Lib(RecordsError), - CursorRange(CursorRangeError), - CursorDecode(CursorDecodeError), + PaginationInput, GqlError(async_graphql::Error), RecordNotFound { record_id: u32 }, MapNotFound { map_uid: String }, PlayerNotFound { login: String }, } +impl fmt::Display for ApiGqlErrorKind { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + ApiGqlErrorKind::Lib(records_error) => fmt::Display::fmt(records_error, f), + ApiGqlErrorKind::PaginationInput => f.write_str( + "cursor pagination input is invalid. \ + must provide either: `after`, `after` with `first`, \ + `before`, or `before` with `last`.", + ), + ApiGqlErrorKind::GqlError(error) => f.write_str(&error.message), + ApiGqlErrorKind::RecordNotFound { record_id } => { + write!(f, "record `{record_id}` not found") + } + ApiGqlErrorKind::MapNotFound { map_uid } => { + write!(f, "map with UID `{map_uid}` not found") + } + ApiGqlErrorKind::PlayerNotFound { login } => { + write!(f, "player with login `{login}` not found") + } + } + } +} + +impl std::error::Error for ApiGqlErrorKind { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + ApiGqlErrorKind::Lib(records_error) => Some(records_error), + ApiGqlErrorKind::PaginationInput => None, + ApiGqlErrorKind::GqlError(_) => None, + ApiGqlErrorKind::RecordNotFound { .. } => None, + ApiGqlErrorKind::MapNotFound { .. } => None, + ApiGqlErrorKind::PlayerNotFound { .. } => None, + } + } +} + #[derive(Debug, Clone)] pub struct ApiGqlError { inner: Arc, } -impl ApiGqlError { - pub(crate) fn from_cursor_range_error( - arg_name: &'static str, - expected_range: RangeInclusive, - value: usize, - ) -> Self { - Self { - inner: Arc::new(ApiGqlErrorKind::CursorRange(CursorRangeError { - arg_name, - value, - range: expected_range, - })), - } +impl std::error::Error for ApiGqlError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + Some(&self.inner) } +} - pub(crate) fn from_cursor_decode_error( - arg_name: &'static str, - value: String, - decode_error: CursorDecodeErrorKind, - ) -> Self { +impl ApiGqlError { + pub(crate) fn from_pagination_input_error() -> Self { Self { - inner: Arc::new(ApiGqlErrorKind::CursorDecode(CursorDecodeError { - arg_name, - value, - kind: decode_error, - })), + inner: Arc::new(ApiGqlErrorKind::PaginationInput), } } @@ -145,36 +186,7 @@ where impl fmt::Display for ApiGqlError { #[inline] fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match &*self.inner { - ApiGqlErrorKind::Lib(records_error) => fmt::Display::fmt(records_error, f), - ApiGqlErrorKind::CursorRange(cursor_error) => { - write!( - f, - "`{}` must be between {} and {} included, got {}", - cursor_error.arg_name, - cursor_error.range.start(), - cursor_error.range.end(), - cursor_error.value - ) - } - ApiGqlErrorKind::CursorDecode(decode_error) => { - write!( - f, - "cursor argument `{}` couldn't be decoded: {}. got `{}`", - decode_error.arg_name, decode_error.kind, decode_error.value - ) - } - ApiGqlErrorKind::GqlError(error) => f.write_str(&error.message), - ApiGqlErrorKind::RecordNotFound { record_id } => { - write!(f, "record `{record_id}` not found") - } - ApiGqlErrorKind::MapNotFound { map_uid } => { - write!(f, "map with UID `{map_uid}` not found") - } - ApiGqlErrorKind::PlayerNotFound { login } => { - write!(f, "player with login `{login}` not found") - } - } + fmt::Display::fmt(&self.inner, f) } } @@ -185,4 +197,4 @@ impl From for async_graphql::Error { } } -pub type GqlResult = Result; +pub type GqlResult = Result; diff --git a/crates/graphql-api/src/lib.rs b/crates/graphql-api/src/lib.rs index b85452f..e1229a5 100644 --- a/crates/graphql-api/src/lib.rs +++ b/crates/graphql-api/src/lib.rs @@ -4,3 +4,12 @@ pub mod objects; pub mod schema; pub mod cursors; + +pub mod config; +pub(crate) use config::config; +pub use config::{init_config, set_config}; + +pub(crate) mod utils; + +#[cfg(test)] +mod tests; diff --git a/crates/graphql-api/src/objects/event_edition_map.rs b/crates/graphql-api/src/objects/event_edition_map.rs index 5147c7b..3547fad 100644 --- a/crates/graphql-api/src/objects/event_edition_map.rs +++ b/crates/graphql-api/src/objects/event_edition_map.rs @@ -8,8 +8,8 @@ use crate::{ loaders::map::MapLoader, objects::{ event_edition::EventEdition, map::Map, medal_times::MedalTimes, - ranked_record::RankedRecord, records_filter::RecordsFilter, sort_state::SortState, - sortable_fields::MapRecordSortableField, + ranked_record::RankedRecord, records_filter::RecordsFilter, sort::MapRecordSort, + sort_state::SortState, }, }; @@ -91,7 +91,7 @@ impl EventEditionMap<'_> { before: Option, #[graphql(desc = "Number of records to fetch (default: 50, max: 100)")] first: Option, #[graphql(desc = "Number of records to fetch from the end (for backward pagination)")] last: Option, - sort_field: Option, + sort: Option, filter: Option, ) -> GqlResult> { self.map @@ -102,7 +102,7 @@ impl EventEditionMap<'_> { before, first, last, - sort_field, + sort, filter, ) .await diff --git a/crates/graphql-api/src/objects/map.rs b/crates/graphql-api/src/objects/map.rs index e38e1ad..46edc50 100644 --- a/crates/graphql-api/src/objects/map.rs +++ b/crates/graphql-api/src/objects/map.rs @@ -6,7 +6,7 @@ use async_graphql::{ use deadpool_redis::redis::AsyncCommands as _; use entity::{ event_edition, event_edition_maps, global_event_records, global_records, maps, player_rating, - players, records, + records, }; use records_lib::{ Database, RedisPool, internal, @@ -16,24 +16,29 @@ use records_lib::{ sync, }; use sea_orm::{ - ColumnTrait as _, ConnectionTrait, DbConn, EntityTrait as _, FromQueryResult, JoinType, Order, - QueryFilter as _, QueryOrder as _, QuerySelect as _, StreamTrait, + ColumnTrait as _, ConnectionTrait, DbConn, EntityTrait as _, FromQueryResult, QueryFilter as _, + QueryOrder as _, QuerySelect as _, StreamTrait, prelude::Expr, - sea_query::{Asterisk, ExprTrait as _, Func, Query}, + sea_query::{Asterisk, ExprTrait as _, Func, IntoValueTuple, Query}, }; use crate::{ cursors::{ - CURSOR_DEFAULT_LIMIT, CURSOR_LIMIT_RANGE, ConnectionParameters, RecordDateCursor, - RecordRankCursor, + ConnectionParameters, RecordDateCursor, RecordRankCursor, expr_tuple::IntoExprTuple, + query_trait::CursorPaginable, }, - error::{self, ApiGqlError, GqlResult}, + error::{self, ApiGqlError, CursorDecodeError, CursorDecodeErrorKind, GqlResult}, loaders::{map::MapLoader, player::PlayerLoader}, objects::{ event_edition::EventEdition, player::Player, player_rating::PlayerRating, ranked_record::RankedRecord, records_filter::RecordsFilter, - related_edition::RelatedEdition, sort_state::SortState, - sortable_fields::MapRecordSortableField, + related_edition::RelatedEdition, sort::MapRecordSort, sort_order::SortOrder, + sort_state::SortState, sortable_fields::MapRecordSortableField, + }, + utils::{ + page_input::{PaginationInput, apply_cursor_input}, + pagination_result::{PaginationResult, get_paginated}, + records_filter::apply_filter, }, }; @@ -136,256 +141,186 @@ async fn get_map_records( let mut ranked_records = Vec::with_capacity(records.len()); - let mut ranking_session = ranks::RankingSession::try_from_pool(redis_pool).await?; + let mut redis_conn = redis_pool.get().await?; for record in records { - let rank = ranks::get_rank_in_session( - &mut ranking_session, - map_id, - record.record_player_id, - record.time, - event, - ) - .await?; + let rank = ranks::get_rank(&mut redis_conn, map_id, record.time, event).await?; ranked_records.push(records::RankedRecord { rank, record }.into()); } Ok(ranked_records) } -enum MapRecordCursor { +pub(crate) enum MapRecordCursor { Date(RecordDateCursor), Rank(RecordRankCursor), } -fn encode_map_cursor(cursor: &MapRecordCursor) -> String { - match cursor { - MapRecordCursor::Date(date_cursor) => date_cursor.encode_cursor(), - MapRecordCursor::Rank(rank_cursor) => rank_cursor.encode_cursor(), +impl From for MapRecordCursor { + #[inline] + fn from(value: RecordDateCursor) -> Self { + Self::Date(value) } } -async fn get_map_records_connection( - conn: &C, - redis_pool: &RedisPool, - map_id: u32, - event: OptEvent<'_>, - ConnectionParameters { - before, - after, - first, - last, - }: ConnectionParameters, - sort_field: Option, - filter: Option, -) -> GqlResult> { - let limit = if let Some(first) = first { - if !CURSOR_LIMIT_RANGE.contains(&first) { - return Err(ApiGqlError::from_cursor_range_error( - "first", - CURSOR_LIMIT_RANGE, - first, - )); - } - first - } else if let Some(last) = last { - if !CURSOR_LIMIT_RANGE.contains(&last) { - return Err(ApiGqlError::from_cursor_range_error( - "last", - CURSOR_LIMIT_RANGE, - last, - )); - } - last - } else { - CURSOR_DEFAULT_LIMIT - }; - - let has_next_page = after.is_some(); - - // Decode cursors if provided - let after = match after { - Some(cursor) => { - let cursor = match sort_field { - Some(MapRecordSortableField::Date) => { - CursorType::decode_cursor(&cursor).map(MapRecordCursor::Date) - } - Some(MapRecordSortableField::Rank) | None => { - CursorType::decode_cursor(&cursor).map(MapRecordCursor::Rank) - } - } - .map_err(|e| ApiGqlError::from_cursor_decode_error("after", cursor.0, e))?; +impl From for MapRecordCursor { + #[inline] + fn from(value: RecordRankCursor) -> Self { + Self::Rank(value) + } +} - Some(cursor) +impl CursorType for MapRecordCursor { + type Error = CursorDecodeError; + + fn decode_cursor(s: &str) -> Result { + match RecordRankCursor::decode_cursor(s) { + Ok(rank) => Ok(rank.into()), + Err(CursorDecodeError { + kind: CursorDecodeErrorKind::InvalidPrefix, + }) => match RecordDateCursor::decode_cursor(s) { + Ok(date) => Ok(date.into()), + Err(e) => Err(e), + }, + Err(e) => Err(e), } - None => None, - }; - - let before = match before { - Some(cursor) => { - let cursor = match sort_field { - Some(MapRecordSortableField::Date) => { - CursorType::decode_cursor(&cursor).map(MapRecordCursor::Date) - } - Some(MapRecordSortableField::Rank) | None => { - CursorType::decode_cursor(&cursor).map(MapRecordCursor::Rank) - } - } - .map_err(|e| ApiGqlError::from_cursor_decode_error("before", cursor.0, e))?; + } - Some(cursor) + fn encode_cursor(&self) -> String { + match self { + MapRecordCursor::Date(record_date_cursor) => record_date_cursor.encode_cursor(), + MapRecordCursor::Rank(record_rank_cursor) => record_rank_cursor.encode_cursor(), } - None => None, - }; - - update_leaderboard(conn, redis_pool, map_id, event).await?; - - let mut select = Query::select(); - - let select = match event.get() { - Some((ev, ed)) => select.from_as(global_event_records::Entity, "r").and_where( - Expr::col(("r", global_event_records::Column::EventId)) - .eq(ev.id) - .and(Expr::col(("r", global_event_records::Column::EditionId)).eq(ed.id)), - ), - None => select.from_as(global_records::Entity, "r"), } - .column(Asterisk) - .and_where(Expr::col(("r", records::Column::MapId)).eq(map_id)); - - // Apply filters if provided - if let Some(filter) = filter { - // For player filters, we need to join with players table - if let Some(filter) = filter.player { - select.join_as( - JoinType::InnerJoin, - players::Entity, - "p", - Expr::col(("r", records::Column::RecordPlayerId)) - .equals(("p", players::Column::Id)), - ); +} - if let Some(ref login) = filter.player_login { - select - .and_where(Expr::col(("p", players::Column::Login)).like(format!("%{login}%"))); +impl IntoExprTuple for &MapRecordCursor { + fn into_expr_tuple(self) -> crate::cursors::expr_tuple::ExprTuple { + match self { + MapRecordCursor::Date(record_date_cursor) => { + IntoExprTuple::into_expr_tuple(record_date_cursor) } - - if let Some(ref name) = filter.player_name { - select.and_where( - Func::cust("rm_mp_style") - .arg(Expr::col(("p", players::Column::Name))) - .like(format!("%{name}%")), - ); + MapRecordCursor::Rank(record_rank_cursor) => { + IntoExprTuple::into_expr_tuple(record_rank_cursor) } } + } +} - // Apply date filters - if let Some(before_date) = filter.before_date { - select.and_where(Expr::col(("r", records::Column::RecordDate)).lt(before_date)); - } - - if let Some(after_date) = filter.after_date { - select.and_where(Expr::col(("r", records::Column::RecordDate)).gt(after_date)); - } - - // Apply time filters - if let Some(time_gt) = filter.time_gt { - select.and_where(Expr::col(("r", records::Column::Time)).gt(time_gt)); - } - - if let Some(time_lt) = filter.time_lt { - select.and_where(Expr::col(("r", records::Column::Time)).lt(time_lt)); +impl IntoValueTuple for &MapRecordCursor { + fn into_value_tuple(self) -> sea_orm::sea_query::ValueTuple { + match self { + MapRecordCursor::Date(record_date_cursor) => { + IntoValueTuple::into_value_tuple(record_date_cursor) + } + MapRecordCursor::Rank(record_rank_cursor) => { + IntoValueTuple::into_value_tuple(record_rank_cursor) + } } } +} - // Apply cursor filters - if let Some(cursor) = after { - let (date, time) = match cursor { - MapRecordCursor::Date(date) => (date.0, None), - MapRecordCursor::Rank(rank) => (rank.record_date, Some(rank.time)), - }; - - select.and_where(Expr::col(("r", records::Column::RecordDate)).gt(date)); +pub(crate) async fn get_map_records_connection( + conn: &C, + redis_pool: &RedisPool, + map_id: u32, + event: OptEvent<'_>, + connection_parameters: ConnectionParameters, + sort: Option, + filter: Option, +) -> GqlResult> { + let pagination_input = PaginationInput::try_from_input(connection_parameters)?; + let cursor_encoder = match sort.map(|s| s.field) { + Some(MapRecordSortableField::Date) => |record: &records::Model| { + RecordDateCursor { + record_date: record.record_date.and_utc(), + data: record.record_id, + } + .encode_cursor() + }, + _ => |record: &records::Model| { + RecordRankCursor { + time: record.time, + record_date: record.record_date.and_utc(), + data: record.record_id, + } + .encode_cursor() + }, + }; - if let Some(time) = time { - select.and_where(Expr::col(("r", records::Column::Time)).gte(time)); - } - } + let mut query = + match event.get() { + Some((ev, ed)) => { + let base_query = apply_filter( + global_event_records::Entity::find().filter( + global_event_records::Column::MapId + .eq(map_id) + .and(global_event_records::Column::EventId.eq(ev.id)) + .and(global_event_records::Column::EditionId.eq(ed.id)), + ), + filter.as_ref(), + ); - if let Some(cursor) = before { - let (date, time) = match cursor { - MapRecordCursor::Date(date) => (date.0, None), - MapRecordCursor::Rank(rank) => (rank.record_date, Some(rank.time)), - }; + match (pagination_input.get_cursor(), sort.map(|s| s.field)) { + (Some(MapRecordCursor::Date(_)), _) + | (_, Some(MapRecordSortableField::Date)) => base_query.paginate_cursor_by(( + global_event_records::Column::RecordDate, + global_event_records::Column::RecordId, + )), + _ => base_query.paginate_cursor_by(( + global_event_records::Column::Time, + global_event_records::Column::RecordDate, + global_event_records::Column::RecordId, + )), + } + .into_model::() + } - select.and_where(Expr::col(("r", records::Column::RecordDate)).lt(date)); + None => { + let base_query = apply_filter( + global_records::Entity::find().filter(global_records::Column::MapId.eq(map_id)), + filter.as_ref(), + ); - if let Some(time) = time { - select.and_where(Expr::col(("r", records::Column::Time)).lt(time)); - } - } + match (pagination_input.get_cursor(), sort.map(|s| s.field)) { + (Some(MapRecordCursor::Date(_)), _) + | (_, Some(MapRecordSortableField::Date)) => base_query.paginate_cursor_by(( + global_event_records::Column::RecordDate, + global_event_records::Column::RecordId, + )), + _ => base_query.paginate_cursor_by(( + global_records::Column::Time, + global_records::Column::RecordDate, + global_records::Column::RecordId, + )), + } + .into_model::() + } + }; - // Apply ordering - if let Some(MapRecordSortableField::Rank) | None = sort_field { - select.order_by_expr( - Expr::col(("r", records::Column::Time)).into(), - if last.is_some() { - Order::Desc - } else { - Order::Asc - }, - ); - } - select.order_by_expr( - Expr::col(("r", records::Column::RecordDate)).into(), - if last.is_some() { - Order::Desc - } else { - Order::Asc - }, - ); + apply_cursor_input(&mut query, &pagination_input); - // Fetch one extra to determine if there's a next page - select.limit((limit + 1) as u64); + match sort.and_then(|s| s.order) { + Some(SortOrder::Descending) => query.desc(), + _ => query.asc(), + }; - let stmt = conn.get_database_backend().build(&*select); - let records = conn - .query_all(stmt) - .await? - .into_iter() - .map(|result| records::Model::from_query_result(&result, "")) - .collect::, _>>()?; + let PaginationResult { + mut connection, + iter: records, + } = get_paginated(conn, query, &pagination_input).await?; - let mut connection = connection::Connection::new(has_next_page, records.len() > limit); + connection.edges.reserve(records.len()); - let encode_cursor_fn = match sort_field { - Some(MapRecordSortableField::Date) => |record: &records::Model| { - encode_map_cursor(&MapRecordCursor::Date(RecordDateCursor( - record.record_date.and_utc(), - ))) - }, - Some(MapRecordSortableField::Rank) | None => |record: &records::Model| { - encode_map_cursor(&MapRecordCursor::Rank(RecordRankCursor { - record_date: record.record_date.and_utc(), - time: record.time, - })) - }, - }; + ranks::update_leaderboard(conn, redis_pool, map_id, event).await?; - let mut ranking_session = ranks::RankingSession::try_from_pool(redis_pool).await?; + let mut redis_conn = redis_pool.get().await?; - for record in records.into_iter().take(limit) { - let rank = ranks::get_rank_in_session( - &mut ranking_session, - map_id, - record.record_player_id, - record.time, - event, - ) - .await?; + for record in records { + let rank = ranks::get_rank(&mut redis_conn, record.map_id, record.time, event).await?; connection.edges.push(connection::Edge::new( - ID(encode_cursor_fn(&record)), + ID((cursor_encoder)(&record)), records::RankedRecord { rank, record }.into(), )); } @@ -426,39 +361,36 @@ impl Map { before: Option, first: Option, last: Option, - sort_field: Option, + sort: Option, filter: Option, ) -> GqlResult> { let db = gql_ctx.data_unchecked::(); - records_lib::assert_future_send(sync::transaction(&db.sql_conn, async |txn| { - connection::query( - after, - before, - first, - last, - |after, before, first, last| async move { - get_map_records_connection( - txn, - &db.redis_pool, - self.inner.id, - event, - ConnectionParameters { - after, - before, - first, - last, - }, - sort_field, - filter, - ) - .await - }, - ) - .await - .map_err(error::map_gql_err) - })) + connection::query_with( + after, + before, + first, + last, + |after, before, first, last| async move { + get_map_records_connection( + &db.sql_conn, + &db.redis_pool, + self.inner.id, + event, + ConnectionParameters { + after, + before, + first, + last, + }, + sort, + filter, + ) + .await + }, + ) .await + .map_err(error::map_gql_err) } } @@ -497,6 +429,10 @@ impl Map { &self.inner.name } + async fn score(&self) -> f64 { + self.inner.score + } + async fn related_event_editions( &self, ctx: &async_graphql::Context<'_>, @@ -586,7 +522,7 @@ impl Map { before: Option, #[graphql(desc = "Number of records to fetch (default: 50, max: 100)")] first: Option, #[graphql(desc = "Number of records to fetch from the end (for backward pagination)")] last: Option, - sort_field: Option, + sort: Option, filter: Option, ) -> GqlResult> { self.get_records_connection( @@ -596,7 +532,7 @@ impl Map { before, first, last, - sort_field, + sort, filter, ) .await diff --git a/crates/graphql-api/src/objects/map_with_score.rs b/crates/graphql-api/src/objects/map_with_score.rs index 20407d3..2950a3f 100644 --- a/crates/graphql-api/src/objects/map_with_score.rs +++ b/crates/graphql-api/src/objects/map_with_score.rs @@ -6,5 +6,4 @@ use crate::objects::map::Map; pub struct MapWithScore { pub rank: i32, pub map: Map, - pub score: f64, } diff --git a/crates/graphql-api/src/objects/mappack.rs b/crates/graphql-api/src/objects/mappack.rs index a102e29..7146720 100644 --- a/crates/graphql-api/src/objects/mappack.rs +++ b/crates/graphql-api/src/objects/mappack.rs @@ -1,3 +1,4 @@ +use mkenv::prelude::*; use std::time::SystemTime; use deadpool_redis::redis::{self, AsyncCommands as _}; @@ -192,7 +193,7 @@ impl Mappack { .get(mappack_time_key(AnyMappackId::Id(&self.mappack_id))) .await?; Ok(last_upd_time - .map(|last| last + records_lib::env().event_scores_interval.as_secs()) + .map(|last| last + records_lib::env().event_scores_interval.get().as_secs()) .and_then(|last| { SystemTime::UNIX_EPOCH .elapsed() diff --git a/crates/graphql-api/src/objects/player.rs b/crates/graphql-api/src/objects/player.rs index 2b8274e..cc890e6 100644 --- a/crates/graphql-api/src/objects/player.rs +++ b/crates/graphql-api/src/objects/player.rs @@ -1,27 +1,24 @@ -use async_graphql::connection::CursorType as _; use async_graphql::{Enum, ID, connection}; -use entity::{global_records, maps, players, records, role}; +use entity::{global_records, players, records, role}; use records_lib::{Database, ranks}; use records_lib::{RedisPool, error::RecordsError, internal, opt_event::OptEvent, sync}; -use sea_orm::Order; use sea_orm::{ - ColumnTrait as _, ConnectionTrait, DbConn, EntityTrait as _, FromQueryResult, JoinType, - QueryFilter as _, QueryOrder as _, QuerySelect as _, RelationTrait, StreamTrait, - prelude::Expr, - sea_query::{ExprTrait as _, Func}, + ColumnTrait as _, ConnectionTrait, DbConn, EntityTrait as _, FromQueryResult, QueryFilter as _, + QueryOrder as _, QuerySelect as _, StreamTrait, TransactionTrait, }; +use crate::cursors::RecordDateCursor; use crate::objects::records_filter::RecordsFilter; +use crate::objects::root::get_records_connection_impl; +use crate::objects::sort::UnorderedRecordSort; +use crate::utils::records_filter::apply_filter; use crate::{ cursors::ConnectionParameters, - error::{ApiGqlError, GqlResult}, + error::GqlResult, objects::{ranked_record::RankedRecord, sort_state::SortState}, }; -use crate::{ - cursors::{CURSOR_DEFAULT_LIMIT, CURSOR_LIMIT_RANGE, RecordDateCursor}, - error, -}; +use crate::error; #[derive(Copy, Clone, Eq, PartialEq, Enum)] #[repr(u8)] @@ -74,6 +71,10 @@ impl Player { self.inner.zone_path.as_deref() } + async fn score(&self) -> f64 { + self.inner.score + } + async fn role(&self, ctx: &async_graphql::Context<'_>) -> GqlResult { let conn = ctx.data_unchecked::(); @@ -117,37 +118,36 @@ impl Player { before: Option, #[graphql(desc = "Number of records to fetch (default: 50, max: 100)")] first: Option, #[graphql(desc = "Number of records to fetch from the end (for backward pagination)")] last: Option, + sort: Option, filter: Option, ) -> GqlResult> { let db = ctx.data_unchecked::(); - records_lib::assert_future_send(sync::transaction(&db.sql_conn, async |txn| { - connection::query( - after, - before, - first, - last, - |after, before, first, last| async move { - get_player_records_connection( - txn, - &db.redis_pool, - self.inner.id, - Default::default(), - ConnectionParameters { - after, - before, - first, - last, - }, - filter, - ) - .await - }, - ) - .await - .map_err(error::map_gql_err) - })) + connection::query_with( + after, + before, + first, + last, + |after, before, first, last| async move { + get_player_records_connection( + &db.sql_conn, + &db.redis_pool, + self.inner.id, + Default::default(), + ConnectionParameters { + after, + before, + first, + last, + }, + sort, + filter, + ) + .await + }, + ) .await + .map_err(error::map_gql_err) } } @@ -175,17 +175,10 @@ async fn get_player_records( let mut ranked_records = Vec::with_capacity(records.len()); - let mut ranking_session = ranks::RankingSession::try_from_pool(redis_pool).await?; + let mut redis_conn = redis_pool.get().await?; for record in records { - let rank = ranks::get_rank_in_session( - &mut ranking_session, - record.map_id, - record.record_player_id, - record.time, - event, - ) - .await?; + let rank = ranks::get_rank(&mut redis_conn, record.map_id, record.time, event).await?; ranked_records.push( records::RankedRecord { @@ -199,158 +192,30 @@ async fn get_player_records( Ok(ranked_records) } -async fn get_player_records_connection( +pub(crate) async fn get_player_records_connection( conn: &C, redis_pool: &RedisPool, player_id: u32, event: OptEvent<'_>, - ConnectionParameters { - after, - before, - first, - last, - }: ConnectionParameters, + connection_parameters: ConnectionParameters, + sort: Option, filter: Option, -) -> GqlResult> { - let limit = if let Some(first) = first { - if !CURSOR_LIMIT_RANGE.contains(&first) { - return Err(ApiGqlError::from_cursor_range_error( - "first", - CURSOR_LIMIT_RANGE, - first, - )); - } - first - } else if let Some(last) = last { - if !CURSOR_LIMIT_RANGE.contains(&last) { - return Err(ApiGqlError::from_cursor_range_error( - "last", - CURSOR_LIMIT_RANGE, - last, - )); - } - last - } else { - CURSOR_DEFAULT_LIMIT - }; - - let has_previous_page = after.is_some(); - - // Decode cursors if provided - let after_timestamp = match after { - Some(cursor) => { - let decoded = RecordDateCursor::decode_cursor(&cursor) - .map_err(|e| ApiGqlError::from_cursor_decode_error("after", cursor.0, e))?; - Some(decoded) - } - None => None, - }; - - let before_timestamp = match before { - Some(cursor) => { - let decoded = RecordDateCursor::decode_cursor(&cursor) - .map_err(|e| ApiGqlError::from_cursor_decode_error("before", cursor.0, e))?; - Some(decoded) - } - None => None, - }; - - // Build query with appropriate ordering - let mut query = - global_records::Entity::find().filter(global_records::Column::RecordPlayerId.eq(player_id)); - - // Apply filters if provided - if let Some(filter) = filter { - // Join with maps table if needed for map filters - if let Some(filter) = filter.map { - query = query.join_as( - JoinType::InnerJoin, - global_records::Relation::Maps.def(), - "m", - ); - - // Apply map UID filter - if let Some(uid) = filter.map_uid { - query = - query.filter(Expr::col(("m", maps::Column::GameId)).like(format!("%{uid}%"))); - } - - // Apply map name filter - if let Some(name) = filter.map_name { - query = query.filter( - Func::cust("rm_mp_style") - .arg(Expr::col(("m", maps::Column::Name))) - .like(format!("%{name}%")), - ); - } - } - - // Apply date filters - if let Some(before_date) = filter.before_date { - query = query.filter(global_records::Column::RecordDate.lt(before_date)); - } - - if let Some(after_date) = filter.after_date { - query = query.filter(global_records::Column::RecordDate.gt(after_date)); - } - - // Apply time filters - if let Some(time_gt) = filter.time_gt { - query = query.filter(global_records::Column::Time.gt(time_gt)); - } - - if let Some(time_lt) = filter.time_lt { - query = query.filter(global_records::Column::Time.lt(time_lt)); - } - } - - // Apply cursor filters - if let Some(timestamp) = after_timestamp { - query = query.filter(global_records::Column::RecordDate.lt(timestamp.0)); - } - - if let Some(timestamp) = before_timestamp { - query = query.filter(global_records::Column::RecordDate.gt(timestamp.0)); - } - - // Apply ordering - query = query.order_by( - global_records::Column::RecordDate, - if last.is_some() { - Order::Asc - } else { - Order::Desc - }, +) -> GqlResult> +where + C: ConnectionTrait + TransactionTrait, +{ + let base_query = apply_filter( + global_records::Entity::find().filter(global_records::Column::RecordPlayerId.eq(player_id)), + filter.as_ref(), ); - // Fetch one extra to determine if there's a next page - query = query.limit((limit + 1) as u64); - - let records = query.all(conn).await?; - - let mut connection = connection::Connection::new(has_previous_page, records.len() > limit); - - let mut ranking_session = ranks::RankingSession::try_from_pool(redis_pool).await?; - - for record in records.into_iter().take(limit) { - let rank = ranks::get_rank_in_session( - &mut ranking_session, - record.map_id, - record.record_player_id, - record.time, - event, - ) - .await?; - - connection.edges.push(connection::Edge::new( - ID(RecordDateCursor(record.record_date.and_utc()).encode_cursor()), - records::RankedRecord { - rank, - record: record.into(), - } - .into(), - )); - } - - Ok(connection) + get_records_connection_impl( + conn, + redis_pool, + connection_parameters, + event, + sort, + base_query, + ) + .await } diff --git a/crates/graphql-api/src/objects/player_with_score.rs b/crates/graphql-api/src/objects/player_with_score.rs index 5e53e92..fb806a7 100644 --- a/crates/graphql-api/src/objects/player_with_score.rs +++ b/crates/graphql-api/src/objects/player_with_score.rs @@ -6,5 +6,4 @@ use crate::objects::player::Player; pub struct PlayerWithScore { pub rank: i32, pub player: Player, - pub score: f64, } diff --git a/crates/graphql-api/src/objects/root.rs b/crates/graphql-api/src/objects/root.rs index a8c2d9f..cf08b74 100644 --- a/crates/graphql-api/src/objects/root.rs +++ b/crates/graphql-api/src/objects/root.rs @@ -1,31 +1,29 @@ -use std::{collections::HashMap, fmt}; - use async_graphql::{ - ID, OutputType, - connection::{self, CursorType as _}, + ID, + connection::{self, CursorType}, }; use deadpool_redis::redis::{AsyncCommands, ToRedisArgs}; use entity::{event as event_entity, event_edition, global_records, maps, players, records}; use records_lib::{ - Database, RedisConnection, RedisPool, internal, must, + Database, RedisConnection, RedisPool, must, opt_event::OptEvent, ranks, - redis_key::{map_ranking, player_ranking}, + redis_key::{MapRanking, PlayerRanking, map_ranking, player_ranking}, sync, }; use sea_orm::{ - ColumnTrait as _, ConnectionTrait, DbConn, EntityTrait as _, JoinType, Order, QueryFilter as _, - QueryOrder as _, QuerySelect as _, QueryTrait, RelationTrait, StreamTrait, + ColumnTrait, ConnectionTrait, DbConn, EntityTrait, FromQueryResult, Identity, QueryFilter as _, + QueryOrder as _, QuerySelect, Select, SelectModel, StreamTrait, TransactionTrait, prelude::Expr, - sea_query::{ExprTrait as _, Func}, + sea_query::{Asterisk, ExprTrait as _, Func, IntoIden, IntoValueTuple, SelectStatement}, }; use crate::{ cursors::{ - CURSOR_DEFAULT_LIMIT, CURSOR_LIMIT_RANGE, ConnectionParameters, F64Cursor, - RecordDateCursor, TextCursor, + ConnectionParameters, F64Cursor, RecordDateCursor, TextCursor, expr_tuple::IntoExprTuple, + query_builder::CursorQueryBuilder, query_trait::CursorPaginable, }, - error::{self, ApiGqlError, GqlResult}, + error::{self, ApiGqlError, CursorDecodeError, CursorDecodeErrorKind, GqlResult}, objects::{ event::Event, event_edition::EventEdition, @@ -38,7 +36,15 @@ use crate::{ player_with_score::PlayerWithScore, ranked_record::RankedRecord, records_filter::RecordsFilter, + sort::{PlayerMapRankingSort, UnorderedRecordSort}, + sort_order::SortOrder, sort_state::SortState, + sortable_fields::{PlayerMapRankingSortableField, UnorderedRecordSortableField}, + }, + utils::{ + page_input::{PaginationInput, apply_cursor_input}, + pagination_result::{PaginationResult, get_paginated}, + records_filter::apply_filter, }, }; @@ -56,17 +62,10 @@ async fn get_record( return Err(ApiGqlError::from_record_not_found_error(record_id)); }; - let mut ranking_session = ranks::RankingSession::try_from_pool(redis_pool).await?; + let mut redis_conn = redis_pool.get().await?; let out = records::RankedRecord { - rank: ranks::get_rank_in_session( - &mut ranking_session, - record.map_id, - record.record_player_id, - record.time, - event, - ) - .await?, + rank: ranks::get_rank(&mut redis_conn, record.map_id, record.time, event).await?, record, } .into(); @@ -94,17 +93,10 @@ async fn get_records( let mut ranked_records = Vec::with_capacity(records.len()); - let mut ranking_session = ranks::RankingSession::try_from_pool(redis_pool).await?; + let mut redis_conn = redis_pool.get().await?; for record in records { - let rank = ranks::get_rank_in_session( - &mut ranking_session, - record.map_id, - record.record_player_id, - record.time, - event, - ) - .await?; + let rank = ranks::get_rank(&mut redis_conn, record.map_id, record.time, event).await?; ranked_records.push( records::RankedRecord { @@ -118,201 +110,57 @@ async fn get_records( Ok(ranked_records) } -async fn get_records_connection( +pub(crate) async fn get_records_connection_impl( conn: &C, redis_pool: &RedisPool, - ConnectionParameters { - after, - before, - first, - last, - }: ConnectionParameters, + connection_parameters: ConnectionParameters, event: OptEvent<'_>, - filter: Option, + sort: Option, + base_query: Select, ) -> GqlResult> { - let limit = if let Some(first) = first { - if !CURSOR_LIMIT_RANGE.contains(&first) { - return Err(ApiGqlError::from_cursor_range_error( - "first", - CURSOR_LIMIT_RANGE, - first, - )); - } - first - } else if let Some(last) = last { - if !CURSOR_LIMIT_RANGE.contains(&last) { - return Err(ApiGqlError::from_cursor_range_error( - "last", - CURSOR_LIMIT_RANGE, - last, - )); - } - last - } else { - CURSOR_DEFAULT_LIMIT - }; - - let has_previous_page = after.is_some(); - - // Decode cursors if provided - let after_timestamp = match after { - Some(cursor) => { - let decoded = RecordDateCursor::decode_cursor(&cursor) - .map_err(|e| ApiGqlError::from_cursor_decode_error("after", cursor.0, e))?; - Some(decoded) - } - None => None, - }; - - let before_timestamp = match before { - Some(cursor) => { - let decoded = RecordDateCursor::decode_cursor(&cursor) - .map_err(|e| ApiGqlError::from_cursor_decode_error("before", cursor.0, e))?; - Some(decoded) - } - None => None, - }; - - // Build query with appropriate ordering - let mut query = global_records::Entity::find(); - - // Apply filters if provided - if let Some(filter) = filter { - // Join with players table if needed for player filters - if filter.player.is_some() { - query = query.join_as( - JoinType::InnerJoin, - global_records::Relation::Players.def(), - "p", - ); - } - - // Join with maps table if needed for map filters - if let Some(m) = &filter.map { - query = query.join_as( - JoinType::InnerJoin, - global_records::Relation::Maps.def(), - "m", - ); - - // Join again with players table if filtering on map author - if m.author.is_some() { - query = query.join_as(JoinType::InnerJoin, maps::Relation::Players.def(), "p2"); - } - } - - if let Some(filter) = filter.player { - // Apply player login filter - if let Some(login) = filter.player_login { - query = query - .filter(Expr::col(("p", players::Column::Login)).like(format!("%{login}%"))); - } - - // Apply player name filter - if let Some(name) = filter.player_name { - query = query.filter( - Func::cust("rm_mp_style") - .arg(Expr::col(("p", players::Column::Name))) - .like(format!("%{name}%")), - ); - } - } - - if let Some(filter) = filter.map { - // Apply map UID filter - if let Some(uid) = filter.map_uid { - query = - query.filter(Expr::col(("m", maps::Column::GameId)).like(format!("%{uid}%"))); - } - - // Apply map name filter - if let Some(name) = filter.map_name { - query = query.filter( - Func::cust("rm_mp_style") - .arg(Expr::col(("m", maps::Column::Name))) - .like(format!("%{name}%")), - ); - } - - if let Some(filter) = filter.author { - // Apply player login filter - if let Some(login) = filter.player_login { - query = query.filter( - Expr::col(("p2", players::Column::Login)).like(format!("%{login}%")), - ); - } - - // Apply player name filter - if let Some(name) = filter.player_name { - query = query.filter( - Func::cust("rm_mp_style") - .arg(Expr::col(("p2", players::Column::Name))) - .like(format!("%{name}%")), - ); - } - } - } - - // Apply date filters - if let Some(before_date) = filter.before_date { - query = query.filter(global_records::Column::RecordDate.lt(before_date)); - } - - if let Some(after_date) = filter.after_date { - query = query.filter(global_records::Column::RecordDate.gt(after_date)); - } - - // Apply time filters - if let Some(time_gt) = filter.time_gt { - query = query.filter(global_records::Column::Time.gt(time_gt)); - } + let pagination_input = PaginationInput::try_from_input(connection_parameters)?; - if let Some(time_lt) = filter.time_lt { - query = query.filter(global_records::Column::Time.lt(time_lt)); - } - } - - // Apply cursor filters - if let Some(timestamp) = after_timestamp { - query = query.filter(global_records::Column::RecordDate.lt(timestamp.0)); - } + let mut query = base_query.paginate_cursor_by(( + global_records::Column::RecordDate, + global_records::Column::RecordId, + )); - if let Some(timestamp) = before_timestamp { - query = query.filter(global_records::Column::RecordDate.gt(timestamp.0)); - } + apply_cursor_input(&mut query, &pagination_input); - // Apply ordering - query = query.order_by( - global_records::Column::RecordDate, - if last.is_some() { - Order::Asc - } else { - Order::Desc - }, + // Record dates are ordered by desc by default + let is_sort_asc = matches!( + sort, + Some(UnorderedRecordSort { + field: UnorderedRecordSortableField::Date, + order: Some(SortOrder::Descending), + }) ); - // Fetch one extra to determine if there's a next page - query = query.limit((limit + 1) as u64); + if is_sort_asc { + query.asc(); + } else { + query.desc(); + } - let records = query.all(conn).await?; + // inline get_paginated + let PaginationResult { + mut connection, + iter: records, + } = get_paginated(conn, query, &pagination_input).await?; - let mut connection = connection::Connection::new(has_previous_page, records.len() > limit); connection.edges.reserve(records.len()); - let mut ranking_session = ranks::RankingSession::try_from_pool(redis_pool).await?; + let mut redis_conn = redis_pool.get().await?; - for record in records.into_iter().take(limit) { - let rank = ranks::get_rank_in_session( - &mut ranking_session, - record.map_id, - record.record_player_id, - record.time, - event, - ) - .await?; + for record in records { + let rank = ranks::get_rank(&mut redis_conn, record.map_id, record.time, event).await?; connection.edges.push(connection::Edge::new( - ID(RecordDateCursor(record.record_date.and_utc()).encode_cursor()), + ID(RecordDateCursor { + record_date: record.record_date.and_utc(), + data: record.record_id, + } + .encode_cursor()), records::RankedRecord { rank, record: record.into(), @@ -324,6 +172,27 @@ async fn get_records_connection( Ok(connection) } +pub(crate) async fn get_records_connection( + conn: &C, + redis_pool: &RedisPool, + connection_parameters: ConnectionParameters, + event: OptEvent<'_>, + sort: Option, + filter: Option, +) -> GqlResult> { + let base_query = apply_filter(global_records::Entity::find(), filter.as_ref()); + + get_records_connection_impl( + conn, + redis_pool, + connection_parameters, + event, + sort, + base_query, + ) + .await +} + #[async_graphql::Object] impl QueryRoot { async fn event_edition_from_mx_id( @@ -436,6 +305,7 @@ impl QueryRoot { .await } + #[allow(clippy::too_many_arguments)] async fn players( &self, ctx: &async_graphql::Context<'_>, @@ -444,34 +314,34 @@ impl QueryRoot { first: Option, last: Option, filter: Option, + sort: Option, ) -> GqlResult> { let db = ctx.data_unchecked::(); let mut redis_conn = db.redis_pool.get().await?; - connection::query( + connection::query_with( after, before, first, last, |after, before, first, last| async move { - get_players_connection( - &db.sql_conn, - &mut redis_conn, - ConnectionParameters { - after, - before, - first, - last, - }, - filter, - ) - .await + let input = ::new(ConnectionParameters { + after, + before, + first, + last, + }) + .with_filter(filter) + .with_sort(sort); + + get_players_connection(&db.sql_conn, &mut redis_conn, input).await }, ) .await .map_err(error::map_gql_err) } + #[allow(clippy::too_many_arguments)] async fn maps( &self, ctx: &async_graphql::Context<'_>, @@ -480,28 +350,27 @@ impl QueryRoot { first: Option, last: Option, filter: Option, + sort: Option, ) -> GqlResult> { let db = ctx.data_unchecked::(); let mut redis_conn = db.redis_pool.get().await?; - connection::query( + connection::query_with( after, before, first, last, |after, before, first, last| async move { - get_maps_connection( - &db.sql_conn, - &mut redis_conn, - ConnectionParameters { - after, - before, - first, - last, - }, - filter, - ) - .await + let input = ::new(ConnectionParameters { + after, + before, + first, + last, + }) + .with_filter(filter) + .with_sort(sort); + + get_maps_connection(&db.sql_conn, &mut redis_conn, input).await }, ) .await @@ -519,433 +388,418 @@ impl QueryRoot { before: Option, #[graphql(desc = "Number of records to fetch (default: 50, max: 100)")] first: Option, #[graphql(desc = "Number of records to fetch from the end (for backward pagination)")] last: Option, + sort: Option, filter: Option, ) -> GqlResult> { let db = ctx.data_unchecked::(); - let conn = ctx.data_unchecked::(); - sync::transaction(conn, async |txn| { - connection::query( - after, - before, - first, - last, - |after, before, first, last| async move { - get_records_connection( - txn, - &db.redis_pool, - ConnectionParameters { - after, - before, - first, - last, - }, - Default::default(), - filter, - ) - .await - }, - ) - .await - .map_err(error::map_gql_err) - }) + connection::query_with( + after, + before, + first, + last, + |after, before, first, last| async move { + get_records_connection( + &db.sql_conn, + &db.redis_pool, + ConnectionParameters { + after, + before, + first, + last, + }, + Default::default(), + sort, + filter, + ) + .await + }, + ) .await + .map_err(error::map_gql_err) } } -/// If a filter is provided, the result is ordered by the login of the players, so the cursors become -/// based on them. Otherwise, the result is ordered by the score of the players. -async fn get_players_connection( - conn: &C, - redis_conn: &mut RedisConnection, - ConnectionParameters { - after, - before, - first, - last, - }: ConnectionParameters, - filter: Option, -) -> GqlResult> { - match filter { - Some(filter) => { - let after = after - .map(|after| { - TextCursor::decode_cursor(&after.0) - .map_err(|e| ApiGqlError::from_cursor_decode_error("after", after.0, e)) - }) - .transpose()?; +enum EitherRedisKey { + A(A), + B(B), +} - let before = before - .map(|before| { - TextCursor::decode_cursor(&before.0) - .map_err(|e| ApiGqlError::from_cursor_decode_error("before", before.0, e)) - }) - .transpose()?; +impl ToRedisArgs for EitherRedisKey +where + A: ToRedisArgs, + B: ToRedisArgs, +{ + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + deadpool_redis::redis::RedisWrite, + { + match self { + EitherRedisKey::A(a) => ToRedisArgs::write_redis_args(a, out), + EitherRedisKey::B(b) => ToRedisArgs::write_redis_args(b, out), + } + } +} - let limit = first - .map(|f| f as u64) - .or(last.map(|l| l as _)) - .map(|l| l.min(100)) - .unwrap_or(50); +fn custom_source_or(source: Option, default: F) -> EitherRedisKey +where + F: FnOnce() -> D, +{ + source + .map(EitherRedisKey::B) + .unwrap_or_else(|| EitherRedisKey::A(default())) +} - let has_previous_page = after.is_some(); +pub(crate) struct ConnectionInput { + connection_parameters: ConnectionParameters, + filter: Option, + sort: Option, + source: Option, +} - let query = players::Entity::find() - .apply_if(filter.player_login, |query, login| { - query.filter(players::Column::Login.like(format!("%{login}%"))) - }) - .apply_if(filter.player_name, |query, name| { - query.filter( - Func::cust("rm_mp_style") - .arg(Expr::col((players::Entity, players::Column::Name))) - .like(format!("%{name}%")), - ) - }) - .apply_if(after, |query, after| { - query.filter(players::Column::Login.gt(after.0)) - }) - .apply_if(before, |query, before| { - query.filter(players::Column::Login.lt(before.0)) - }) - .order_by( - players::Column::Login, - if first.is_some() { - Order::Asc - } else { - Order::Desc - }, - ) - .limit(limit + 1) - .all(conn) - .await?; - - let mut connection = - connection::Connection::new(has_previous_page, query.len() > limit as _); - connection.edges.reserve(limit as _); - - for player in query.into_iter().take(limit as _) { - let score = redis_conn.zscore(player_ranking(), player.id).await?; - let rank: i32 = redis_conn.zrevrank(player_ranking(), player.id).await?; - connection.edges.push(connection::Edge::new( - ID(TextCursor(player.login.clone()).encode_cursor()), - PlayerWithScore { - score, - rank: rank + 1, - player: player.into(), - }, - )); - } +// derive(Default) adds Default: bound to the generics +impl Default for ConnectionInput { + fn default() -> Self { + Self { + connection_parameters: Default::default(), + filter: Default::default(), + sort: Default::default(), + source: Default::default(), + } + } +} - Ok(connection) +impl ConnectionInput { + pub(crate) fn new(connection_parameters: ConnectionParameters) -> Self { + Self { + connection_parameters, + ..Default::default() } - None => { - let has_previous_page = after.is_some(); - - let player_with_scores = - build_scores(redis_conn, player_ranking(), after, before, first, last).await?; - - let players = players::Entity::find() - .filter(players::Column::Id.is_in(player_with_scores.keys().copied())) - .all(conn) - .await?; - - let limit = first.or(last).unwrap_or(50); - let mut connection = connection::Connection::<_, PlayerWithScore>::new( - has_previous_page, - players.len() > limit, - ); - - build_connections( - redis_conn, - player_ranking(), - &mut connection, - players, - limit, - &player_with_scores, - ) - .await?; + } - if last.is_some() { - connection.edges.sort_by_key(|edge| -edge.node.rank); - } else { - connection.edges.sort_by_key(|edge| edge.node.rank); - } + pub(crate) fn with_filter(mut self, filter: Option) -> Self { + self.filter = filter; + self + } + + pub(crate) fn with_sort(mut self, sort: Option) -> Self { + self.sort = sort; + self + } - Ok(connection) + #[allow(unused)] // used for testing + pub(crate) fn with_source(self, source: U) -> ConnectionInput { + ConnectionInput { + connection_parameters: self.connection_parameters, + filter: self.filter, + sort: self.sort, + source: Some(source), } } } -/// If a filter is provided, the result is ordered by the UID of the maps, so the cursors become -/// based on them. Otherwise, the result is ordered by the score of the maps. -async fn get_maps_connection( - conn: &C, - redis_conn: &mut RedisConnection, - ConnectionParameters { - after, - before, - first, - last, - }: ConnectionParameters, - filter: Option, -) -> GqlResult> { - match filter { - Some(filter) => { - let after = after - .map(|after| { - TextCursor::decode_cursor(&after.0) - .map_err(|e| ApiGqlError::from_cursor_decode_error("after", after.0, e)) - }) - .transpose()?; - - let before = before - .map(|before| { - TextCursor::decode_cursor(&before.0) - .map_err(|e| ApiGqlError::from_cursor_decode_error("before", before.0, e)) - }) - .transpose()?; - - let limit = first - .map(|f| f as u64) - .or(last.map(|l| l as _)) - .map(|l| l.min(100)) - .unwrap_or(50); +pub(crate) enum PlayerMapRankingCursor { + Name(TextCursor), + Score(F64Cursor), +} - let has_previous_page = after.is_some(); +impl From for PlayerMapRankingCursor { + #[inline] + fn from(value: TextCursor) -> Self { + Self::Name(value) + } +} - let query = maps::Entity::find() - .apply_if(filter.author, |query, filter| { - query - .inner_join(players::Entity) - .apply_if(filter.player_login, |query, login| { - query.filter(players::Column::Login.like(format!("%{login}%"))) - }) - .apply_if(filter.player_name, |query, name| { - query.filter( - Func::cust("rm_mp_style") - .arg(Expr::col((players::Entity, players::Column::Name))) - .like(format!("%{name}%")), - ) - }) - }) - .apply_if(filter.map_uid, |query, uid| { - query.filter(maps::Column::GameId.like(format!("%{uid}%"))) - }) - .apply_if(filter.map_name, |query, name| { - query.filter( - Func::cust("rm_mp_style") - .arg(Expr::col((maps::Entity, maps::Column::Name))) - .like(format!("%{name}%")), - ) - }) - .apply_if(after, |query, after| { - query.filter(maps::Column::GameId.gt(after.0)) - }) - .apply_if(before, |query, before| { - query.filter(maps::Column::GameId.lt(before.0)) - }) - .order_by( - maps::Column::GameId, - if first.is_some() { - Order::Asc - } else { - Order::Desc - }, - ) - .limit(limit + 1) - .all(conn) - .await?; - - let mut connection = - connection::Connection::new(has_previous_page, query.len() > limit as _); - connection.edges.reserve(limit as _); - - for map in query.into_iter().take(limit as _) { - let score = redis_conn.zscore(map_ranking(), map.id).await?; - let rank: i32 = redis_conn.zrevrank(map_ranking(), map.id).await?; - connection.edges.push(connection::Edge::new( - ID(TextCursor(map.game_id.clone()).encode_cursor()), - MapWithScore { - score, - rank: rank + 1, - map: map.into(), - }, - )); - } +impl From for PlayerMapRankingCursor { + #[inline] + fn from(value: F64Cursor) -> Self { + Self::Score(value) + } +} - Ok(connection) +impl CursorType for PlayerMapRankingCursor { + type Error = CursorDecodeError; + + fn decode_cursor(s: &str) -> Result { + match F64Cursor::decode_cursor(s) { + Ok(score) => Ok(score.into()), + Err(CursorDecodeError { + kind: CursorDecodeErrorKind::InvalidPrefix, + }) => match TextCursor::decode_cursor(s) { + Ok(text) => Ok(text.into()), + Err(e) => Err(e), + }, + Err(e) => Err(e), } - None => { - let has_previous_page = after.is_some(); - - let map_with_scores = - build_scores(redis_conn, map_ranking(), after, before, first, last).await?; - - let maps = maps::Entity::find() - .filter(maps::Column::Id.is_in(map_with_scores.keys().copied())) - .all(conn) - .await?; - - let limit = first.or(last).unwrap_or(50); - let mut connection = connection::Connection::<_, MapWithScore>::new( - has_previous_page, - maps.len() > limit, - ); - - build_connections( - redis_conn, - map_ranking(), - &mut connection, - maps, - limit, - &map_with_scores, - ) - .await?; - - if last.is_some() { - connection.edges.sort_by_key(|edge| -edge.node.rank); - } else { - connection.edges.sort_by_key(|edge| edge.node.rank); - } + } - Ok(connection) + fn encode_cursor(&self) -> String { + match self { + PlayerMapRankingCursor::Name(text_cursor) => text_cursor.encode_cursor(), + PlayerMapRankingCursor::Score(f64_cursor) => f64_cursor.encode_cursor(), } } } -async fn build_connections( - redis_conn: &mut RedisConnection, - redis_key: K, - connection: &mut connection::Connection, - items: I, - limit: usize, - scores: &HashMap, -) -> GqlResult<()> -where - K: ToRedisArgs + Sync, - I: IntoIterator, - U: HasId, - T: Ranking + OutputType, -{ - connection.edges.reserve(limit); - - for map in items.into_iter().take(limit) { - let score = scores - .get(&map.get_id()) - .copied() - .ok_or_else(|| internal!("missing score entry for ID {}", map.get_id()))?; - let rank: i32 = redis_conn.zrevrank(&redis_key, map.get_id()).await?; - connection.edges.push(connection::Edge::new( - ID(F64Cursor(score).encode_cursor()), - T::from_node(rank + 1, score, map), - )); +impl IntoExprTuple for &PlayerMapRankingCursor { + fn into_expr_tuple(self) -> crate::cursors::expr_tuple::ExprTuple { + match self { + PlayerMapRankingCursor::Name(name) => IntoExprTuple::into_expr_tuple(name), + PlayerMapRankingCursor::Score(score) => IntoExprTuple::into_expr_tuple(score), + } } +} - Ok(()) +impl IntoValueTuple for &PlayerMapRankingCursor { + fn into_value_tuple(self) -> sea_orm::sea_query::ValueTuple { + match self { + PlayerMapRankingCursor::Name(name) => IntoValueTuple::into_value_tuple(name), + PlayerMapRankingCursor::Score(score) => IntoValueTuple::into_value_tuple(score), + } + } } -trait Ranking { - type Item; +pub(crate) type PlayersConnectionInput = + ConnectionInput; - fn from_node(rank: i32, score: f64, node: Self::Item) -> Self; +#[derive(FromQueryResult)] +struct PlayerWithUnstyledName { + #[sea_orm(nested)] + player: players::Model, + unstyled_player_name: String, } -trait HasId { - fn get_id(&self) -> u32; -} +pub(crate) async fn get_players_connection( + conn: &C, + redis_conn: &mut RedisConnection, + input: PlayersConnectionInput, +) -> GqlResult> +where + C: ConnectionTrait, + S: ToRedisArgs + Send + Sync, +{ + let pagination_input = PaginationInput::try_from_input(input.connection_parameters)?; + let cursor_encoder = match input.sort.map(|s| s.field) { + Some(PlayerMapRankingSortableField::Name) => |player: &PlayerWithUnstyledName| { + TextCursor { + text: player.unstyled_player_name.clone(), + data: player.player.id, + } + .encode_cursor() + }, + _ => |player: &PlayerWithUnstyledName| { + F64Cursor { + score: player.player.score, + data: player.player.id, + } + .encode_cursor() + }, + }; -impl HasId for maps::Model { - fn get_id(&self) -> u32 { - self.id + let mut query = players::Entity::find().expr_as( + Func::cust("rm_mp_style").arg(Expr::col((players::Entity, players::Column::Name))), + "unstyled_player_name", + ); + let query = SelectStatement::new() + .column(Asterisk) + .from_subquery(QuerySelect::query(&mut query).take(), "player") + .apply_if(input.filter, |query, filter| { + query + .apply_if(filter.player_login, |query, login| { + query.and_where( + Expr::col(("player", players::Column::Login)).like(format!("%{login}%")), + ); + }) + .apply_if(filter.player_name, |query, name| { + query.and_where( + Expr::col(("player", "unstyled_player_name")).like(format!("%{name}%")), + ); + }); + }) + .take(); + + let mut query = match ( + pagination_input.get_cursor(), + input.sort.as_ref().map(|s| s.field), + ) { + (Some(PlayerMapRankingCursor::Name(_)), _) + | (None, Some(PlayerMapRankingSortableField::Name)) => { + CursorQueryBuilder::>::new( + query, + "player".into_iden(), + Identity::Binary( + "unstyled_player_name".into_iden(), + players::Column::Id.into_iden(), + ), + ) + } + _ => CursorQueryBuilder::new( + query, + "player".into_iden(), + (players::Column::Score, players::Column::Id), + ), } -} + .into_model::(); -impl HasId for players::Model { - fn get_id(&self) -> u32 { - self.id - } -} + apply_cursor_input(&mut query, &pagination_input); -impl Ranking for MapWithScore { - type Item = maps::Model; + match input.sort.and_then(|s| s.order) { + Some(SortOrder::Descending) => query.desc(), + _ => query.asc(), + }; - fn from_node(rank: i32, score: f64, node: Self::Item) -> Self { - Self { - rank, - score, - map: node.into(), - } + let PaginationResult { + mut connection, + iter: players, + } = get_paginated(conn, query, &pagination_input).await?; + + connection.edges.reserve(players.len()); + let source = custom_source_or(input.source, player_ranking); + + for player in players { + let rank: i32 = redis_conn.zrevrank(&source, player.player.id).await?; + connection.edges.push(connection::Edge::new( + ID((cursor_encoder)(&player)), + PlayerWithScore { + rank: rank + 1, + player: player.player.into(), + }, + )); } + + Ok(connection) } -impl Ranking for PlayerWithScore { - type Item = players::Model; +pub(crate) type MapsConnectionInput = + ConnectionInput; - fn from_node(rank: i32, score: f64, node: Self::Item) -> Self { - Self { - rank, - score, - player: node.into(), - } - } +#[derive(FromQueryResult)] +struct MapWithUnstyledName { + #[sea_orm(nested)] + map: maps::Model, + unstyled_map_name: String, } -async fn build_scores( - redis_conn: &mut deadpool_redis::Connection, - redis_key: K, - after: Option, - before: Option, - first: Option, - last: Option, -) -> Result, ApiGqlError> +pub(crate) async fn get_maps_connection( + conn: &C, + redis_conn: &mut RedisConnection, + input: MapsConnectionInput, +) -> GqlResult> where - K: ToRedisArgs + Sync + fmt::Display, + C: ConnectionTrait, + S: ToRedisArgs + Send + Sync, { - let ids: Vec = match after { - Some(after) => { - let decoded = F64Cursor::decode_cursor(&after.0) - .map_err(|e| ApiGqlError::from_cursor_decode_error("after", after.0, e))?; - let first = first.map(|f| f.min(100)).unwrap_or(50); - redis_conn - .zrevrangebyscore_limit_withscores(&redis_key, decoded.0, "-inf", 1, first as _) - .await? - } - None => match before { - Some(before) => { - let decoded = F64Cursor::decode_cursor(&before.0) - .map_err(|e| ApiGqlError::from_cursor_decode_error("before", before.0, e))?; - let last = last.map(|l| l.min(100)).unwrap_or(50); - redis_conn - .zrangebyscore_limit_withscores(&redis_key, decoded.0, "+inf", 1, last as _) - .await? + let pagination_input = PaginationInput::try_from_input(input.connection_parameters)?; + let cursor_encoder = match input.sort.map(|s| s.field) { + Some(PlayerMapRankingSortableField::Name) => |map: &MapWithUnstyledName| { + TextCursor { + text: map.unstyled_map_name.clone(), + data: map.map.id, } - None => match last { - Some(last) => { - redis_conn - .zrange_withscores(&redis_key, 0, last.min(100) as isize - 1) - .await? - } - None => { - let first = first.map(|f| f.min(100)).unwrap_or(50); - redis_conn - .zrevrange_withscores(&redis_key, 0, first as isize - 1) - .await? - } - }, + .encode_cursor() + }, + _ => |map: &MapWithUnstyledName| { + F64Cursor { + score: map.map.score, + data: map.map.id, + } + .encode_cursor() }, }; - let (ids, _) = ids.as_chunks::<2>(); - let with_scores = ids - .iter() - .map(|[id, score]| { - let id = id - .parse::() - .map_err(|e| internal!("got invalid ID `{id}` in {redis_key} ZSET: {e}"))?; - let score = score - .parse::() - .map_err(|e| internal!("got invalid score `{score}` in {redis_key} ZSET: {e}"))?; - GqlResult::Ok((id, score)) + + let mut query = maps::Entity::find().expr_as( + Func::cust("rm_mp_style").arg(Expr::col((maps::Entity, maps::Column::Name))), + "unstyled_map_name", + ); + let query = SelectStatement::new() + .column(Asterisk) + .from_subquery(QuerySelect::query(&mut query).take(), "map") + .apply_if(input.filter, |query, filter| { + query + .apply_if(filter.author, |query, filter| { + query + .join_as( + sea_orm::JoinType::InnerJoin, + players::Entity, + "author", + Expr::col(("author", players::Column::Id)) + .eq(Expr::col(("map", maps::Column::PlayerId))), + ) + .apply_if(filter.player_login, |query, login| { + query.and_where( + Expr::col(("author", players::Column::Login)) + .like(format!("%{login}%")), + ); + }) + .apply_if(filter.player_name, |query, name| { + query.and_where( + Func::cust("rm_mp_style") + .arg(Expr::col(("author", players::Column::Name))) + .like(format!("%{name}%")), + ); + }); + }) + .apply_if(filter.map_uid, |query, uid| { + query.and_where( + Expr::col(("map", maps::Column::GameId)).like(format!("%{uid}%")), + ); + }) + .apply_if(filter.map_name, |query, name| { + query.and_where( + Expr::col(("map", "unstyled_map_name")).like(format!("%{name}%")), + ); + }); }) - .collect::>>()?; - Ok(with_scores) + .take(); + + let mut query = match ( + pagination_input.get_cursor(), + input.sort.as_ref().map(|s| s.field), + ) { + (Some(PlayerMapRankingCursor::Name(_)), _) + | (_, Some(PlayerMapRankingSortableField::Name)) => { + CursorQueryBuilder::>::new( + query, + "map".into_iden(), + Identity::Binary( + "unstyled_map_name".into_iden(), + maps::Column::Id.into_iden(), + ), + ) + } + _ => CursorQueryBuilder::new( + query, + "map".into_iden(), + (maps::Column::Score, maps::Column::Id), + ), + } + .into_model::(); + + apply_cursor_input(&mut query, &pagination_input); + + match input.sort.and_then(|s| s.order) { + Some(SortOrder::Descending) => query.desc(), + _ => query.asc(), + }; + + let PaginationResult { + mut connection, + iter: maps, + } = get_paginated(conn, query, &pagination_input).await?; + + connection.edges.reserve(maps.len()); + let source = custom_source_or(input.source, map_ranking); + + for map in maps { + let rank: i32 = redis_conn.zrevrank(&source, map.map.id).await?; + connection.edges.push(connection::Edge::new( + ID((cursor_encoder)(&map)), + MapWithScore { + rank: rank + 1, + map: map.map.into(), + }, + )); + } + + Ok(connection) } diff --git a/crates/graphql-api/src/objects/sort.rs b/crates/graphql-api/src/objects/sort.rs index fff3467..f4c77d7 100644 --- a/crates/graphql-api/src/objects/sort.rs +++ b/crates/graphql-api/src/objects/sort.rs @@ -2,17 +2,25 @@ use async_graphql::InputObject; use crate::objects::{ sort_order::SortOrder, - sortable_fields::{MapRecordSortableField, UnorderedRecordSortableField}, + sortable_fields::{ + MapRecordSortableField, PlayerMapRankingSortableField, UnorderedRecordSortableField, + }, }; -#[derive(InputObject, Clone)] +#[derive(Debug, InputObject, Clone, Copy)] pub(crate) struct UnorderedRecordSort { pub field: UnorderedRecordSortableField, pub order: Option, } -#[derive(InputObject, Clone)] +#[derive(Debug, InputObject, Clone, Copy)] pub(crate) struct MapRecordSort { pub field: MapRecordSortableField, pub order: Option, } + +#[derive(Debug, InputObject, Clone, Copy)] +pub(crate) struct PlayerMapRankingSort { + pub field: PlayerMapRankingSortableField, + pub order: Option, +} diff --git a/crates/graphql-api/src/objects/sort_order.rs b/crates/graphql-api/src/objects/sort_order.rs index e0cc01c..f5a2e98 100644 --- a/crates/graphql-api/src/objects/sort_order.rs +++ b/crates/graphql-api/src/objects/sort_order.rs @@ -1,6 +1,6 @@ use async_graphql::Enum; -#[derive(Clone, Copy, PartialEq, PartialOrd, Eq, Ord, Enum)] +#[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Eq, Ord, Enum)] pub(crate) enum SortOrder { Ascending, Descending, diff --git a/crates/graphql-api/src/objects/sortable_fields.rs b/crates/graphql-api/src/objects/sortable_fields.rs index 3309f5c..480bfd8 100644 --- a/crates/graphql-api/src/objects/sortable_fields.rs +++ b/crates/graphql-api/src/objects/sortable_fields.rs @@ -1,12 +1,18 @@ use async_graphql::Enum; -#[derive(Clone, Copy, PartialEq, PartialOrd, Eq, Ord, Enum)] +#[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Eq, Ord, Enum)] pub(crate) enum UnorderedRecordSortableField { Date, } -#[derive(Clone, Copy, PartialEq, PartialOrd, Eq, Ord, Enum)] +#[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Eq, Ord, Enum)] pub(crate) enum MapRecordSortableField { Date, Rank, } + +#[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Eq, Ord, Enum)] +pub(crate) enum PlayerMapRankingSortableField { + Name, + Rank, +} diff --git a/crates/graphql-api/src/tests/maps_records_connection.rs b/crates/graphql-api/src/tests/maps_records_connection.rs new file mode 100644 index 0000000..0b6f23f --- /dev/null +++ b/crates/graphql-api/src/tests/maps_records_connection.rs @@ -0,0 +1,1370 @@ +use mkenv::prelude::*; +use std::time::Duration; + +use async_graphql::connection::CursorType; +use chrono::SubsecRound; +use entity::{maps, players, records}; +use itertools::Itertools; +use sea_orm::{ActiveValue::Set, EntityTrait}; + +use crate::{ + config::InitError, + cursors::{ConnectionParameters, RecordDateCursor, RecordRankCursor}, + objects::{ + map::get_map_records_connection, sort::MapRecordSort, sort_order::SortOrder, + sortable_fields::MapRecordSortableField, + }, +}; + +fn setup() { + match crate::init_config() { + Ok(_) | Err(InitError::ConfigAlreadySet) => (), + Err(InitError::Config(e)) => { + panic!("error during test setup: {e}"); + } + } +} + +#[derive(Debug, PartialEq)] +struct Record { + record_id: u32, + cursor: String, + rank: i32, + map_id: u32, + player_id: u32, + time: i32, + record_date: chrono::DateTime, + flags: u32, + respawn_count: i32, +} + +#[tokio::test] +async fn default_page() -> anyhow::Result<()> { + setup(); + + let default_limit = crate::config().cursor_default_limit.get(); + let record_amount = default_limit * 2; + + let players = (1..=record_amount).map(|i| players::ActiveModel { + id: Set(i as _), + login: Set(format!("player_{i}_login")), + name: Set(format!("player_{i}_name")), + role: Set(0), + ..Default::default() + }); + + let map_id = test_env::get_map_id(); + let map = maps::ActiveModel { + id: Set(map_id), + game_id: Set("map_uid".to_owned()), + name: Set("map_name".to_owned()), + player_id: Set(1), + ..Default::default() + }; + + // The higher the record ID, the less recent the record + let now = chrono::Utc::now().naive_utc().trunc_subsecs(0); + let record_dates = (0..record_amount) + .map(|i| now - Duration::from_secs(3600 * (i as u64 + 1))) + .collect_vec(); + + let records = (0..record_amount).map(|i| records::ActiveModel { + record_id: Set((i + 1) as _), + map_id: Set(map_id), + record_player_id: Set((i + 1) as _), + flags: Set(682), + time: Set((1000 * (i + 1)) as _), + respawn_count: Set(0), + record_date: Set(record_dates[i]), + ..Default::default() + }); + + test_env::wrap(async |db| { + players::Entity::insert_many(players) + .exec(&db.sql_conn) + .await?; + maps::Entity::insert(map).exec(&db.sql_conn).await?; + records::Entity::insert_many(records) + .exec(&db.sql_conn) + .await?; + + let result = get_map_records_connection( + &db.sql_conn, + &db.redis_pool, + map_id, + Default::default(), + Default::default(), + Default::default(), + Default::default(), + ) + .await?; + + itertools::assert_equal( + result.edges.into_iter().map(|edge| Record { + record_id: edge.node.inner.record.record_id, + cursor: edge.cursor.0, + rank: edge.node.inner.rank, + map_id: edge.node.inner.record.map_id, + player_id: edge.node.inner.record.record_player_id, + flags: edge.node.inner.record.flags, + record_date: edge.node.inner.record.record_date.and_utc(), + respawn_count: edge.node.inner.record.respawn_count, + time: edge.node.inner.record.time, + }), + (0..default_limit).map(|i| { + let record_date = record_dates[i].and_utc(); + Record { + record_id: i as u32 + 1, + cursor: RecordRankCursor { + time: (i as i32 + 1) * 1000, + record_date, + data: i as u32 + 1, + } + .encode_cursor(), + rank: i as i32 + 1, + map_id, + flags: 682, + player_id: i as u32 + 1, + record_date, + respawn_count: 0, + time: 1000 * (i as i32 + 1), + } + }), + ); + + assert!(result.has_next_page); + assert!(!result.has_previous_page); + + anyhow::Ok(()) + }) + .await +} + +#[tokio::test] +async fn default_page_desc() -> anyhow::Result<()> { + setup(); + + let default_limit = crate::config().cursor_default_limit.get(); + let record_amount = default_limit * 2; + + let players = (1..=record_amount).map(|i| players::ActiveModel { + id: Set(i as _), + login: Set(format!("player_{i}_login")), + name: Set(format!("player_{i}_name")), + role: Set(0), + ..Default::default() + }); + + let map_id = test_env::get_map_id(); + let map = maps::ActiveModel { + id: Set(map_id), + game_id: Set("map_uid".to_owned()), + name: Set("map_name".to_owned()), + player_id: Set(1), + ..Default::default() + }; + + // The higher the record ID, the less recent the record + let now = chrono::Utc::now().naive_utc().trunc_subsecs(0); + let record_dates = (0..record_amount) + .map(|i| now - Duration::from_secs(3600 * (i as u64 + 1))) + .collect_vec(); + + let records = (0..record_amount).map(|i| records::ActiveModel { + record_id: Set((i + 1) as _), + map_id: Set(map_id), + record_player_id: Set((i + 1) as _), + flags: Set(682), + time: Set((1000 * (i + 1)) as _), + respawn_count: Set(0), + record_date: Set(record_dates[i]), + ..Default::default() + }); + + test_env::wrap(async |db| { + players::Entity::insert_many(players) + .exec(&db.sql_conn) + .await?; + maps::Entity::insert(map).exec(&db.sql_conn).await?; + records::Entity::insert_many(records) + .exec(&db.sql_conn) + .await?; + + let result = get_map_records_connection( + &db.sql_conn, + &db.redis_pool, + map_id, + Default::default(), + Default::default(), + Some(MapRecordSort { + field: MapRecordSortableField::Rank, + order: Some(SortOrder::Descending), + }), + Default::default(), + ) + .await?; + + itertools::assert_equal( + result.edges.into_iter().map(|edge| Record { + record_id: edge.node.inner.record.record_id, + cursor: edge.cursor.0, + rank: edge.node.inner.rank, + map_id: edge.node.inner.record.map_id, + player_id: edge.node.inner.record.record_player_id, + flags: edge.node.inner.record.flags, + record_date: edge.node.inner.record.record_date.and_utc(), + respawn_count: edge.node.inner.record.respawn_count, + time: edge.node.inner.record.time, + }), + (0..default_limit).map(|i| { + let i = record_amount - 1 - i; + let record_date = record_dates[i].and_utc(); + Record { + record_id: i as u32 + 1, + cursor: RecordRankCursor { + time: (i as i32 + 1) * 1000, + record_date, + data: i as u32 + 1, + } + .encode_cursor(), + rank: i as i32 + 1, + map_id, + flags: 682, + player_id: i as u32 + 1, + record_date, + respawn_count: 0, + time: 1000 * (i as i32 + 1), + } + }), + ); + + assert!(result.has_next_page); + assert!(!result.has_previous_page); + + anyhow::Ok(()) + }) + .await +} + +#[tokio::test] +async fn default_page_date() -> anyhow::Result<()> { + setup(); + + let default_limit = crate::config().cursor_default_limit.get(); + let record_amount = default_limit * 2; + + let players = (1..=record_amount).map(|i| players::ActiveModel { + id: Set(i as _), + login: Set(format!("player_{i}_login")), + name: Set(format!("player_{i}_name")), + role: Set(0), + ..Default::default() + }); + + let map_id = test_env::get_map_id(); + let map = maps::ActiveModel { + id: Set(map_id), + game_id: Set("map_uid".to_owned()), + name: Set("map_name".to_owned()), + player_id: Set(1), + ..Default::default() + }; + + // The higher the record ID, the less recent the record + let now = chrono::Utc::now().naive_utc().trunc_subsecs(0); + let record_dates = (0..record_amount) + .map(|i| now - Duration::from_secs(3600 * (i as u64 + 1))) + .collect_vec(); + + let records = (0..record_amount).map(|i| records::ActiveModel { + record_id: Set((i + 1) as _), + map_id: Set(map_id), + record_player_id: Set((i + 1) as _), + flags: Set(682), + time: Set((1000 * (i + 1)) as _), + respawn_count: Set(0), + record_date: Set(record_dates[i]), + ..Default::default() + }); + + test_env::wrap(async |db| { + players::Entity::insert_many(players) + .exec(&db.sql_conn) + .await?; + maps::Entity::insert(map).exec(&db.sql_conn).await?; + records::Entity::insert_many(records) + .exec(&db.sql_conn) + .await?; + + let result = get_map_records_connection( + &db.sql_conn, + &db.redis_pool, + map_id, + Default::default(), + Default::default(), + Some(MapRecordSort { + field: MapRecordSortableField::Date, + order: Default::default(), + }), + Default::default(), + ) + .await?; + + let mut expected = (0..default_limit) + .map(|i| { + let i = record_amount - 1 - i; + let record_date = record_dates[i].and_utc(); + Record { + record_id: i as u32 + 1, + cursor: RecordDateCursor { + record_date, + data: i as u32 + 1, + } + .encode_cursor(), + rank: i as i32 + 1, + map_id, + flags: 682, + player_id: i as u32 + 1, + record_date, + respawn_count: 0, + time: 1000 * (i as i32 + 1), + } + }) + .collect_vec(); + + expected.sort_by_key(|record| record.record_date); + + itertools::assert_equal( + result.edges.into_iter().map(|edge| Record { + record_id: edge.node.inner.record.record_id, + cursor: edge.cursor.0, + rank: edge.node.inner.rank, + map_id: edge.node.inner.record.map_id, + player_id: edge.node.inner.record.record_player_id, + flags: edge.node.inner.record.flags, + record_date: edge.node.inner.record.record_date.and_utc(), + respawn_count: edge.node.inner.record.respawn_count, + time: edge.node.inner.record.time, + }), + expected, + ); + + assert!(result.has_next_page); + assert!(!result.has_previous_page); + + anyhow::Ok(()) + }) + .await +} + +#[tokio::test] +async fn default_page_date_desc() -> anyhow::Result<()> { + setup(); + + let default_limit = crate::config().cursor_default_limit.get(); + let record_amount = default_limit * 2; + + let players = (1..=record_amount).map(|i| players::ActiveModel { + id: Set(i as _), + login: Set(format!("player_{i}_login")), + name: Set(format!("player_{i}_name")), + role: Set(0), + ..Default::default() + }); + + let map_id = test_env::get_map_id(); + let map = maps::ActiveModel { + id: Set(map_id), + game_id: Set("map_uid".to_owned()), + name: Set("map_name".to_owned()), + player_id: Set(1), + ..Default::default() + }; + + // The higher the record ID, the less recent the record + let now = chrono::Utc::now().naive_utc().trunc_subsecs(0); + let record_dates = (0..record_amount) + .map(|i| now - Duration::from_secs(3600 * (i as u64 + 1))) + .collect_vec(); + + let records = (0..record_amount).map(|i| records::ActiveModel { + record_id: Set((i + 1) as _), + map_id: Set(map_id), + record_player_id: Set((i + 1) as _), + flags: Set(682), + time: Set((1000 * (i + 1)) as _), + respawn_count: Set(0), + record_date: Set(record_dates[i]), + ..Default::default() + }); + + test_env::wrap(async |db| { + players::Entity::insert_many(players) + .exec(&db.sql_conn) + .await?; + maps::Entity::insert(map).exec(&db.sql_conn).await?; + records::Entity::insert_many(records) + .exec(&db.sql_conn) + .await?; + + let result = get_map_records_connection( + &db.sql_conn, + &db.redis_pool, + map_id, + Default::default(), + Default::default(), + Some(MapRecordSort { + field: MapRecordSortableField::Date, + order: Some(SortOrder::Descending), + }), + Default::default(), + ) + .await?; + + let mut expected = (0..default_limit) + .map(|i| { + let record_date = record_dates[i].and_utc(); + Record { + record_id: i as u32 + 1, + cursor: RecordDateCursor { + record_date, + data: i as u32 + 1, + } + .encode_cursor(), + rank: i as i32 + 1, + map_id, + flags: 682, + player_id: i as u32 + 1, + record_date, + respawn_count: 0, + time: 1000 * (i as i32 + 1), + } + }) + .collect_vec(); + + expected.sort_by(|a, b| a.record_date.cmp(&b.record_date).reverse()); + + itertools::assert_equal( + result.edges.into_iter().map(|edge| Record { + record_id: edge.node.inner.record.record_id, + cursor: edge.cursor.0, + rank: edge.node.inner.rank, + map_id: edge.node.inner.record.map_id, + player_id: edge.node.inner.record.record_player_id, + flags: edge.node.inner.record.flags, + record_date: edge.node.inner.record.record_date.and_utc(), + respawn_count: edge.node.inner.record.respawn_count, + time: edge.node.inner.record.time, + }), + expected, + ); + + assert!(result.has_next_page); + assert!(!result.has_previous_page); + + anyhow::Ok(()) + }) + .await +} + +#[tracing::instrument] +async fn test_first_x_after_y( + record_amount: usize, + is_desc: bool, + first: Option, + after_idx: usize, + expected_len: usize, + has_next_page: bool, +) -> anyhow::Result<()> { + let players = (1..=record_amount).map(|i| players::ActiveModel { + id: Set(i as _), + login: Set(format!("player_{i}_login")), + name: Set(format!("player_{i}_name")), + role: Set(0), + ..Default::default() + }); + + let map_id = test_env::get_map_id(); + let map = maps::ActiveModel { + id: Set(map_id), + game_id: Set("map_uid".to_owned()), + name: Set("map_name".to_owned()), + player_id: Set(1), + ..Default::default() + }; + + // The higher the record ID, the less recent the record + let now = chrono::Utc::now().naive_utc().trunc_subsecs(0); + let record_dates = (0..record_amount) + .map(|i| now - Duration::from_secs(3600 * (i as u64 + 1))) + .collect_vec(); + + let records = (0..record_amount).map(|i| records::ActiveModel { + record_id: Set((i + 1) as _), + map_id: Set(map_id), + record_player_id: Set((i + 1) as _), + flags: Set(682), + time: Set((1000 * (i + 1)) as _), + respawn_count: Set(0), + record_date: Set(record_dates[i]), + ..Default::default() + }); + + test_env::wrap(async |db| { + players::Entity::insert_many(players) + .exec(&db.sql_conn) + .await?; + maps::Entity::insert(map).exec(&db.sql_conn).await?; + records::Entity::insert_many(records) + .exec(&db.sql_conn) + .await?; + + let result = get_map_records_connection( + &db.sql_conn, + &db.redis_pool, + map_id, + Default::default(), + ConnectionParameters { + first, + after: Some({ + let idx = if is_desc { + record_amount.saturating_sub(1).saturating_sub(after_idx) + } else { + after_idx + }; + RecordRankCursor { + time: (idx as i32 + 1) * 1000, + record_date: record_dates[idx].and_utc(), + data: idx as u32 + 1, + } + .into() + }), + ..Default::default() + }, + is_desc.then_some(MapRecordSort { + field: MapRecordSortableField::Rank, + order: Some(SortOrder::Descending), + }), + Default::default(), + ) + .await?; + + itertools::assert_equal( + result.edges.into_iter().map(|edge| Record { + record_id: edge.node.inner.record.record_id, + cursor: edge.cursor.0, + rank: edge.node.inner.rank, + map_id: edge.node.inner.record.map_id, + player_id: edge.node.inner.record.record_player_id, + flags: edge.node.inner.record.flags, + record_date: edge.node.inner.record.record_date.and_utc(), + respawn_count: edge.node.inner.record.respawn_count, + time: edge.node.inner.record.time, + }), + (0..expected_len).map(|i| { + let i = if is_desc { + record_amount + .saturating_sub(1) + .saturating_sub(after_idx) + .saturating_sub(1) + .saturating_sub(i) + } else { + after_idx + 1 + i + }; + let record_date = record_dates[i].and_utc(); + Record { + record_id: i as u32 + 1, + cursor: RecordRankCursor { + time: (i as i32 + 1) * 1000, + record_date, + data: i as u32 + 1, + } + .encode_cursor(), + rank: i as i32 + 1, + map_id, + flags: 682, + player_id: i as u32 + 1, + record_date, + respawn_count: 0, + time: 1000 * (i as i32 + 1), + } + }), + ); + + assert_eq!(result.has_next_page, has_next_page); + assert!(result.has_previous_page); + + anyhow::Ok(()) + }) + .await +} + +#[tracing::instrument] +async fn test_first_x_after_y_date( + record_amount: usize, + is_desc: bool, + first: Option, + after_idx: usize, + expected_len: usize, + has_next_page: bool, +) -> anyhow::Result<()> { + // Testing the date order is pretty the same as testing the rank order, + // but the order is reversed (higher rank means an earlier date). + // + // So when sorting by date with ascending order, in this test, it's like + // sorting by rank with descending order. + + let players = (1..=record_amount).map(|i| players::ActiveModel { + id: Set(i as _), + login: Set(format!("player_{i}_login")), + name: Set(format!("player_{i}_name")), + role: Set(0), + ..Default::default() + }); + + let map_id = test_env::get_map_id(); + let map = maps::ActiveModel { + id: Set(map_id), + game_id: Set("map_uid".to_owned()), + name: Set("map_name".to_owned()), + player_id: Set(1), + ..Default::default() + }; + + // The higher the record ID, the less recent the record + let now = chrono::Utc::now().naive_utc().trunc_subsecs(0); + let record_dates = (0..record_amount) + .map(|i| now - Duration::from_secs(3600 * (i as u64 + 1))) + .collect_vec(); + + let records = (0..record_amount).map(|i| records::ActiveModel { + record_id: Set((i + 1) as _), + map_id: Set(map_id), + record_player_id: Set((i + 1) as _), + flags: Set(682), + time: Set((1000 * (i + 1)) as _), + respawn_count: Set(0), + record_date: Set(record_dates[i]), + ..Default::default() + }); + + test_env::wrap(async |db| { + players::Entity::insert_many(players) + .exec(&db.sql_conn) + .await?; + maps::Entity::insert(map).exec(&db.sql_conn).await?; + records::Entity::insert_many(records) + .exec(&db.sql_conn) + .await?; + + let result = get_map_records_connection( + &db.sql_conn, + &db.redis_pool, + map_id, + Default::default(), + ConnectionParameters { + first, + after: Some({ + let idx = if is_desc { + after_idx + } else { + record_amount.saturating_sub(1).saturating_sub(after_idx) + }; + RecordDateCursor { + record_date: record_dates[idx].and_utc(), + data: idx as u32 + 1, + } + .into() + }), + ..Default::default() + }, + Some(MapRecordSort { + field: MapRecordSortableField::Date, + order: is_desc.then_some(SortOrder::Descending), + }), + Default::default(), + ) + .await?; + + let mut expected = (0..expected_len) + .map(|i| { + let i = if is_desc { + after_idx + 1 + i + } else { + record_amount + .saturating_sub(1) + .saturating_sub(after_idx) + .saturating_sub(1) + .saturating_sub(i) + }; + let record_date = record_dates[i].and_utc(); + Record { + record_id: i as u32 + 1, + cursor: RecordDateCursor { + record_date, + data: i as u32 + 1, + } + .encode_cursor(), + rank: i as i32 + 1, + map_id, + flags: 682, + player_id: i as u32 + 1, + record_date, + respawn_count: 0, + time: 1000 * (i as i32 + 1), + } + }) + .collect_vec(); + + expected.sort_by(if is_desc { + |a: &Record, b: &Record| a.record_date.cmp(&b.record_date).reverse() + } else { + |a: &Record, b: &Record| a.record_date.cmp(&b.record_date) + }); + + itertools::assert_equal( + result.edges.into_iter().map(|edge| Record { + record_id: edge.node.inner.record.record_id, + cursor: edge.cursor.0, + rank: edge.node.inner.rank, + map_id: edge.node.inner.record.map_id, + player_id: edge.node.inner.record.record_player_id, + flags: edge.node.inner.record.flags, + record_date: edge.node.inner.record.record_date.and_utc(), + respawn_count: edge.node.inner.record.respawn_count, + time: edge.node.inner.record.time, + }), + expected, + ); + + assert_eq!(result.has_next_page, has_next_page); + assert!(result.has_previous_page); + + anyhow::Ok(()) + }) + .await +} + +#[tokio::test] +async fn first_x_after_y() -> anyhow::Result<()> { + setup(); + + test_first_x_after_y(10, false, Some(3), 5, 3, true).await?; + test_first_x_after_y(10, false, Some(3), 6, 3, false).await?; + test_first_x_after_y(10, false, Some(3), 7, 2, false).await?; + + Ok(()) +} + +#[tokio::test] +async fn first_x_after_y_desc() -> anyhow::Result<()> { + setup(); + + test_first_x_after_y(10, true, Some(3), 5, 3, true).await?; + test_first_x_after_y(10, true, Some(3), 6, 3, false).await?; + test_first_x_after_y(10, true, Some(3), 7, 2, false).await?; + + Ok(()) +} + +#[tokio::test] +async fn first_x_after_y_date() -> anyhow::Result<()> { + setup(); + + test_first_x_after_y_date(10, false, Some(3), 5, 3, true).await?; + test_first_x_after_y_date(10, false, Some(3), 6, 3, false).await?; + test_first_x_after_y_date(10, false, Some(4), 6, 3, false).await?; + + Ok(()) +} + +#[tokio::test] +async fn first_x_after_y_date_desc() -> anyhow::Result<()> { + setup(); + + test_first_x_after_y_date(10, true, Some(3), 5, 3, true).await?; + test_first_x_after_y_date(10, true, Some(3), 6, 3, false).await?; + test_first_x_after_y_date(10, true, Some(4), 6, 3, false).await?; + + Ok(()) +} + +#[tokio::test] +async fn after_y() -> anyhow::Result<()> { + setup(); + + let default_limit = crate::config().cursor_default_limit.get(); + let record_amount = default_limit * 2; + + test_first_x_after_y(record_amount, false, None, 5, default_limit, true).await?; + test_first_x_after_y( + record_amount, + false, + None, + default_limit - 1, + default_limit, + false, + ) + .await?; + test_first_x_after_y( + record_amount, + false, + None, + default_limit + 1, + default_limit - 2, + false, + ) + .await?; + + Ok(()) +} + +#[tokio::test] +async fn after_y_desc() -> anyhow::Result<()> { + setup(); + + let default_limit = crate::config().cursor_default_limit.get(); + let record_amount = default_limit * 2; + + test_first_x_after_y(record_amount, true, None, 5, default_limit, true).await?; + test_first_x_after_y( + record_amount, + true, + None, + default_limit - 1, + default_limit, + false, + ) + .await?; + test_first_x_after_y( + record_amount, + true, + None, + default_limit + 1, + default_limit - 2, + false, + ) + .await?; + + Ok(()) +} + +#[tokio::test] +async fn after_y_date() -> anyhow::Result<()> { + setup(); + + let default_limit = crate::config().cursor_default_limit.get(); + let record_amount = default_limit * 2; + + test_first_x_after_y_date(record_amount, false, None, 5, default_limit, true).await?; + test_first_x_after_y_date( + record_amount, + false, + None, + default_limit - 1, + default_limit, + false, + ) + .await?; + test_first_x_after_y_date( + record_amount, + false, + None, + default_limit, + default_limit - 1, + false, + ) + .await?; + + Ok(()) +} + +#[tokio::test] +async fn after_y_date_desc() -> anyhow::Result<()> { + setup(); + + let default_limit = crate::config().cursor_default_limit.get(); + let record_amount = default_limit * 2; + + test_first_x_after_y_date(record_amount, true, None, 5, default_limit, true).await?; + test_first_x_after_y_date( + record_amount, + true, + None, + default_limit - 1, + default_limit, + false, + ) + .await?; + test_first_x_after_y_date( + record_amount, + true, + None, + default_limit, + default_limit - 1, + false, + ) + .await?; + + Ok(()) +} + +#[tracing::instrument] +async fn test_last_x_before_y( + record_amount: usize, + is_desc: bool, + last: Option, + before_idx: usize, + expected_len: usize, + has_previous_page: bool, +) -> anyhow::Result<()> { + let players = (1..=record_amount).map(|i| players::ActiveModel { + id: Set(i as _), + login: Set(format!("player_{i}_login")), + name: Set(format!("player_{i}_name")), + role: Set(0), + ..Default::default() + }); + + let map_id = test_env::get_map_id(); + let map = maps::ActiveModel { + id: Set(map_id), + game_id: Set("map_uid".to_owned()), + name: Set("map_name".to_owned()), + player_id: Set(1), + ..Default::default() + }; + + // The higher the record ID, the less recent the record + let now = chrono::Utc::now().naive_utc().trunc_subsecs(0); + let record_dates = (0..record_amount) + .map(|i| now - Duration::from_secs(3600 * (i as u64 + 1))) + .collect_vec(); + + let records = (0..record_amount).map(|i| records::ActiveModel { + record_id: Set((i + 1) as _), + map_id: Set(map_id), + record_player_id: Set((i + 1) as _), + flags: Set(682), + time: Set((1000 * (i + 1)) as _), + respawn_count: Set(0), + record_date: Set(record_dates[i]), + ..Default::default() + }); + + test_env::wrap(async |db| { + players::Entity::insert_many(players) + .exec(&db.sql_conn) + .await?; + maps::Entity::insert(map).exec(&db.sql_conn).await?; + records::Entity::insert_many(records) + .exec(&db.sql_conn) + .await?; + + let result = get_map_records_connection( + &db.sql_conn, + &db.redis_pool, + map_id, + Default::default(), + ConnectionParameters { + last, + before: Some({ + let idx = if is_desc { + record_amount.saturating_sub(1).saturating_sub(before_idx) + } else { + before_idx + }; + RecordRankCursor { + time: (idx as i32 + 1) * 1000, + record_date: record_dates[idx].and_utc(), + data: idx as u32 + 1, + } + .into() + }), + ..Default::default() + }, + is_desc.then_some(MapRecordSort { + field: MapRecordSortableField::Rank, + order: Some(SortOrder::Descending), + }), + Default::default(), + ) + .await?; + + itertools::assert_equal( + result.edges.into_iter().map(|edge| Record { + record_id: edge.node.inner.record.record_id, + cursor: edge.cursor.0, + rank: edge.node.inner.rank, + map_id: edge.node.inner.record.map_id, + player_id: edge.node.inner.record.record_player_id, + flags: edge.node.inner.record.flags, + record_date: edge.node.inner.record.record_date.and_utc(), + respawn_count: edge.node.inner.record.respawn_count, + time: edge.node.inner.record.time, + }), + (0..expected_len).map(|i| { + let i = if is_desc { + record_amount - 1 - (i + before_idx - expected_len) + } else { + i + before_idx - expected_len + }; + let record_date = record_dates[i].and_utc(); + Record { + record_id: i as u32 + 1, + cursor: RecordRankCursor { + time: (i as i32 + 1) * 1000, + record_date, + data: i as u32 + 1, + } + .encode_cursor(), + rank: i as i32 + 1, + map_id, + flags: 682, + player_id: i as u32 + 1, + record_date, + respawn_count: 0, + time: 1000 * (i as i32 + 1), + } + }), + ); + + assert!(result.has_next_page); + assert_eq!(result.has_previous_page, has_previous_page); + + anyhow::Ok(()) + }) + .await +} + +#[tracing::instrument] +async fn test_last_x_before_y_date( + record_amount: usize, + is_desc: bool, + last: Option, + before_idx: usize, + expected_len: usize, + has_previous_page: bool, +) -> anyhow::Result<()> { + let players = (1..=record_amount).map(|i| players::ActiveModel { + id: Set(i as _), + login: Set(format!("player_{i}_login")), + name: Set(format!("player_{i}_name")), + role: Set(0), + ..Default::default() + }); + + let map_id = test_env::get_map_id(); + let map = maps::ActiveModel { + id: Set(map_id), + game_id: Set("map_uid".to_owned()), + name: Set("map_name".to_owned()), + player_id: Set(1), + ..Default::default() + }; + + // The higher the record ID, the less recent the record + let now = chrono::Utc::now().naive_utc().trunc_subsecs(0); + let record_dates = (0..record_amount) + .map(|i| now - Duration::from_secs(3600 * (i as u64 + 1))) + .collect_vec(); + + let records = (0..record_amount).map(|i| records::ActiveModel { + record_id: Set((i + 1) as _), + map_id: Set(map_id), + record_player_id: Set((i + 1) as _), + flags: Set(682), + time: Set((1000 * (i + 1)) as _), + respawn_count: Set(0), + record_date: Set(record_dates[i]), + ..Default::default() + }); + + test_env::wrap(async |db| { + players::Entity::insert_many(players) + .exec(&db.sql_conn) + .await?; + maps::Entity::insert(map).exec(&db.sql_conn).await?; + records::Entity::insert_many(records) + .exec(&db.sql_conn) + .await?; + + let result = get_map_records_connection( + &db.sql_conn, + &db.redis_pool, + map_id, + Default::default(), + ConnectionParameters { + last, + before: Some({ + let idx = if is_desc { + before_idx + } else { + record_amount.saturating_sub(1).saturating_sub(before_idx) + }; + RecordDateCursor { + record_date: record_dates[idx].and_utc(), + data: idx as u32 + 1, + } + .into() + }), + ..Default::default() + }, + Some(MapRecordSort { + field: MapRecordSortableField::Date, + order: is_desc.then_some(SortOrder::Descending), + }), + Default::default(), + ) + .await?; + + let mut expected = (0..expected_len) + .map(|i| { + let i = if is_desc { + i + before_idx - expected_len + } else { + record_amount - 1 - (i + before_idx - expected_len) + }; + let record_date = record_dates[i].and_utc(); + Record { + record_id: i as u32 + 1, + cursor: RecordDateCursor { + record_date, + data: i as u32 + 1, + } + .encode_cursor(), + rank: i as i32 + 1, + map_id, + flags: 682, + player_id: i as u32 + 1, + record_date, + respawn_count: 0, + time: 1000 * (i as i32 + 1), + } + }) + .collect_vec(); + + expected.sort_by(if is_desc { + |a: &Record, b: &Record| a.record_date.cmp(&b.record_date).reverse() + } else { + |a: &Record, b: &Record| a.record_date.cmp(&b.record_date) + }); + + itertools::assert_equal( + result.edges.into_iter().map(|edge| Record { + record_id: edge.node.inner.record.record_id, + cursor: edge.cursor.0, + rank: edge.node.inner.rank, + map_id: edge.node.inner.record.map_id, + player_id: edge.node.inner.record.record_player_id, + flags: edge.node.inner.record.flags, + record_date: edge.node.inner.record.record_date.and_utc(), + respawn_count: edge.node.inner.record.respawn_count, + time: edge.node.inner.record.time, + }), + expected, + ); + + assert!(result.has_next_page); + assert_eq!(result.has_previous_page, has_previous_page); + + anyhow::Ok(()) + }) + .await +} + +#[tokio::test] +async fn last_x_before_y() -> anyhow::Result<()> { + setup(); + + test_last_x_before_y(10, false, Some(3), 4, 3, true).await?; + test_last_x_before_y(10, false, Some(3), 3, 3, false).await?; + test_last_x_before_y(10, false, Some(3), 2, 2, false).await?; + + Ok(()) +} + +#[tokio::test] +async fn last_x_before_y_desc() -> anyhow::Result<()> { + setup(); + + test_last_x_before_y(10, true, Some(3), 4, 3, true).await?; + test_last_x_before_y(10, true, Some(3), 3, 3, false).await?; + test_last_x_before_y(10, true, Some(3), 2, 2, false).await?; + + Ok(()) +} + +#[tokio::test] +async fn last_x_before_y_date() -> anyhow::Result<()> { + setup(); + + test_last_x_before_y_date(10, false, Some(3), 4, 3, true).await?; + test_last_x_before_y_date(10, false, Some(4), 4, 4, false).await?; + test_last_x_before_y_date(10, false, Some(3), 2, 2, false).await?; + + Ok(()) +} + +#[tokio::test] +async fn last_x_before_y_date_desc() -> anyhow::Result<()> { + setup(); + + test_last_x_before_y_date(10, true, Some(3), 4, 3, true).await?; + test_last_x_before_y_date(10, true, Some(4), 4, 4, false).await?; + test_last_x_before_y_date(10, true, Some(3), 2, 2, false).await?; + + Ok(()) +} + +#[tokio::test] +async fn before_x() -> anyhow::Result<()> { + setup(); + + let default_limit = crate::config().cursor_default_limit.get(); + let record_amount = default_limit * 2; + + test_last_x_before_y( + record_amount, + false, + None, + record_amount - 10, + default_limit, + true, + ) + .await?; + test_last_x_before_y( + record_amount, + false, + None, + default_limit, + default_limit, + false, + ) + .await?; + test_last_x_before_y( + record_amount, + false, + None, + default_limit - 1, + default_limit - 1, + false, + ) + .await?; + + Ok(()) +} + +#[tokio::test] +async fn before_x_desc() -> anyhow::Result<()> { + setup(); + + let default_limit = crate::config().cursor_default_limit.get(); + let record_amount = default_limit * 2; + + test_last_x_before_y( + record_amount, + true, + None, + record_amount - 10, + default_limit, + true, + ) + .await?; + test_last_x_before_y( + record_amount, + true, + None, + default_limit, + default_limit, + false, + ) + .await?; + test_last_x_before_y( + record_amount, + true, + None, + default_limit - 1, + default_limit - 1, + false, + ) + .await?; + + Ok(()) +} + +#[tokio::test] +async fn before_x_date() -> anyhow::Result<()> { + setup(); + + let default_limit = crate::config().cursor_default_limit.get(); + let record_amount = default_limit * 2; + + test_last_x_before_y_date( + record_amount, + false, + None, + default_limit + 1, + default_limit, + true, + ) + .await?; + test_last_x_before_y_date( + record_amount, + false, + None, + default_limit, + default_limit, + false, + ) + .await?; + test_last_x_before_y_date( + record_amount, + false, + None, + default_limit - 1, + default_limit - 1, + false, + ) + .await?; + + Ok(()) +} + +#[tokio::test] +async fn before_x_date_desc() -> anyhow::Result<()> { + setup(); + + let default_limit = crate::config().cursor_default_limit.get(); + let record_amount = default_limit * 2; + + test_last_x_before_y_date( + record_amount, + true, + None, + default_limit + 1, + default_limit, + true, + ) + .await?; + test_last_x_before_y_date( + record_amount, + true, + None, + default_limit, + default_limit, + false, + ) + .await?; + test_last_x_before_y_date( + record_amount, + true, + None, + default_limit - 1, + default_limit - 1, + false, + ) + .await?; + + Ok(()) +} diff --git a/crates/graphql-api/src/tests/mod.rs b/crates/graphql-api/src/tests/mod.rs new file mode 100644 index 0000000..83cfc05 --- /dev/null +++ b/crates/graphql-api/src/tests/mod.rs @@ -0,0 +1,6 @@ +mod queryroot_maps_connection; +mod queryroot_players_connection; +mod queryroot_records_connection; + +mod maps_records_connection; +mod players_records_connection; diff --git a/crates/graphql-api/src/tests/players_records_connection.rs b/crates/graphql-api/src/tests/players_records_connection.rs new file mode 100644 index 0000000..cbe44e2 --- /dev/null +++ b/crates/graphql-api/src/tests/players_records_connection.rs @@ -0,0 +1,876 @@ +use mkenv::prelude::*; +use rand::Rng; +use std::time::Duration; + +use async_graphql::connection::CursorType; +use chrono::SubsecRound; +use entity::{maps, players, records}; +use itertools::Itertools; +use sea_orm::{ActiveValue::Set, EntityTrait}; + +use crate::{ + config::InitError, + cursors::{ConnectionParameters, RecordDateCursor}, + objects::{ + player::get_player_records_connection, sort::UnorderedRecordSort, sort_order::SortOrder, + sortable_fields::UnorderedRecordSortableField, + }, +}; + +fn setup() { + match crate::init_config() { + Ok(_) | Err(InitError::ConfigAlreadySet) => (), + Err(InitError::Config(e)) => { + panic!("error during test setup: {e}"); + } + } +} + +#[derive(Debug, PartialEq)] +struct Record { + record_id: u32, + cursor: String, + rank: i32, + map_id: u32, + player_id: u32, + time: i32, + record_date: chrono::DateTime, + flags: u32, + respawn_count: i32, +} + +#[tracing::instrument] +async fn test_default_page(is_desc: bool) -> anyhow::Result<()> { + let default_limit = crate::config().cursor_default_limit.get(); + let record_amount = default_limit * 2; + + let player = players::ActiveModel { + id: Set(1), + login: Set("boogalogin".to_owned()), + name: Set("booganame".to_owned()), + role: Set(0), + ..Default::default() + }; + + let map_ids = rand::rng() + .sample_iter(rand::distr::StandardUniform) + .take(record_amount) + .collect_vec(); + + let maps = map_ids.iter().map(|map_id| maps::ActiveModel { + id: Set(*map_id), + game_id: Set(format!("map_{map_id}_uid")), + name: Set(format!("map_{map_id}_name")), + player_id: Set(1), + ..Default::default() + }); + + // The higher the record ID, the less recent the record + let now = chrono::Utc::now().naive_utc().trunc_subsecs(0); + let record_dates = (0..record_amount) + .map(|i| now - Duration::from_secs(3600 * (i as u64 + 1))) + .collect_vec(); + + let records = (0..record_amount) + .zip(map_ids.iter()) + .map(|(i, map_id)| records::ActiveModel { + record_id: Set((i + 1) as _), + map_id: Set(*map_id), + record_player_id: Set(1), + flags: Set(682), + time: Set(1000), + respawn_count: Set(0), + record_date: Set(record_dates[i]), + ..Default::default() + }); + + test_env::wrap(async |db| { + players::Entity::insert(player).exec(&db.sql_conn).await?; + maps::Entity::insert_many(maps).exec(&db.sql_conn).await?; + records::Entity::insert_many(records) + .exec(&db.sql_conn) + .await?; + + let result = get_player_records_connection( + &db.sql_conn, + &db.redis_pool, + 1, + Default::default(), + ConnectionParameters { + before: None, + after: None, + first: None, + last: None, + }, + is_desc.then_some(UnorderedRecordSort { + field: UnorderedRecordSortableField::Date, + order: Some(SortOrder::Descending), + }), + None, + ) + .await?; + + itertools::assert_equal( + result.edges.into_iter().map(|edge| Record { + record_id: edge.node.inner.record.record_id, + cursor: edge.cursor.0, + rank: edge.node.inner.rank, + map_id: edge.node.inner.record.map_id, + player_id: edge.node.inner.record.record_player_id, + flags: edge.node.inner.record.flags, + record_date: edge.node.inner.record.record_date.and_utc(), + respawn_count: edge.node.inner.record.respawn_count, + time: edge.node.inner.record.time, + }), + (0..default_limit).map(|i| { + let i = if is_desc { record_amount - 1 - i } else { i }; + let record_date = record_dates[i].and_utc(); + Record { + record_id: i as u32 + 1, + cursor: RecordDateCursor { + record_date, + data: i + 1, + } + .encode_cursor(), + rank: 1, + map_id: map_ids[i], + flags: 682, + player_id: 1, + record_date, + respawn_count: 0, + time: 1000, + } + }), + ); + + assert!(result.has_next_page); + assert!(!result.has_previous_page); + + anyhow::Ok(()) + }) + .await +} + +#[tokio::test] +async fn default_page() -> anyhow::Result<()> { + setup(); + test_default_page(false).await +} + +#[tokio::test] +async fn default_page_desc() -> anyhow::Result<()> { + setup(); + test_default_page(true).await +} + +#[tokio::test] +async fn first_x_after_y() -> anyhow::Result<()> { + setup(); + test_first_x_after_y( + FirstXAfterYParams { + record_amount: 20, + first: Some(6), + after_idx: 4, + is_desc: false, + }, + true, + 6, + ) + .await?; + test_first_x_after_y( + FirstXAfterYParams { + record_amount: 20, + first: Some(4), + after_idx: 4, + is_desc: false, + }, + true, + 4, + ) + .await?; + test_first_x_after_y( + FirstXAfterYParams { + record_amount: 20, + first: Some(4), + after_idx: 14, + is_desc: false, + }, + true, + 4, + ) + .await?; + test_first_x_after_y( + FirstXAfterYParams { + record_amount: 20, + first: Some(5), + after_idx: 14, + is_desc: false, + }, + false, + 5, + ) + .await?; + test_first_x_after_y( + FirstXAfterYParams { + record_amount: 20, + first: Some(4), + after_idx: 17, + is_desc: false, + }, + false, + 2, + ) + .await?; + + Ok(()) +} + +struct FirstXAfterYParams { + record_amount: usize, + first: Option, + after_idx: usize, + is_desc: bool, +} + +#[tracing::instrument( + skip(params), + fields( + record_amount = params.record_amount, + first = params.first, + after_idx = params.after_idx, + is_desc = params.is_desc, + ) +)] +async fn test_first_x_after_y( + params: FirstXAfterYParams, + has_next_page: bool, + expected_len: usize, +) -> anyhow::Result<()> { + let limit = params + .first + .unwrap_or_else(|| crate::config().cursor_default_limit.get()); + + let computed_has_next_page = params.after_idx + 1 + limit < params.record_amount; + if has_next_page != computed_has_next_page { + tracing::warn!( + computed_has_next_page = computed_has_next_page, + "wrong has_next_page", + ); + } + let computed_expected_len = limit.min( + params + .record_amount + .saturating_sub(params.after_idx) + .saturating_sub(1), + ); + if expected_len != computed_expected_len { + tracing::warn!( + computed_expected_len = computed_expected_len, + "wrong expected_len", + ); + } + + let player = players::ActiveModel { + id: Set(1), + login: Set("boogalogin".to_owned()), + name: Set("booganame".to_owned()), + role: Set(0), + ..Default::default() + }; + + let map_ids = rand::rng() + .sample_iter(rand::distr::StandardUniform) + .take(params.record_amount) + .collect_vec(); + + let maps = map_ids.iter().map(|map_id| maps::ActiveModel { + id: Set(*map_id), + game_id: Set(format!("map_{map_id}_uid")), + name: Set(format!("map_{map_id}_name")), + player_id: Set(1), + ..Default::default() + }); + + // The higher the record ID, the less recent the record + let now = chrono::Utc::now().naive_utc().trunc_subsecs(0); + let record_dates = (0..params.record_amount) + .map(|i| now - Duration::from_secs(3600 * (i as u64 + 1))) + .collect_vec(); + + let records = (0..params.record_amount) + .zip(map_ids.iter()) + .map(|(i, map_id)| records::ActiveModel { + record_id: Set((i + 1) as _), + map_id: Set(*map_id), + record_player_id: Set(1), + flags: Set(682), + time: Set(1000), + respawn_count: Set(0), + record_date: Set(record_dates[i]), + ..Default::default() + }); + + test_env::wrap(async |db| { + players::Entity::insert(player).exec(&db.sql_conn).await?; + maps::Entity::insert_many(maps).exec(&db.sql_conn).await?; + records::Entity::insert_many(records) + .exec(&db.sql_conn) + .await?; + + let result = get_player_records_connection( + &db.sql_conn, + &db.redis_pool, + 1, + Default::default(), + ConnectionParameters { + first: params.first, + after: Some({ + let idx = if params.is_desc { + params + .record_amount + .saturating_sub(1) + .saturating_sub(params.after_idx) + } else { + params.after_idx + }; + RecordDateCursor { + record_date: record_dates[idx].and_utc(), + data: idx as u32 + 1, + } + }), + ..Default::default() + }, + params.is_desc.then_some(UnorderedRecordSort { + field: UnorderedRecordSortableField::Date, + order: Some(SortOrder::Descending), + }), + None, + ) + .await?; + + itertools::assert_equal( + result.edges.into_iter().map(|edge| Record { + record_id: edge.node.inner.record.record_id, + cursor: edge.cursor.0, + rank: edge.node.inner.rank, + map_id: edge.node.inner.record.map_id, + player_id: edge.node.inner.record.record_player_id, + flags: edge.node.inner.record.flags, + record_date: edge.node.inner.record.record_date.and_utc(), + respawn_count: edge.node.inner.record.respawn_count, + time: edge.node.inner.record.time, + }), + (0..expected_len).map(|i| { + let i = if params.is_desc { + params + .record_amount + .saturating_sub(1) + .saturating_sub(params.after_idx) + .saturating_sub(1) + .saturating_sub(i) + } else { + params.after_idx + 1 + i + }; + let record_date = record_dates[i].and_utc(); + Record { + record_id: i as u32 + 1, + cursor: RecordDateCursor { + record_date, + data: i as u32 + 1, + } + .encode_cursor(), + rank: 1, + map_id: map_ids[i], + flags: 682, + player_id: 1, + record_date, + respawn_count: 0, + time: 1000, + } + }), + ); + + assert_eq!(result.has_next_page, has_next_page); + assert!(result.has_previous_page); + + anyhow::Ok(()) + }) + .await +} + +struct LastXBeforeYParams { + record_amount: usize, + last: Option, + before_idx: usize, + is_desc: bool, +} + +#[tracing::instrument( + skip(params), + fields( + record_amount = params.record_amount, + last = params.last, + before_idx = params.before_idx, + is_desc = params.is_desc, + ) +)] +async fn test_last_x_before_y( + params: LastXBeforeYParams, + has_previous_page: bool, + expected_len: usize, +) -> anyhow::Result<()> { + let limit = params + .last + .unwrap_or_else(|| crate::config().cursor_default_limit.get()); + + let computed_has_previous_page = limit < params.before_idx; + if has_previous_page != computed_has_previous_page { + tracing::warn!( + computed_has_previous_page = computed_has_previous_page, + "wrong has_previous_page", + ); + } + let computed_expected_len = limit.min(params.before_idx).min(params.record_amount); + if expected_len != computed_expected_len { + tracing::warn!( + computed_expected_len = computed_expected_len, + "wrong expected_len", + ); + } + + let player = players::ActiveModel { + id: Set(1), + login: Set("boogalogin".to_owned()), + name: Set("booganame".to_owned()), + role: Set(0), + ..Default::default() + }; + + let map_ids = rand::rng() + .sample_iter(rand::distr::StandardUniform) + .take(params.record_amount) + .collect_vec(); + + let maps = map_ids.iter().map(|map_id| maps::ActiveModel { + id: Set(*map_id), + game_id: Set(format!("map_{map_id}_uid")), + name: Set(format!("map_{map_id}_name")), + player_id: Set(1), + ..Default::default() + }); + + // The higher the record ID, the less recent the record + let now = chrono::Utc::now().naive_utc().trunc_subsecs(0); + let record_dates = (0..params.record_amount) + .map(|i| now - Duration::from_secs(3600 * (i as u64 + 1))) + .collect_vec(); + + let records = (0..params.record_amount) + .zip(map_ids.iter()) + .map(|(i, map_id)| records::ActiveModel { + record_id: Set((i + 1) as _), + map_id: Set(*map_id), + record_player_id: Set(1), + flags: Set(682), + time: Set(1000), + respawn_count: Set(0), + record_date: Set(record_dates[i]), + ..Default::default() + }); + + test_env::wrap(async |db| { + players::Entity::insert(player).exec(&db.sql_conn).await?; + maps::Entity::insert_many(maps).exec(&db.sql_conn).await?; + records::Entity::insert_many(records) + .exec(&db.sql_conn) + .await?; + + let result = get_player_records_connection( + &db.sql_conn, + &db.redis_pool, + 1, + Default::default(), + ConnectionParameters { + before: Some({ + let idx = if params.is_desc { + params + .record_amount + .saturating_sub(1) + .saturating_sub(params.before_idx) + } else { + params.before_idx + }; + RecordDateCursor { + record_date: record_dates[idx].and_utc(), + data: idx as u32 + 1, + } + }), + after: None, + first: None, + last: params.last, + }, + params.is_desc.then_some(UnorderedRecordSort { + field: UnorderedRecordSortableField::Date, + order: Some(SortOrder::Descending), + }), + None, + ) + .await?; + + itertools::assert_equal( + result.edges.into_iter().map(|edge| Record { + record_id: edge.node.inner.record.record_id, + cursor: edge.cursor.0, + rank: edge.node.inner.rank, + map_id: edge.node.inner.record.map_id, + player_id: edge.node.inner.record.record_player_id, + flags: edge.node.inner.record.flags, + record_date: edge.node.inner.record.record_date.and_utc(), + respawn_count: edge.node.inner.record.respawn_count, + time: edge.node.inner.record.time, + }), + (0..expected_len).map(|i| { + let i = if params.is_desc { + params.record_amount - 1 - (i + params.before_idx - expected_len) + } else { + i + params.before_idx - expected_len + }; + let record_date = record_dates[i].and_utc(); + Record { + record_id: i as u32 + 1, + cursor: RecordDateCursor { + record_date, + data: i as u32 + 1, + } + .encode_cursor(), + rank: 1, + map_id: map_ids[i], + flags: 682, + player_id: 1, + record_date, + respawn_count: 0, + time: 1000, + } + }), + ); + + assert!(result.has_next_page); + assert_eq!(result.has_previous_page, has_previous_page); + + anyhow::Ok(()) + }) + .await +} + +#[tokio::test] +async fn first_x_after_y_desc() -> anyhow::Result<()> { + setup(); + + test_first_x_after_y( + FirstXAfterYParams { + record_amount: 20, + first: Some(6), + after_idx: 4, + is_desc: true, + }, + true, + 6, + ) + .await?; + test_first_x_after_y( + FirstXAfterYParams { + record_amount: 20, + first: Some(4), + after_idx: 4, + is_desc: true, + }, + true, + 4, + ) + .await?; + test_first_x_after_y( + FirstXAfterYParams { + record_amount: 20, + first: Some(4), + after_idx: 14, + is_desc: true, + }, + true, + 4, + ) + .await?; + test_first_x_after_y( + FirstXAfterYParams { + record_amount: 20, + first: Some(5), + after_idx: 14, + is_desc: true, + }, + false, + 5, + ) + .await?; + test_first_x_after_y( + FirstXAfterYParams { + record_amount: 20, + first: Some(5), + after_idx: 16, + is_desc: true, + }, + false, + 3, + ) + .await?; + + Ok(()) +} + +#[tokio::test] +async fn after_x() -> anyhow::Result<()> { + setup(); + let default_limit = crate::config().cursor_default_limit.get(); + + test_first_x_after_y( + FirstXAfterYParams { + record_amount: 100, + first: None, + after_idx: 15, + is_desc: false, + }, + true, + default_limit, + ) + .await?; + test_first_x_after_y( + FirstXAfterYParams { + record_amount: 100, + first: None, + after_idx: 100 - default_limit - 1, + is_desc: false, + }, + false, + default_limit, + ) + .await?; + test_first_x_after_y( + FirstXAfterYParams { + record_amount: 100, + first: None, + after_idx: 79, + is_desc: false, + }, + false, + 20, + ) + .await?; + + Ok(()) +} + +#[tokio::test] +async fn after_x_desc() -> anyhow::Result<()> { + setup(); + let default_limit = crate::config().cursor_default_limit.get(); + + test_first_x_after_y( + FirstXAfterYParams { + record_amount: 100, + first: None, + after_idx: 15, + is_desc: true, + }, + true, + default_limit, + ) + .await?; + test_first_x_after_y( + FirstXAfterYParams { + record_amount: 100, + first: None, + after_idx: 49, + is_desc: true, + }, + false, + default_limit, + ) + .await?; + test_first_x_after_y( + FirstXAfterYParams { + record_amount: 100, + first: None, + after_idx: 79, + is_desc: true, + }, + false, + 20, + ) + .await?; + + Ok(()) +} + +#[tokio::test] +async fn last_x_before_y() -> anyhow::Result<()> { + setup(); + + test_last_x_before_y( + LastXBeforeYParams { + record_amount: 10, + last: Some(3), + before_idx: 6, + is_desc: false, + }, + true, + 3, + ) + .await?; + test_last_x_before_y( + LastXBeforeYParams { + record_amount: 10, + last: Some(3), + before_idx: 3, + is_desc: false, + }, + false, + 3, + ) + .await?; + test_last_x_before_y( + LastXBeforeYParams { + record_amount: 10, + last: Some(4), + before_idx: 3, + is_desc: false, + }, + false, + 3, + ) + .await?; + + Ok(()) +} + +#[tokio::test] +async fn last_x_before_y_desc() -> anyhow::Result<()> { + setup(); + + test_last_x_before_y( + LastXBeforeYParams { + record_amount: 10, + last: Some(3), + before_idx: 6, + is_desc: true, + }, + true, + 3, + ) + .await?; + test_last_x_before_y( + LastXBeforeYParams { + record_amount: 10, + last: Some(3), + before_idx: 3, + is_desc: true, + }, + false, + 3, + ) + .await?; + test_last_x_before_y( + LastXBeforeYParams { + record_amount: 10, + last: Some(4), + before_idx: 3, + is_desc: true, + }, + false, + 3, + ) + .await?; + + Ok(()) +} + +#[tokio::test] +async fn before_x() -> anyhow::Result<()> { + setup(); + let default_limit = crate::config().cursor_default_limit.get(); + + test_last_x_before_y( + LastXBeforeYParams { + record_amount: 100, + last: None, + before_idx: 75, + is_desc: false, + }, + true, + default_limit, + ) + .await?; + test_last_x_before_y( + LastXBeforeYParams { + record_amount: 100, + last: None, + before_idx: 100 - default_limit, + is_desc: false, + }, + false, + default_limit, + ) + .await?; + test_last_x_before_y( + LastXBeforeYParams { + record_amount: 100, + last: None, + before_idx: 3, + is_desc: false, + }, + false, + 3, + ) + .await?; + + Ok(()) +} + +#[tokio::test] +async fn before_x_desc() -> anyhow::Result<()> { + setup(); + let default_limit = crate::config().cursor_default_limit.get(); + + test_last_x_before_y( + LastXBeforeYParams { + record_amount: 100, + last: None, + before_idx: 75, + is_desc: true, + }, + true, + default_limit, + ) + .await?; + test_last_x_before_y( + LastXBeforeYParams { + record_amount: 100, + last: None, + before_idx: 100 - default_limit, + is_desc: true, + }, + false, + default_limit, + ) + .await?; + test_last_x_before_y( + LastXBeforeYParams { + record_amount: 100, + last: None, + before_idx: 3, + is_desc: true, + }, + false, + 3, + ) + .await?; + + Ok(()) +} diff --git a/crates/graphql-api/src/tests/queryroot_maps_connection.rs b/crates/graphql-api/src/tests/queryroot_maps_connection.rs new file mode 100644 index 0000000..315c887 --- /dev/null +++ b/crates/graphql-api/src/tests/queryroot_maps_connection.rs @@ -0,0 +1,548 @@ +use async_graphql::connection::CursorType; +use deadpool_redis::redis::{self, ToRedisArgs}; +use entity::{maps, players}; +use mkenv::Layer as _; +use rand::Rng; +use records_lib::RedisConnection; +use sea_orm::{ActiveValue::Set, EntityTrait}; + +use crate::{ + config::InitError, + cursors::{ConnectionParameters, F64Cursor}, + objects::root::{MapsConnectionInput, get_maps_connection}, +}; + +fn setup() { + match crate::init_config() { + Ok(_) | Err(InitError::ConfigAlreadySet) => (), + Err(InitError::Config(e)) => { + panic!("error during test setup: {e}"); + } + } +} + +#[derive(Debug)] +struct Map { + cursor: String, + rank: i32, + id: u32, + uid: String, + name: String, + score: f64, +} + +impl PartialEq for Map { + fn eq(&self, other: &Self) -> bool { + const ALLOWED_ERROR: f64 = 0.05; + self.cursor == other.cursor + && self.rank == other.rank + && self.id == other.id + && self.uid == other.uid + && self.name == other.name + && (self.score - other.score).abs() < ALLOWED_ERROR + } +} + +fn gen_map_ranking_key() -> String { + "__test_map_ranking_" + .chars() + .chain( + rand::rng() + .sample_iter(rand::distr::Alphabetic) + .take(20) + .map(char::from), + ) + .collect() +} + +async fn fill_redis_lb( + redis_conn: &mut RedisConnection, + key: impl ToRedisArgs, + iter: I, +) -> anyhow::Result<()> +where + I: IntoIterator, + M: ToRedisArgs, + S: ToRedisArgs, +{ + let mut pipe = redis::pipe(); + pipe.atomic(); + pipe.del(&key); + for (member, score) in iter { + pipe.zadd(&key, member, score); + } + pipe.exec_async(redis_conn).await?; + Ok(()) +} + +#[tokio::test] +async fn default_page() -> anyhow::Result<()> { + setup(); + + let default_limit = crate::config().cursor_default_limit.get(); + let map_amount = default_limit * 2; + + let author = players::ActiveModel { + id: Set(1), + login: Set("boogalogin".to_owned()), + name: Set("booganame".to_owned()), + role: Set(0), + ..Default::default() + }; + + let maps = (0..map_amount).map(|i| maps::ActiveModel { + id: Set((i + 1) as _), + game_id: Set(format!("map_{i}_uid")), + name: Set(format!("map_{i}_name")), + score: Set(i as _), + player_id: Set(1), + ..Default::default() + }); + + let source = gen_map_ranking_key(); + + test_env::wrap(async |db| { + players::Entity::insert(author).exec(&db.sql_conn).await?; + maps::Entity::insert_many(maps).exec(&db.sql_conn).await?; + + let mut redis_conn = db.redis_pool.get().await?; + fill_redis_lb( + &mut redis_conn, + &source, + (1..=map_amount).zip(0..map_amount), + ) + .await?; + + let result = get_maps_connection( + &db.sql_conn, + &mut redis_conn, + ::default().with_source(source), + ) + .await?; + + itertools::assert_equal( + result.edges.into_iter().map(|edge| Map { + cursor: edge.cursor.0, + rank: edge.node.rank, + id: edge.node.map.inner.id, + uid: edge.node.map.inner.game_id, + name: edge.node.map.inner.name, + score: edge.node.map.inner.score, + }), + (0..default_limit).map(|i| Map { + cursor: F64Cursor { + score: i as _, + data: (i + 1), + } + .encode_cursor(), + id: (i + 1) as _, + uid: format!("map_{i}_uid"), + name: format!("map_{i}_name"), + rank: (map_amount - i) as _, + score: i as _, + }), + ); + + anyhow::Ok(()) + }) + .await +} + +struct FirstXAfterYParams { + map_amount: usize, + first: Option, + after_idx: usize, +} + +#[tracing::instrument( + skip(params), + fields( + map_amount = params.map_amount, + first = params.first, + after_idx = params.after_idx, + ) +)] +async fn test_first_x_after_y( + params: FirstXAfterYParams, + has_next_page: bool, + expected_len: usize, +) -> anyhow::Result<()> { + let limit = params + .first + .unwrap_or_else(|| crate::config().cursor_default_limit.get()); + + let computed_has_next_page = params.after_idx + 1 + limit < params.map_amount; + if has_next_page != computed_has_next_page { + tracing::warn!( + computed_has_next_page = computed_has_next_page, + "wrong has_next_page", + ); + } + let computed_expected_len = limit.min( + params + .map_amount + .saturating_sub(params.after_idx) + .saturating_sub(1), + ); + if expected_len != computed_expected_len { + tracing::warn!( + computed_expected_len = computed_expected_len, + "wrong expected_len", + ); + } + + let author = players::ActiveModel { + id: Set(1), + login: Set("boogalogin".to_owned()), + name: Set("booganame".to_owned()), + role: Set(0), + ..Default::default() + }; + + let maps = (0..params.map_amount).map(|i| maps::ActiveModel { + id: Set((i + 1) as _), + game_id: Set(format!("map_{i}_uid")), + name: Set(format!("map_{i}_name")), + score: Set(i as _), + player_id: Set(1), + ..Default::default() + }); + + let source = gen_map_ranking_key(); + + test_env::wrap(async |db| { + players::Entity::insert(author).exec(&db.sql_conn).await?; + maps::Entity::insert_many(maps).exec(&db.sql_conn).await?; + + let mut redis_conn = db.redis_pool.get().await?; + fill_redis_lb( + &mut redis_conn, + &source, + (1..=params.map_amount).zip(0..params.map_amount), + ) + .await?; + + let result = get_maps_connection( + &db.sql_conn, + &mut redis_conn, + ::new(ConnectionParameters { + first: params.first, + after: Some( + F64Cursor { + score: params.after_idx as _, + data: params.after_idx as u32 + 1, + } + .into(), + ), + ..Default::default() + }) + .with_source(source), + ) + .await?; + + itertools::assert_equal( + result.edges.into_iter().map(|edge| Map { + cursor: edge.cursor.0, + rank: edge.node.rank, + id: edge.node.map.inner.id, + uid: edge.node.map.inner.game_id, + name: edge.node.map.inner.name, + score: edge.node.map.inner.score, + }), + (0..expected_len).map(|i| { + let i = params.after_idx + 1 + i; + Map { + cursor: F64Cursor { + score: i as _, + data: (i + 1), + } + .encode_cursor(), + id: (i + 1) as _, + uid: format!("map_{i}_uid"), + name: format!("map_{i}_name"), + rank: (params.map_amount - i) as _, + score: i as _, + } + }), + ); + + anyhow::Ok(()) + }) + .await +} + +struct LastXBeforeYParams { + map_amount: usize, + last: Option, + before_idx: usize, +} + +#[tracing::instrument( + skip(params), + fields( + map_amount = params.map_amount, + last = params.last, + before_idx = params.before_idx, + ) +)] +async fn test_last_x_before_y( + params: LastXBeforeYParams, + has_previous_page: bool, + expected_len: usize, +) -> anyhow::Result<()> { + let limit = params + .last + .unwrap_or_else(|| crate::config().cursor_default_limit.get()); + + let computed_has_previous_page = limit < params.before_idx; + if has_previous_page != computed_has_previous_page { + tracing::warn!( + computed_has_previous_page = computed_has_previous_page, + "wrong has_previous_page", + ); + } + let computed_expected_len = limit.min(params.before_idx).min(params.map_amount); + if expected_len != computed_expected_len { + tracing::warn!( + computed_expected_len = computed_expected_len, + "wrong expected_len", + ); + } + + let author = players::ActiveModel { + id: Set(1), + login: Set("boogalogin".to_owned()), + name: Set("booganame".to_owned()), + role: Set(0), + ..Default::default() + }; + + let maps = (0..params.map_amount).map(|i| maps::ActiveModel { + id: Set((i + 1) as _), + game_id: Set(format!("map_{i}_uid")), + name: Set(format!("map_{i}_name")), + score: Set(i as _), + player_id: Set(1), + ..Default::default() + }); + + let source = gen_map_ranking_key(); + + test_env::wrap(async |db| { + players::Entity::insert(author).exec(&db.sql_conn).await?; + maps::Entity::insert_many(maps).exec(&db.sql_conn).await?; + + let mut redis_conn = db.redis_pool.get().await?; + fill_redis_lb( + &mut redis_conn, + &source, + (1..=params.map_amount).zip(0..params.map_amount), + ) + .await?; + + let result = get_maps_connection( + &db.sql_conn, + &mut redis_conn, + ::new(ConnectionParameters { + last: params.last, + before: Some( + F64Cursor { + score: params.before_idx as _, + data: params.before_idx as u32 + 1, + } + .into(), + ), + ..Default::default() + }) + .with_source(source), + ) + .await?; + + itertools::assert_equal( + result.edges.into_iter().map(|edge| Map { + cursor: edge.cursor.0, + rank: edge.node.rank, + id: edge.node.map.inner.id, + uid: edge.node.map.inner.game_id, + name: edge.node.map.inner.name, + score: edge.node.map.inner.score, + }), + (0..expected_len).map(|i| { + let i = i + params.before_idx - expected_len; + Map { + cursor: F64Cursor { + score: i as _, + data: (i + 1), + } + .encode_cursor(), + id: (i + 1) as _, + uid: format!("map_{i}_uid"), + name: format!("map_{i}_name"), + rank: (params.map_amount - i) as _, + score: i as _, + } + }), + ); + + anyhow::Ok(()) + }) + .await +} + +#[tokio::test] +async fn first_x_after_y() -> anyhow::Result<()> { + setup(); + + test_first_x_after_y( + FirstXAfterYParams { + map_amount: 10, + first: Some(3), + after_idx: 2, + }, + true, + 3, + ) + .await?; + test_first_x_after_y( + FirstXAfterYParams { + map_amount: 10, + first: Some(10), + after_idx: 2, + }, + false, + 7, + ) + .await?; + test_first_x_after_y( + FirstXAfterYParams { + map_amount: 10, + first: Some(6), + after_idx: 2, + }, + true, + 6, + ) + .await?; + + Ok(()) +} + +#[tokio::test] +async fn after_x() -> anyhow::Result<()> { + setup(); + + let default_limit = crate::config().cursor_default_limit.get(); + let map_amount = default_limit * 2; + + test_first_x_after_y( + FirstXAfterYParams { + map_amount, + first: None, + after_idx: 2, + }, + true, + default_limit, + ) + .await?; + test_first_x_after_y( + FirstXAfterYParams { + map_amount, + first: None, + after_idx: default_limit + 7, + }, + false, + default_limit - 8, + ) + .await?; + test_first_x_after_y( + FirstXAfterYParams { + map_amount, + first: None, + after_idx: default_limit - 1, + }, + true, + default_limit, + ) + .await?; + + Ok(()) +} + +#[tokio::test] +async fn last_x_before_y() -> anyhow::Result<()> { + setup(); + + test_last_x_before_y( + LastXBeforeYParams { + map_amount: 10, + last: Some(3), + before_idx: 6, + }, + true, + 3, + ) + .await?; + test_last_x_before_y( + LastXBeforeYParams { + map_amount: 10, + last: Some(6), + before_idx: 6, + }, + false, + 6, + ) + .await?; + test_last_x_before_y( + LastXBeforeYParams { + map_amount: 10, + last: Some(7), + before_idx: 6, + }, + false, + 6, + ) + .await?; + + Ok(()) +} + +#[tokio::test] +async fn last_x() -> anyhow::Result<()> { + setup(); + + let default_limit = crate::config().cursor_default_limit.get(); + let map_amount = default_limit * 2; + + test_last_x_before_y( + LastXBeforeYParams { + map_amount, + last: None, + before_idx: default_limit + 1, + }, + true, + default_limit, + ) + .await?; + test_last_x_before_y( + LastXBeforeYParams { + map_amount, + last: None, + before_idx: default_limit, + }, + false, + default_limit, + ) + .await?; + test_last_x_before_y( + LastXBeforeYParams { + map_amount, + last: None, + before_idx: default_limit - 3, + }, + false, + default_limit - 3, + ) + .await?; + + Ok(()) +} diff --git a/crates/graphql-api/src/tests/queryroot_players_connection.rs b/crates/graphql-api/src/tests/queryroot_players_connection.rs new file mode 100644 index 0000000..5cc670b --- /dev/null +++ b/crates/graphql-api/src/tests/queryroot_players_connection.rs @@ -0,0 +1,526 @@ +use async_graphql::connection::CursorType; +use deadpool_redis::redis::{self, ToRedisArgs}; +use entity::players; +use mkenv::Layer as _; +use rand::Rng; +use records_lib::RedisConnection; +use sea_orm::{ActiveValue::Set, EntityTrait}; + +use crate::{ + config::InitError, + cursors::{ConnectionParameters, F64Cursor}, + objects::root::{PlayersConnectionInput, get_players_connection}, +}; + +fn setup() { + match crate::init_config() { + Ok(_) | Err(InitError::ConfigAlreadySet) => (), + Err(InitError::Config(e)) => { + panic!("error during test setup: {e}"); + } + } +} + +#[derive(Debug)] +struct Player { + cursor: String, + rank: i32, + id: u32, + login: String, + name: String, + score: f64, +} + +impl PartialEq for Player { + fn eq(&self, other: &Self) -> bool { + const ALLOWED_ERROR: f64 = 0.05; + self.cursor == other.cursor + && self.rank == other.rank + && self.id == other.id + && self.login == other.login + && self.name == other.name + && (self.score - other.score).abs() < ALLOWED_ERROR + } +} + +fn gen_player_ranking_key() -> String { + "__test_player_ranking_" + .chars() + .chain( + rand::rng() + .sample_iter(rand::distr::Alphabetic) + .take(20) + .map(char::from), + ) + .collect() +} + +async fn fill_redis_lb( + redis_conn: &mut RedisConnection, + key: impl ToRedisArgs, + iter: I, +) -> anyhow::Result<()> +where + I: IntoIterator, + P: ToRedisArgs, + S: ToRedisArgs, +{ + let mut pipe = redis::pipe(); + pipe.atomic(); + pipe.del(&key); + for (player_id, score) in iter { + pipe.zadd(&key, player_id, score); + } + pipe.exec_async(redis_conn).await?; + Ok(()) +} + +#[tokio::test] +async fn default_page() -> anyhow::Result<()> { + setup(); + + let default_limit = crate::config().cursor_default_limit.get(); + let player_amount = default_limit * 2; + + let players = (0..player_amount).map(|i| players::ActiveModel { + id: Set((i + 1) as _), + login: Set(format!("player_{i}_login")), + name: Set(format!("player_{i}_name")), + role: Set(0), + score: Set(i as _), + ..Default::default() + }); + + let source = gen_player_ranking_key(); + + test_env::wrap(async |db| { + players::Entity::insert_many(players) + .exec(&db.sql_conn) + .await?; + + let mut redis_conn = db.redis_pool.get().await?; + fill_redis_lb( + &mut redis_conn, + &source, + (1..=player_amount).zip(0..player_amount), + ) + .await?; + + let result = get_players_connection( + &db.sql_conn, + &mut redis_conn, + ::default().with_source(source), + ) + .await?; + + itertools::assert_equal( + result.edges.into_iter().map(|edge| Player { + cursor: edge.cursor.0, + rank: edge.node.rank, + id: edge.node.player.inner.id, + login: edge.node.player.inner.login, + name: edge.node.player.inner.name, + score: edge.node.player.inner.score, + }), + (0..default_limit).map(|i| Player { + cursor: F64Cursor { + score: i as _, + data: (i + 1), + } + .encode_cursor(), + id: (i + 1) as _, + login: format!("player_{i}_login"), + name: format!("player_{i}_name"), + rank: (player_amount - i) as _, + score: i as _, + }), + ); + + anyhow::Ok(()) + }) + .await +} + +struct FirstXAfterYParams { + player_amount: usize, + first: Option, + after_idx: usize, +} + +#[tracing::instrument( + skip(params), + fields( + player_amount = params.player_amount, + first = params.first, + after_idx = params.after_idx, + ) +)] +async fn test_first_x_after_y( + params: FirstXAfterYParams, + has_next_page: bool, + expected_len: usize, +) -> anyhow::Result<()> { + let limit = params + .first + .unwrap_or_else(|| crate::config().cursor_default_limit.get()); + + let computed_has_next_page = params.after_idx + 1 + limit < params.player_amount; + if has_next_page != computed_has_next_page { + tracing::warn!( + computed_has_next_page = computed_has_next_page, + "wrong has_next_page", + ); + } + let computed_expected_len = limit.min( + params + .player_amount + .saturating_sub(params.after_idx) + .saturating_sub(1), + ); + if expected_len != computed_expected_len { + tracing::warn!( + computed_expected_len = computed_expected_len, + "wrong expected_len", + ); + } + + let players = (0..params.player_amount).map(|i| players::ActiveModel { + id: Set((i + 1) as _), + login: Set(format!("player_{i}_login")), + name: Set(format!("player_{i}_name")), + role: Set(0), + score: Set(i as _), + ..Default::default() + }); + + let source = gen_player_ranking_key(); + + test_env::wrap(async |db| { + players::Entity::insert_many(players) + .exec(&db.sql_conn) + .await?; + + let mut redis_conn = db.redis_pool.get().await?; + fill_redis_lb( + &mut redis_conn, + &source, + (1..=params.player_amount).zip(0..params.player_amount), + ) + .await?; + + let result = get_players_connection( + &db.sql_conn, + &mut redis_conn, + ::new(ConnectionParameters { + first: params.first, + after: Some( + F64Cursor { + score: params.after_idx as _, + data: params.after_idx as u32 + 1, + } + .into(), + ), + ..Default::default() + }) + .with_source(source), + ) + .await?; + + itertools::assert_equal( + result.edges.into_iter().map(|edge| Player { + cursor: edge.cursor.0, + rank: edge.node.rank, + id: edge.node.player.inner.id, + login: edge.node.player.inner.login, + name: edge.node.player.inner.name, + score: edge.node.player.inner.score, + }), + (0..expected_len).map(|i| { + let i = params.after_idx + 1 + i; + Player { + cursor: F64Cursor { + score: i as _, + data: (i + 1), + } + .encode_cursor(), + id: (i + 1) as _, + login: format!("player_{i}_login"), + name: format!("player_{i}_name"), + rank: (params.player_amount - i) as _, + score: i as _, + } + }), + ); + + anyhow::Ok(()) + }) + .await +} + +struct LastXBeforeYParams { + player_amount: usize, + last: Option, + before_idx: usize, +} + +#[tracing::instrument( + skip(params), + fields( + player_amount = params.player_amount, + last = params.last, + before_idx = params.before_idx, + ) +)] +async fn test_last_x_before_y( + params: LastXBeforeYParams, + has_previous_page: bool, + expected_len: usize, +) -> anyhow::Result<()> { + let limit = params + .last + .unwrap_or_else(|| crate::config().cursor_default_limit.get()); + + let computed_has_previous_page = limit < params.before_idx; + if has_previous_page != computed_has_previous_page { + tracing::warn!( + computed_has_previous_page = computed_has_previous_page, + "wrong has_previous_page", + ); + } + let computed_expected_len = limit.min(params.before_idx).min(params.player_amount); + if expected_len != computed_expected_len { + tracing::warn!( + computed_expected_len = computed_expected_len, + "wrong expected_len", + ); + } + + let players = (0..params.player_amount).map(|i| players::ActiveModel { + id: Set((i + 1) as _), + login: Set(format!("player_{i}_login")), + name: Set(format!("player_{i}_name")), + role: Set(0), + score: Set(i as _), + ..Default::default() + }); + + let source = gen_player_ranking_key(); + + test_env::wrap(async |db| { + players::Entity::insert_many(players) + .exec(&db.sql_conn) + .await?; + let mut redis_conn = db.redis_pool.get().await?; + fill_redis_lb( + &mut redis_conn, + &source, + (1..=params.player_amount).zip(0..params.player_amount), + ) + .await?; + + let result = get_players_connection( + &db.sql_conn, + &mut redis_conn, + ::new(ConnectionParameters { + last: params.last, + before: Some( + F64Cursor { + score: params.before_idx as _, + data: params.before_idx as u32 + 1, + } + .into(), + ), + ..Default::default() + }) + .with_source(source), + ) + .await?; + + itertools::assert_equal( + result.edges.into_iter().map(|edge| Player { + cursor: edge.cursor.0, + rank: edge.node.rank, + id: edge.node.player.inner.id, + login: edge.node.player.inner.login, + name: edge.node.player.inner.name, + score: edge.node.player.inner.score, + }), + (0..expected_len).map(|i| { + let i = i + params.before_idx - expected_len; + Player { + cursor: F64Cursor { + score: i as _, + data: (i + 1), + } + .encode_cursor(), + id: (i + 1) as _, + login: format!("player_{i}_login"), + name: format!("player_{i}_name"), + rank: (params.player_amount - i) as _, + score: i as _, + } + }), + ); + + anyhow::Ok(()) + }) + .await +} + +#[tokio::test] +async fn first_x_after_y() -> anyhow::Result<()> { + setup(); + + test_first_x_after_y( + FirstXAfterYParams { + player_amount: 10, + first: Some(3), + after_idx: 2, + }, + true, + 3, + ) + .await?; + test_first_x_after_y( + FirstXAfterYParams { + player_amount: 10, + first: Some(10), + after_idx: 2, + }, + false, + 7, + ) + .await?; + test_first_x_after_y( + FirstXAfterYParams { + player_amount: 10, + first: Some(6), + after_idx: 2, + }, + true, + 6, + ) + .await?; + + Ok(()) +} + +#[tokio::test] +async fn after_x() -> anyhow::Result<()> { + setup(); + + let default_limit = crate::config().cursor_default_limit.get(); + let player_amount = default_limit * 2; + + test_first_x_after_y( + FirstXAfterYParams { + player_amount, + first: None, + after_idx: 2, + }, + true, + default_limit, + ) + .await?; + test_first_x_after_y( + FirstXAfterYParams { + player_amount, + first: None, + after_idx: default_limit + 7, + }, + false, + default_limit - 8, + ) + .await?; + test_first_x_after_y( + FirstXAfterYParams { + player_amount, + first: None, + after_idx: default_limit - 1, + }, + true, + default_limit, + ) + .await?; + + Ok(()) +} + +#[tokio::test] +async fn last_x_before_y() -> anyhow::Result<()> { + setup(); + + test_last_x_before_y( + LastXBeforeYParams { + player_amount: 10, + last: Some(3), + before_idx: 6, + }, + true, + 3, + ) + .await?; + test_last_x_before_y( + LastXBeforeYParams { + player_amount: 10, + last: Some(6), + before_idx: 6, + }, + false, + 6, + ) + .await?; + test_last_x_before_y( + LastXBeforeYParams { + player_amount: 10, + last: Some(7), + before_idx: 6, + }, + false, + 6, + ) + .await?; + + Ok(()) +} + +#[tokio::test] +async fn last_x() -> anyhow::Result<()> { + setup(); + + let default_limit = crate::config().cursor_default_limit.get(); + let player_amount = default_limit * 2; + + test_last_x_before_y( + LastXBeforeYParams { + player_amount, + last: None, + before_idx: default_limit + 1, + }, + true, + default_limit, + ) + .await?; + test_last_x_before_y( + LastXBeforeYParams { + player_amount, + last: None, + before_idx: default_limit, + }, + false, + default_limit, + ) + .await?; + test_last_x_before_y( + LastXBeforeYParams { + player_amount, + last: None, + before_idx: default_limit - 3, + }, + false, + default_limit - 3, + ) + .await?; + + Ok(()) +} diff --git a/crates/graphql-api/src/tests/queryroot_records_connection.rs b/crates/graphql-api/src/tests/queryroot_records_connection.rs new file mode 100644 index 0000000..ec6e259 --- /dev/null +++ b/crates/graphql-api/src/tests/queryroot_records_connection.rs @@ -0,0 +1,861 @@ +use mkenv::prelude::*; +use std::time::Duration; + +use async_graphql::connection::CursorType; +use chrono::SubsecRound; +use entity::{maps, players, records}; +use itertools::Itertools; +use sea_orm::{ActiveValue::Set, EntityTrait}; + +use crate::{ + config::InitError, + cursors::{ConnectionParameters, RecordDateCursor}, + objects::{ + root::get_records_connection, sort::UnorderedRecordSort, sort_order::SortOrder, + sortable_fields::UnorderedRecordSortableField, + }, +}; + +fn setup() { + match crate::init_config() { + Ok(_) | Err(InitError::ConfigAlreadySet) => (), + Err(InitError::Config(e)) => { + panic!("error during test setup: {e}"); + } + } +} + +#[derive(Debug, PartialEq)] +struct Record { + record_id: u32, + cursor: String, + rank: i32, + map_id: u32, + player_id: u32, + time: i32, + record_date: chrono::DateTime, + flags: u32, + respawn_count: i32, +} + +#[tracing::instrument] +async fn test_default_page(is_desc: bool) -> anyhow::Result<()> { + let default_limit = crate::config().cursor_default_limit.get(); + let player_amount = default_limit * 2; + + let players = (0..player_amount).map(|i| players::ActiveModel { + id: Set((i + 1) as _), + login: Set(format!("player_{i}_login")), + name: Set(format!("player_{i}_name")), + role: Set(0), + ..Default::default() + }); + + let map_id = test_env::get_map_id(); + let map = maps::ActiveModel { + id: Set(map_id), + game_id: Set("map_uid".to_owned()), + name: Set("map_name".to_owned()), + player_id: Set(1), + ..Default::default() + }; + + // The higher the player ID, the less recent the record + let now = chrono::Utc::now().naive_utc().trunc_subsecs(0); + let record_dates = (0..player_amount) + .map(|i| now - Duration::from_secs(3600 * (i as u64 + 1))) + .collect_vec(); + + let records = (0..player_amount).map(|i| records::ActiveModel { + record_id: Set((i + 1) as _), + map_id: Set(map_id), + record_player_id: Set((i + 1) as _), + flags: Set(682), + time: Set(1000), + respawn_count: Set(0), + record_date: Set(record_dates[i]), + ..Default::default() + }); + + test_env::wrap(async |db| { + players::Entity::insert_many(players) + .exec(&db.sql_conn) + .await?; + maps::Entity::insert(map).exec(&db.sql_conn).await?; + records::Entity::insert_many(records) + .exec(&db.sql_conn) + .await?; + + let result = get_records_connection( + &db.sql_conn, + &db.redis_pool, + ConnectionParameters { + before: None, + after: None, + first: None, + last: None, + }, + Default::default(), + is_desc.then_some(UnorderedRecordSort { + field: UnorderedRecordSortableField::Date, + order: Some(SortOrder::Descending), + }), + None, + ) + .await?; + + itertools::assert_equal( + result.edges.into_iter().map(|edge| Record { + record_id: edge.node.inner.record.record_id, + cursor: edge.cursor.0, + rank: edge.node.inner.rank, + map_id: edge.node.inner.record.map_id, + player_id: edge.node.inner.record.record_player_id, + flags: edge.node.inner.record.flags, + record_date: edge.node.inner.record.record_date.and_utc(), + respawn_count: edge.node.inner.record.respawn_count, + time: edge.node.inner.record.time, + }), + (0..default_limit).map(|i| { + let i = if is_desc { player_amount - 1 - i } else { i }; + let record_date = record_dates[i].and_utc(); + Record { + record_id: i as u32 + 1, + cursor: RecordDateCursor { + record_date, + data: i + 1, + } + .encode_cursor(), + rank: 1, + map_id, + flags: 682, + player_id: (i + 1) as _, + record_date, + respawn_count: 0, + time: 1000, + } + }), + ); + + assert!(result.has_next_page); + assert!(!result.has_previous_page); + + anyhow::Ok(()) + }) + .await +} + +#[tokio::test] +async fn default_page() -> anyhow::Result<()> { + setup(); + test_default_page(false).await +} + +#[tokio::test] +async fn default_page_desc() -> anyhow::Result<()> { + setup(); + test_default_page(true).await +} + +#[tokio::test] +async fn first_x_after_y() -> anyhow::Result<()> { + setup(); + test_first_x_after_y( + FirstXAfterYParams { + player_amount: 20, + first: Some(6), + after_idx: 4, + is_desc: false, + }, + true, + 6, + ) + .await?; + test_first_x_after_y( + FirstXAfterYParams { + player_amount: 20, + first: Some(4), + after_idx: 4, + is_desc: false, + }, + true, + 4, + ) + .await?; + test_first_x_after_y( + FirstXAfterYParams { + player_amount: 20, + first: Some(4), + after_idx: 14, + is_desc: false, + }, + true, + 4, + ) + .await?; + test_first_x_after_y( + FirstXAfterYParams { + player_amount: 20, + first: Some(5), + after_idx: 14, + is_desc: false, + }, + false, + 5, + ) + .await?; + test_first_x_after_y( + FirstXAfterYParams { + player_amount: 20, + first: Some(4), + after_idx: 17, + is_desc: false, + }, + false, + 2, + ) + .await?; + + Ok(()) +} + +struct FirstXAfterYParams { + player_amount: usize, + first: Option, + after_idx: usize, + is_desc: bool, +} + +#[tracing::instrument( + skip(params), + fields( + player_amount = params.player_amount, + first = params.first, + after_idx = params.after_idx, + is_desc = params.is_desc, + ) +)] +async fn test_first_x_after_y( + params: FirstXAfterYParams, + has_next_page: bool, + expected_len: usize, +) -> anyhow::Result<()> { + let limit = params + .first + .unwrap_or_else(|| crate::config().cursor_default_limit.get()); + + let computed_has_next_page = params.after_idx + 1 + limit < params.player_amount; + if has_next_page != computed_has_next_page { + tracing::warn!( + computed_has_next_page = computed_has_next_page, + "wrong has_next_page", + ); + } + let computed_expected_len = limit.min( + params + .player_amount + .saturating_sub(params.after_idx) + .saturating_sub(1), + ); + if expected_len != computed_expected_len { + tracing::warn!( + computed_expected_len = computed_expected_len, + "wrong expected_len", + ); + } + + let players = (0..params.player_amount).map(|i| players::ActiveModel { + id: Set((i + 1) as _), + login: Set(format!("player_{i}_login")), + name: Set(format!("player_{i}_name")), + role: Set(0), + ..Default::default() + }); + + let map_id = test_env::get_map_id(); + let map = maps::ActiveModel { + id: Set(map_id), + game_id: Set("map_uid".to_owned()), + name: Set("map_name".to_owned()), + player_id: Set(1), + ..Default::default() + }; + + let now = chrono::Utc::now().naive_utc().trunc_subsecs(0); + // The higher the player ID, the less recent the record + let record_dates = (0..params.player_amount) + .map(|i| now - Duration::from_secs(3600 * (i as u64 + 1))) + .collect_vec(); + + let records = (0..params.player_amount).map(|i| records::ActiveModel { + record_id: Set((i + 1) as _), + map_id: Set(map_id), + record_player_id: Set((i + 1) as _), + flags: Set(682), + time: Set(1000), + respawn_count: Set(0), + record_date: Set(record_dates[i]), + ..Default::default() + }); + + test_env::wrap(async |db| { + players::Entity::insert_many(players) + .exec(&db.sql_conn) + .await?; + maps::Entity::insert(map).exec(&db.sql_conn).await?; + records::Entity::insert_many(records) + .exec(&db.sql_conn) + .await?; + + let result = get_records_connection( + &db.sql_conn, + &db.redis_pool, + ConnectionParameters { + before: None, + after: Some({ + let idx = if params.is_desc { + params + .player_amount + .saturating_sub(1) + .saturating_sub(params.after_idx) + } else { + params.after_idx + }; + RecordDateCursor { + record_date: record_dates[idx].and_utc(), + data: idx as u32 + 1, + } + }), + first: params.first, + last: None, + }, + Default::default(), + params.is_desc.then_some(UnorderedRecordSort { + field: UnorderedRecordSortableField::Date, + order: Some(SortOrder::Descending), + }), + None, + ) + .await?; + + itertools::assert_equal( + result.edges.into_iter().map(|edge| Record { + record_id: edge.node.inner.record.record_id, + cursor: edge.cursor.0, + rank: edge.node.inner.rank, + map_id: edge.node.inner.record.map_id, + player_id: edge.node.inner.record.record_player_id, + flags: edge.node.inner.record.flags, + record_date: edge.node.inner.record.record_date.and_utc(), + respawn_count: edge.node.inner.record.respawn_count, + time: edge.node.inner.record.time, + }), + (0..expected_len).map(|i| { + let i = if params.is_desc { + params + .player_amount + .saturating_sub(1) + .saturating_sub(params.after_idx) + .saturating_sub(1) + .saturating_sub(i) + } else { + params.after_idx + 1 + i + }; + let record_date = record_dates[i].and_utc(); + Record { + record_id: i as u32 + 1, + cursor: RecordDateCursor { + record_date, + data: i as u32 + 1, + } + .encode_cursor(), + rank: 1, + map_id, + flags: 682, + player_id: (i + 1) as _, + record_date, + respawn_count: 0, + time: 1000, + } + }), + ); + + assert_eq!(result.has_next_page, has_next_page); + assert!(result.has_previous_page); + + anyhow::Ok(()) + }) + .await +} + +struct LastXBeforeYParams { + player_amount: usize, + last: Option, + before_idx: usize, + is_desc: bool, +} + +#[tracing::instrument( + skip(params), + fields( + player_amount = params.player_amount, + last = params.last, + before_idx = params.before_idx, + is_desc = params.is_desc, + ) +)] +async fn test_last_x_before_y( + params: LastXBeforeYParams, + has_previous_page: bool, + expected_len: usize, +) -> anyhow::Result<()> { + let limit = params + .last + .unwrap_or_else(|| crate::config().cursor_default_limit.get()); + + let computed_has_previous_page = limit < params.before_idx; + if has_previous_page != computed_has_previous_page { + tracing::warn!( + computed_has_previous_page = computed_has_previous_page, + "wrong has_previous_page", + ); + } + let computed_expected_len = limit.min(params.before_idx).min(params.player_amount); + if expected_len != computed_expected_len { + tracing::warn!( + computed_expected_len = computed_expected_len, + "wrong expected_len", + ); + } + + let players = (0..params.player_amount).map(|i| players::ActiveModel { + id: Set((i + 1) as _), + login: Set(format!("player_{i}_login")), + name: Set(format!("player_{i}_name")), + role: Set(0), + ..Default::default() + }); + + let map_id = test_env::get_map_id(); + let map = maps::ActiveModel { + id: Set(map_id), + game_id: Set("map_uid".to_owned()), + name: Set("map_name".to_owned()), + player_id: Set(1), + ..Default::default() + }; + + let now = chrono::Utc::now().naive_utc().trunc_subsecs(0); + // The higher the player ID, the less recent the record + let record_dates = (0..params.player_amount) + .map(|i| now - Duration::from_secs(3600 * (i as u64 + 1))) + .collect_vec(); + + let records = (0..params.player_amount).map(|i| records::ActiveModel { + record_id: Set((i + 1) as _), + map_id: Set(map_id), + record_player_id: Set((i + 1) as _), + flags: Set(682), + time: Set(1000), + respawn_count: Set(0), + record_date: Set(record_dates[i]), + ..Default::default() + }); + + test_env::wrap(async |db| { + players::Entity::insert_many(players) + .exec(&db.sql_conn) + .await?; + maps::Entity::insert(map).exec(&db.sql_conn).await?; + records::Entity::insert_many(records) + .exec(&db.sql_conn) + .await?; + + let result = get_records_connection( + &db.sql_conn, + &db.redis_pool, + ConnectionParameters { + before: Some({ + let idx = if params.is_desc { + params + .player_amount + .saturating_sub(1) + .saturating_sub(params.before_idx) + } else { + params.before_idx + }; + RecordDateCursor { + record_date: record_dates[idx].and_utc(), + data: idx as u32 + 1, + } + }), + after: None, + first: None, + last: params.last, + }, + Default::default(), + params.is_desc.then_some(UnorderedRecordSort { + field: UnorderedRecordSortableField::Date, + order: Some(SortOrder::Descending), + }), + None, + ) + .await?; + + itertools::assert_equal( + result.edges.into_iter().map(|edge| Record { + record_id: edge.node.inner.record.record_id, + cursor: edge.cursor.0, + rank: edge.node.inner.rank, + map_id: edge.node.inner.record.map_id, + player_id: edge.node.inner.record.record_player_id, + flags: edge.node.inner.record.flags, + record_date: edge.node.inner.record.record_date.and_utc(), + respawn_count: edge.node.inner.record.respawn_count, + time: edge.node.inner.record.time, + }), + (0..expected_len).map(|i| { + let i = if params.is_desc { + params.player_amount - 1 - (i + params.before_idx - expected_len) + } else { + i + params.before_idx - expected_len + }; + let record_date = record_dates[i].and_utc(); + Record { + record_id: i as u32 + 1, + cursor: RecordDateCursor { + record_date, + data: i as u32 + 1, + } + .encode_cursor(), + rank: 1, + map_id, + flags: 682, + player_id: (i + 1) as _, + record_date, + respawn_count: 0, + time: 1000, + } + }), + ); + + assert!(result.has_next_page); + assert_eq!(result.has_previous_page, has_previous_page); + + anyhow::Ok(()) + }) + .await +} + +#[tokio::test] +async fn first_x_after_y_desc() -> anyhow::Result<()> { + setup(); + + test_first_x_after_y( + FirstXAfterYParams { + player_amount: 20, + first: Some(6), + after_idx: 4, + is_desc: true, + }, + true, + 6, + ) + .await?; + test_first_x_after_y( + FirstXAfterYParams { + player_amount: 20, + first: Some(4), + after_idx: 4, + is_desc: true, + }, + true, + 4, + ) + .await?; + test_first_x_after_y( + FirstXAfterYParams { + player_amount: 20, + first: Some(4), + after_idx: 14, + is_desc: true, + }, + true, + 4, + ) + .await?; + test_first_x_after_y( + FirstXAfterYParams { + player_amount: 20, + first: Some(5), + after_idx: 14, + is_desc: true, + }, + false, + 5, + ) + .await?; + test_first_x_after_y( + FirstXAfterYParams { + player_amount: 20, + first: Some(5), + after_idx: 16, + is_desc: true, + }, + false, + 3, + ) + .await?; + + Ok(()) +} + +#[tokio::test] +async fn after_x() -> anyhow::Result<()> { + setup(); + let default_limit = crate::config().cursor_default_limit.get(); + + test_first_x_after_y( + FirstXAfterYParams { + player_amount: 100, + first: None, + after_idx: 15, + is_desc: false, + }, + true, + default_limit, + ) + .await?; + test_first_x_after_y( + FirstXAfterYParams { + player_amount: 100, + first: None, + after_idx: 100 - default_limit - 1, + is_desc: false, + }, + false, + default_limit, + ) + .await?; + test_first_x_after_y( + FirstXAfterYParams { + player_amount: 100, + first: None, + after_idx: 79, + is_desc: false, + }, + false, + 20, + ) + .await?; + + Ok(()) +} + +#[tokio::test] +async fn after_x_desc() -> anyhow::Result<()> { + setup(); + let default_limit = crate::config().cursor_default_limit.get(); + + test_first_x_after_y( + FirstXAfterYParams { + player_amount: 100, + first: None, + after_idx: 15, + is_desc: true, + }, + true, + default_limit, + ) + .await?; + test_first_x_after_y( + FirstXAfterYParams { + player_amount: 100, + first: None, + after_idx: 49, + is_desc: true, + }, + false, + default_limit, + ) + .await?; + test_first_x_after_y( + FirstXAfterYParams { + player_amount: 100, + first: None, + after_idx: 79, + is_desc: true, + }, + false, + 20, + ) + .await?; + + Ok(()) +} + +#[tokio::test] +async fn last_x_before_y() -> anyhow::Result<()> { + setup(); + + test_last_x_before_y( + LastXBeforeYParams { + player_amount: 10, + last: Some(3), + before_idx: 6, + is_desc: false, + }, + true, + 3, + ) + .await?; + test_last_x_before_y( + LastXBeforeYParams { + player_amount: 10, + last: Some(3), + before_idx: 3, + is_desc: false, + }, + false, + 3, + ) + .await?; + test_last_x_before_y( + LastXBeforeYParams { + player_amount: 10, + last: Some(4), + before_idx: 3, + is_desc: false, + }, + false, + 3, + ) + .await?; + + Ok(()) +} + +#[tokio::test] +async fn last_x_before_y_desc() -> anyhow::Result<()> { + setup(); + + test_last_x_before_y( + LastXBeforeYParams { + player_amount: 10, + last: Some(3), + before_idx: 6, + is_desc: true, + }, + true, + 3, + ) + .await?; + test_last_x_before_y( + LastXBeforeYParams { + player_amount: 10, + last: Some(3), + before_idx: 3, + is_desc: true, + }, + false, + 3, + ) + .await?; + test_last_x_before_y( + LastXBeforeYParams { + player_amount: 10, + last: Some(4), + before_idx: 3, + is_desc: true, + }, + false, + 3, + ) + .await?; + + Ok(()) +} + +#[tokio::test] +async fn before_x() -> anyhow::Result<()> { + setup(); + let default_limit = crate::config().cursor_default_limit.get(); + + test_last_x_before_y( + LastXBeforeYParams { + player_amount: 100, + last: None, + before_idx: 75, + is_desc: false, + }, + true, + default_limit, + ) + .await?; + test_last_x_before_y( + LastXBeforeYParams { + player_amount: 100, + last: None, + before_idx: 100 - default_limit, + is_desc: false, + }, + false, + default_limit, + ) + .await?; + test_last_x_before_y( + LastXBeforeYParams { + player_amount: 100, + last: None, + before_idx: 3, + is_desc: false, + }, + false, + 3, + ) + .await?; + + Ok(()) +} + +#[tokio::test] +async fn before_x_desc() -> anyhow::Result<()> { + setup(); + let default_limit = crate::config().cursor_default_limit.get(); + + test_last_x_before_y( + LastXBeforeYParams { + player_amount: 100, + last: None, + before_idx: 75, + is_desc: true, + }, + true, + default_limit, + ) + .await?; + test_last_x_before_y( + LastXBeforeYParams { + player_amount: 100, + last: None, + before_idx: 100 - default_limit, + is_desc: true, + }, + false, + default_limit, + ) + .await?; + test_last_x_before_y( + LastXBeforeYParams { + player_amount: 100, + last: None, + before_idx: 3, + is_desc: true, + }, + false, + 3, + ) + .await?; + + Ok(()) +} diff --git a/crates/graphql-api/src/utils/mod.rs b/crates/graphql-api/src/utils/mod.rs new file mode 100644 index 0000000..c11c676 --- /dev/null +++ b/crates/graphql-api/src/utils/mod.rs @@ -0,0 +1,3 @@ +pub mod page_input; +pub mod pagination_result; +pub mod records_filter; diff --git a/crates/graphql-api/src/utils/page_input.rs b/crates/graphql-api/src/utils/page_input.rs new file mode 100644 index 0000000..f0599ef --- /dev/null +++ b/crates/graphql-api/src/utils/page_input.rs @@ -0,0 +1,80 @@ +use mkenv::prelude::*; +use sea_orm::SelectorTrait; + +use crate::{ + cursors::{ConnectionParameters, expr_tuple::IntoExprTuple, query_builder::CursorQueryBuilder}, + error::{ApiGqlError, GqlResult}, +}; + +pub enum PaginationDirection { + After { cursor: Option }, + Before { cursor: C }, +} + +pub struct PaginationInput { + pub dir: PaginationDirection, + pub limit: usize, +} + +impl PaginationInput { + pub fn get_cursor(&self) -> Option<&C> { + match &self.dir { + PaginationDirection::After { cursor } => cursor.as_ref(), + PaginationDirection::Before { cursor } => Some(cursor), + } + } +} + +impl PaginationInput { + pub fn try_from_input(input: ConnectionParameters) -> GqlResult { + match input { + ConnectionParameters { + first, + after, + last: None, + before: None, + } => { + let limit = first + .map(|t| t.min(crate::config().cursor_max_limit.get())) + .unwrap_or(crate::config().cursor_default_limit.get()); + Ok(Self { + limit, + dir: PaginationDirection::After { cursor: after }, + }) + } + ConnectionParameters { + last, + before: Some(before), + after: None, + first: None, + } => { + let limit = last + .map(|t| t.min(crate::config().cursor_max_limit.get())) + .unwrap_or(crate::config().cursor_default_limit.get()); + Ok(Self { + limit, + dir: PaginationDirection::Before { cursor: before }, + }) + } + _ => Err(ApiGqlError::from_pagination_input_error()), + } + } +} + +pub fn apply_cursor_input(cursor: &mut CursorQueryBuilder, input: &PaginationInput) +where + S: SelectorTrait, + for<'a> &'a C: IntoExprTuple, +{ + match &input.dir { + PaginationDirection::After { cursor: after } => { + cursor.first(input.limit as _); + if let Some(after) = after { + cursor.after(after); + } + } + PaginationDirection::Before { cursor: before } => { + cursor.last(input.limit as _).before(before); + } + } +} diff --git a/crates/graphql-api/src/utils/pagination_result.rs b/crates/graphql-api/src/utils/pagination_result.rs new file mode 100644 index 0000000..d8b0ced --- /dev/null +++ b/crates/graphql-api/src/utils/pagination_result.rs @@ -0,0 +1,57 @@ +use async_graphql::{ID, OutputType, connection}; +use sea_orm::{ConnectionTrait, SelectorTrait}; + +use crate::{ + cursors::query_builder::CursorQueryBuilder, + error::GqlResult, + utils::page_input::{PaginationDirection, PaginationInput}, +}; + +pub(crate) struct PaginationResult +where + T: OutputType, +{ + pub(crate) connection: connection::Connection, + pub(crate) iter: I, +} + +pub(crate) async fn get_paginated( + conn: &C, + mut query: CursorQueryBuilder, + pagination_input: &PaginationInput, +) -> GqlResult::Item>, T>> +where + C: ConnectionTrait, + T: OutputType, + S: SelectorTrait, +{ + match &pagination_input.dir { + PaginationDirection::After { cursor } => { + let players = query + .first(pagination_input.limit as u64 + 1) + .all(conn) + .await?; + Ok(PaginationResult { + connection: connection::Connection::new( + cursor.is_some(), + players.len() > pagination_input.limit, + ), + iter: itertools::Either::Left(players.into_iter().take(pagination_input.limit)), + }) + } + PaginationDirection::Before { .. } => { + let players = query + .last(pagination_input.limit as u64 + 1) + .all(conn) + .await?; + let amount_to_skip = players.len().saturating_sub(pagination_input.limit); + Ok(PaginationResult { + connection: connection::Connection::new( + players.len() > pagination_input.limit, + true, + ), + iter: itertools::Either::Right(players.into_iter().skip(amount_to_skip)), + }) + } + } +} diff --git a/crates/graphql-api/src/utils/records_filter.rs b/crates/graphql-api/src/utils/records_filter.rs new file mode 100644 index 0000000..0be647c --- /dev/null +++ b/crates/graphql-api/src/utils/records_filter.rs @@ -0,0 +1,161 @@ +use entity::{global_event_records, global_records, maps, players, records}; +use sea_orm::{ + ColumnTrait, EntityTrait, JoinType, QueryFilter as _, QuerySelect as _, RelationDef, + RelationTrait as _, Select, + prelude::Expr, + sea_query::{ExprTrait as _, Func}, +}; + +use crate::objects::records_filter::RecordsFilter; + +pub trait RecordsTableFilterConstructor { + type Column: ColumnTrait; + + const COL_RECORD_DATE: Self::Column; + const COL_TIME: Self::Column; + + fn get_players_relation() -> RelationDef; + + fn get_maps_relation() -> RelationDef; +} + +impl RecordsTableFilterConstructor for global_records::Entity { + type Column = global_records::Column; + + const COL_RECORD_DATE: Self::Column = global_records::Column::RecordDate; + const COL_TIME: Self::Column = global_records::Column::Time; + + fn get_players_relation() -> RelationDef { + global_records::Relation::Players.def() + } + + fn get_maps_relation() -> RelationDef { + global_records::Relation::Maps.def() + } +} + +impl RecordsTableFilterConstructor for records::Entity { + type Column = records::Column; + + const COL_RECORD_DATE: Self::Column = records::Column::RecordDate; + + const COL_TIME: Self::Column = records::Column::Time; + + fn get_players_relation() -> RelationDef { + records::Relation::Players.def() + } + + fn get_maps_relation() -> RelationDef { + records::Relation::Maps.def() + } +} + +impl RecordsTableFilterConstructor for global_event_records::Entity { + type Column = global_event_records::Column; + + const COL_RECORD_DATE: Self::Column = global_event_records::Column::RecordDate; + + const COL_TIME: Self::Column = global_event_records::Column::Time; + + fn get_players_relation() -> RelationDef { + global_event_records::Relation::Players.def() + } + + fn get_maps_relation() -> RelationDef { + global_event_records::Relation::Maps.def() + } +} + +pub fn apply_filter(mut query: Select, filter: Option<&RecordsFilter>) -> Select +where + E: RecordsTableFilterConstructor + EntityTrait, +{ + let Some(filter) = filter else { + return query; + }; + + // Join with players table if needed for player filters + if filter.player.is_some() { + query = query.join_as(JoinType::InnerJoin, E::get_players_relation(), "p"); + } + + // Join with maps table if needed for map filters + if let Some(m) = &filter.map { + query = query.join_as(JoinType::InnerJoin, E::get_maps_relation(), "m"); + + // Join again with players table if filtering on map author + if m.author.is_some() { + query = query.join_as(JoinType::InnerJoin, maps::Relation::Players.def(), "p2"); + } + } + + if let Some(filter) = &filter.player { + // Apply player login filter + if let Some(login) = &filter.player_login { + query = + query.filter(Expr::col(("p", players::Column::Login)).like(format!("%{login}%"))); + } + + // Apply player name filter + if let Some(name) = &filter.player_name { + query = query.filter( + Func::cust("rm_mp_style") + .arg(Expr::col(("p", players::Column::Name))) + .like(format!("%{name}%")), + ); + } + } + + if let Some(filter) = &filter.map { + // Apply map UID filter + if let Some(uid) = &filter.map_uid { + query = query.filter(Expr::col(("m", maps::Column::GameId)).like(format!("%{uid}%"))); + } + + // Apply map name filter + if let Some(name) = &filter.map_name { + query = query.filter( + Func::cust("rm_mp_style") + .arg(Expr::col(("m", maps::Column::Name))) + .like(format!("%{name}%")), + ); + } + + if let Some(filter) = &filter.author { + // Apply player login filter + if let Some(login) = &filter.player_login { + query = query + .filter(Expr::col(("p2", players::Column::Login)).like(format!("%{login}%"))); + } + + // Apply player name filter + if let Some(name) = &filter.player_name { + query = query.filter( + Func::cust("rm_mp_style") + .arg(Expr::col(("p2", players::Column::Name))) + .like(format!("%{name}%")), + ); + } + } + } + + // Apply date filters + if let Some(before_date) = filter.before_date { + query = query.filter(E::COL_RECORD_DATE.lt(before_date)); + } + + if let Some(after_date) = filter.after_date { + query = query.filter(E::COL_RECORD_DATE.gt(after_date)); + } + + // Apply time filters + if let Some(time_gt) = filter.time_gt { + query = query.filter(E::COL_TIME.gt(time_gt)); + } + + if let Some(time_lt) = filter.time_lt { + query = query.filter(E::COL_TIME.lt(time_lt)); + } + + query +} diff --git a/crates/graphql-schema-generator/Cargo.toml b/crates/graphql-schema-generator/Cargo.toml index af8e9ba..fc1cd29 100644 --- a/crates/graphql-schema-generator/Cargo.toml +++ b/crates/graphql-schema-generator/Cargo.toml @@ -5,5 +5,5 @@ edition = "2024" [dependencies] anyhow.workspace = true -clap = { version = "4.5.48", features = ["derive"] } +clap = { version = "4.5.54", features = ["derive"] } graphql-api = { path = "../graphql-api" } diff --git a/crates/migration/Cargo.toml b/crates/migration/Cargo.toml index 5ba042b..1027e02 100644 --- a/crates/migration/Cargo.toml +++ b/crates/migration/Cargo.toml @@ -11,7 +11,7 @@ path = "src/lib.rs" [dependencies] sea-orm = { workspace = true, features = ["sqlx", "sqlx-dep"] } entity = { path = "../entity" } -sea-orm-migration = { version = "1.1.0", features = ["runtime-tokio"] } +sea-orm-migration = { version = "1.1.19", features = ["runtime-tokio"] } tokio = { workspace = true } [features] diff --git a/crates/migration/src/lib.rs b/crates/migration/src/lib.rs index 97032f5..3d9e721 100644 --- a/crates/migration/src/lib.rs +++ b/crates/migration/src/lib.rs @@ -8,6 +8,7 @@ mod m20250918_215914_add_event_edition_maps_availability; mod m20250918_222203_add_event_edition_maps_disability; mod m20251002_100827_add_rm_mp_style_func; mod m20251004_214656_add_maps_medal_times; +mod m20260103_142202_players_maps_score; use sea_orm_migration::prelude::*; @@ -29,6 +30,7 @@ impl MigratorTrait for Migrator { Box::new(m20250918_222203_add_event_edition_maps_disability::Migration), Box::new(m20251002_100827_add_rm_mp_style_func::Migration), Box::new(m20251004_214656_add_maps_medal_times::Migration), + Box::new(m20260103_142202_players_maps_score::Migration), ] } } diff --git a/crates/migration/src/m20260103_142202_players_maps_score.rs b/crates/migration/src/m20260103_142202_players_maps_score.rs new file mode 100644 index 0000000..a686441 --- /dev/null +++ b/crates/migration/src/m20260103_142202_players_maps_score.rs @@ -0,0 +1,63 @@ +use sea_orm_migration::prelude::*; + +#[derive(DeriveMigrationName)] +pub struct Migration; + +#[async_trait::async_trait] +impl MigrationTrait for Migration { + async fn up(&self, manager: &SchemaManager) -> Result<(), DbErr> { + manager + .alter_table( + Table::alter() + .table(Players::Table) + .add_column(ColumnDef::new(Players::Score).double().default(0).take()) + .take(), + ) + .await?; + + manager + .alter_table( + Table::alter() + .table(Maps::Table) + .add_column(ColumnDef::new(Maps::Score).double().default(0).take()) + .take(), + ) + .await?; + + Ok(()) + } + + async fn down(&self, manager: &SchemaManager) -> Result<(), DbErr> { + manager + .alter_table( + Table::alter() + .table(Players::Table) + .drop_column(Players::Score) + .take(), + ) + .await?; + + manager + .alter_table( + Table::alter() + .table(Maps::Table) + .drop_column(Maps::Score) + .take(), + ) + .await?; + + Ok(()) + } +} + +#[derive(DeriveIden)] +enum Players { + Table, + Score, +} + +#[derive(DeriveIden)] +enum Maps { + Table, + Score, +} diff --git a/crates/records_lib/Cargo.toml b/crates/records_lib/Cargo.toml index b7acfaa..5dc8e87 100644 --- a/crates/records_lib/Cargo.toml +++ b/crates/records_lib/Cargo.toml @@ -27,7 +27,7 @@ mysql = ["sea-orm/sqlx-mysql"] postgres = ["sea-orm/sqlx-postgres"] sqlite = ["sea-orm/sqlx-sqlite"] sea-orm-proxy = ["sea-orm/proxy"] -test = ["sea-orm/mock"] +mock = ["sea-orm/mock"] [build-dependencies] rustc_version = "0.4.1" diff --git a/crates/records_lib/src/env.rs b/crates/records_lib/src/env.rs index 6f68583..d4f6a90 100644 --- a/crates/records_lib/src/env.rs +++ b/crates/records_lib/src/env.rs @@ -3,200 +3,214 @@ use std::time::Duration; use entity::types::InGameAlignment; use once_cell::sync::OnceCell; -mkenv::make_env! { -/// The environment used to set up a connection to the MySQL/MariaDB database. -pub DbUrlEnv: - /// The database URL. - #[cfg(debug_assertions)] - db_url: { - id: DbUrl(String), - kind: normal, - var: "DATABASE_URL", - desc: "The URL to the MySQL/MariaDB database", - }, - /// The path to the file containing the database URL. - #[cfg(not(debug_assertions))] - db_url: { - id: DbUrl(String), - kind: file, - var: "DATABASE_URL", - desc: "The path to the file containing the URL to the MySQL/MariaDB database", - }, +#[cfg(debug_assertions)] +mkenv::make_config! { + /// The environment used to set up a connection to the MySQL/MariaDB database. + pub struct DbUrlEnv { + /// The database URL. + pub db_url: { + var_name: "DATABASE_URL", + description: "The URL to the MySQL/MariaDB database", + } + } +} +#[cfg(not(debug_assertions))] +mkenv::make_config! { + /// The environment used to set up a connection to the MySQL/MariaDB database. + pub struct DbUrlEnv { + /// The path to the file containing the database URL. + pub db_url: { + var_name: "DATABASE_URL", + layers: [ + file_read(), + ], + description: "The path to the file containing the URL to the MySQL/MariaDB database", + } + } } -mkenv::make_env! { -/// The environment used to set up a connection with the Redis database. -pub RedisUrlEnv: - /// The URL to the Redis database. - redis_url: { - id: RedisUrl(String), - kind: normal, - var: "REDIS_URL", - desc: "The URL to the Redis database", +mkenv::make_config! { + /// The environment used to set up a connection with the Redis database. + pub struct RedisUrlEnv { + /// The URL to the Redis database. + pub redis_url: { + var_name: "REDIS_URL", + description: "The URL to the Redis database", + } } } -mkenv::make_env! { +mkenv::make_config! { /// The environment used to set up a connection to the databases of the API. - pub DbEnv includes [ + pub struct DbEnv { /// The environment for the MySQL/MariaDB database. - DbUrlEnv as db_url, + pub db_url: { DbUrlEnv }, /// The environment for the Redis database. - RedisUrlEnv as redis_url - ]: + pub redis_url: { RedisUrlEnv }, + } } -const DEFAULT_MAPPACK_TTL: i64 = 604_800; - -// In game default parameter values - -const DEFAULT_INGAME_SUBTITLE_ON_NEWLINE: bool = false; - -const DEFAULT_INGAME_TITLES_ALIGN: InGameAlignment = InGameAlignment::Left; -const DEFAULT_INGAME_LB_LINK_ALIGN: InGameAlignment = InGameAlignment::Left; -const DEFAULT_INGAME_AUTHORS_ALIGN: InGameAlignment = InGameAlignment::Right; - -const DEFAULT_INGAME_TITLES_POS_X: f64 = 181.; -const DEFAULT_INGAME_TITLES_POS_Y: f64 = -29.5; - -const DEFAULT_INGAME_LB_LINK_POS_X: f64 = 181.; -const DEFAULT_INGAME_LB_LINK_POS_Y: f64 = -29.5; - -const DEFAULT_INGAME_AUTHORS_POS_X: f64 = 181.; -const DEFAULT_INGAME_AUTHORS_POS_Y: f64 = -29.5; - -const DEFAULT_EVENT_SCORES_INTERVAL_SECONDS: u64 = 6 * 3600; -const DEFAULT_PLAYER_MAP_RANKING_INTERVAL_SECONDS: u64 = 3600 * 24 * 7; - -mkenv::make_env! { -/// The environment used by this crate. -pub LibEnv: - /// The time-to-live for the mappacks saved in the Redis database. - mappack_ttl: { - id: MappackTtl(i64), - kind: parse, - var: "RECORDS_API_MAPPACK_TTL", - desc: "The TTL (time-to-live) of the mappacks stored in Redis", - default: DEFAULT_MAPPACK_TTL, - }, - - /// The default alignment of the titles of an event edition in the Titlepack menu. - ingame_default_titles_align: { - id: InGameDefaultTitlesAlign(InGameAlignment), - kind: parse, - var: "RECORDS_API_INGAME_DEFAULT_TITLES_ALIGN", - desc: "The default alignment (either L for left or R for right) of the titles of \ - an event edition in the Titlepack menu", - default: DEFAULT_INGAME_TITLES_ALIGN, - }, - - /// The default alignment of an event edition title in the Titlepack menu. - ingame_default_lb_link_align: { - id: InGameDefaultLbLinkAlign(InGameAlignment), - kind: parse, - var: "RECORDS_API_INGAME_DEFAULT_LB_LINK_ALIGN", - desc: "The default alignment (either L for left or R for right) of the leaderboards link of \ - an event edition in the Titlepack menu", - default: DEFAULT_INGAME_LB_LINK_ALIGN, - }, - - /// The default alignment of an event edition title in the Titlepack menu. - ingame_default_authors_align: { - id: InGameDefaultAuthorsAlign(InGameAlignment), - kind: parse, - var: "RECORDS_API_INGAME_DEFAULT_AUTHORS_ALIGN", - desc: "The default alignment (either L for left or R for right) of the author list of \ - an event edition in the Titlepack menu", - default: DEFAULT_INGAME_AUTHORS_ALIGN, - }, - - /// The default position in X axis of the titles of an event edition in the Titlepack menu. - ingame_default_titles_pos_x: { - id: InGameDefaultTitlesPosX(f64), - kind: parse, - var: "RECORDS_API_INGAME_DEFAULT_TITLES_POS_X", - desc: "The default position in X axis of the titles of an event edition in \ - the Titlepack menu (double)", - default: DEFAULT_INGAME_TITLES_POS_X, - }, - - /// The default position in Y axis of the titles of an event edition in the Titlepack menu. - ingame_default_titles_pos_y: { - id: InGameDefaultTitlesPosY(f64), - kind: parse, - var: "RECORDS_API_INGAME_DEFAULT_TITLES_POS_Y", - desc: "The default position in Y axis of the titles of an event edition in \ - the Titlepack menu (double)", - default: DEFAULT_INGAME_TITLES_POS_Y, - }, - - /// The default value of the boolean related to either to put the subtitle of an event edition - /// on a new line or not, in the Titlepack menu. - ingame_default_subtitle_on_newline: { - id: InGameDefaultSubtitleOnNewLine(bool), - kind: parse, - var: "RECORDS_API_INGAME_DEFAULT_SUBTITLE_ON_NEWLINE", - desc: "The default value of the boolean related to either to put the subtitle of an \ - event edition on a new line or not in the Titlepack menu (boolean)", - default: DEFAULT_INGAME_SUBTITLE_ON_NEWLINE, - }, - - /// The default position in X axis of the leaderboards link of an event edition in the Titlepack menu. - ingame_default_lb_link_pos_x: { - id: InGameDefaultLbLinkPosX(f64), - kind: parse, - var: "RECORDS_API_INGAME_DEFAULT_LB_LINK_POS_X", - desc: "The default position in X axis of the leaderboards link of an event edition in \ - the Titlepack menu (double)", - default: DEFAULT_INGAME_LB_LINK_POS_X, - }, - - /// The default position in Y axis of the leaderboards link of an event edition in the Titlepack menu. - ingame_default_lb_link_pos_y: { - id: InGameDefaultLbLinkPosY(f64), - kind: parse, - var: "RECORDS_API_INGAME_DEFAULT_LB_LINK_POS_Y", - desc: "The default position in Y axis of the leaderboards link of an event edition in \ - the Titlepack menu (double)", - default: DEFAULT_INGAME_LB_LINK_POS_Y, - }, - - /// The default position in X axis of the author list of an event edition in the Titlepack menu. - ingame_default_authors_pos_x: { - id: InGameDefaultAuthorsPosX(f64), - kind: parse, - var: "RECORDS_API_INGAME_DEFAULT_AUTHORS_POS_X", - desc: "The default position in X axis of the author list of an event edition in \ - the Titlepack menu (double)", - default: DEFAULT_INGAME_AUTHORS_POS_X, - }, - - /// The default position in Y axis of the author list of an event edition in the Titlepack menu. - ingame_default_authors_pos_y: { - id: InGameDefaultAuthorsPosY(f64), - kind: parse, - var: "RECORDS_API_INGAME_DEFAULT_AUTHORS_POS_Y", - desc: "The default position in Y axis of the author list of an event edition in \ - the Titlepack menu (double)", - default: DEFAULT_INGAME_AUTHORS_POS_Y, - }, - - /// The interval of a mappack scores update - event_scores_interval: { - id: EventScoresInterval(Duration), - kind: parse(from_secs), - var: "EVENT_SCORES_INTERVAL_SECONDS", - desc: "The interval of the update of the event scores, in seconds", - default: DEFAULT_EVENT_SCORES_INTERVAL_SECONDS, - }, - - /// The interval of the player/map ranking update - player_map_ranking_scores_interval: { - id: PlayerMapRankingScoresInterval(Duration), - kind: parse(from_secs), - var: "PLAYER_MAP_RANKING_SCORES_INTERVAL", - desc: "The interval of the update of the player/map ranking scores, in seconds", - default: DEFAULT_PLAYER_MAP_RANKING_INTERVAL_SECONDS, +mkenv::make_config! { + /// The environment used by this crate. + pub struct LibEnv { + /// The time-to-live for the mappacks saved in the Redis database. + pub mappack_ttl: { + var_name: "RECORDS_API_MAPPACK_TTL", + layers: [ + parsed_from_str(), + or_default_val(|| 604_800), + ], + description: "The TTL (time-to-live) of the mappacks stored in Redis", + default_val_fmt: "604,800", + }, + + /// The default alignment of the titles of an event edition in the Titlepack menu. + pub ingame_default_titles_align: { + var_name: "RECORDS_API_INGAME_DEFAULT_TITLES_ALIGN", + layers: [ + parsed_from_str(), + or_default_val(|| InGameAlignment::Left), + ], + description: "The default alignment (either L for left or R for right) of the titles of \ + an event edition in the Titlepack menu", + default_val_fmt: "left", + }, + + /// The default alignment of an event edition title in the Titlepack menu. + pub ingame_default_lb_link_align: { + var_name: "RECORDS_API_INGAME_DEFAULT_LB_LINK_ALIGN", + layers: [ + parsed_from_str(), + or_default_val(|| InGameAlignment::Left), + ], + description: "The default alignment (either L for left or R for right) of the leaderboards link of \ + an event edition in the Titlepack menu", + default_val_fmt: "left", + }, + + + /// The default alignment of an event edition title in the Titlepack menu. + pub ingame_default_authors_align: { + var_name: "RECORDS_API_INGAME_DEFAULT_AUTHORS_ALIGN", + layers: [ + parsed_from_str(), + or_default_val(|| InGameAlignment::Right), + ], + description: "The default alignment (either L for left or R for right) of the author list of \ + an event edition in the Titlepack menu", + default_val_fmt: "right", + }, + + /// The default position in X axis of the titles of an event edition in the Titlepack menu. + pub ingame_default_titles_pos_x: { + var_name: "RECORDS_API_INGAME_DEFAULT_TITLES_POS_X", + layers: [ + parsed_from_str(), + or_default_val(|| 181.), + ], + description: "The default position in X axis of the titles of an event edition in \ + the Titlepack menu (double)", + default_val_fmt: "181.0", + }, + + /// The default position in Y axis of the titles of an event edition in the Titlepack menu. + pub ingame_default_titles_pos_y: { + var_name: "RECORDS_API_INGAME_DEFAULT_TITLES_POS_Y", + layers: [ + parsed_from_str(), + or_default_val(|| -29.5), + ], + description: "The default position in Y axis of the titles of an event edition in \ + the Titlepack menu (double)", + default_val_fmt: "-29.5", + }, + + /// The default value of the boolean related to either to put the subtitle of an event edition + /// on a new line or not, in the Titlepack menu. + pub ingame_default_subtitle_on_newline: { + var_name: "RECORDS_API_INGAME_DEFAULT_SUBTITLE_ON_NEWLINE", + layers: [ + parsed_from_str(), + or_default_val(|| false) + ], + description: "The default value of the boolean related to either to put the subtitle of an \ + event edition on a new line or not in the Titlepack menu (boolean)", + default_val_fmt: "false", + }, + + /// The default position in X axis of the leaderboards link of an event edition in the Titlepack menu. + pub ingame_default_lb_link_pos_x: { + var_name: "RECORDS_API_INGAME_DEFAULT_LB_LINK_POS_X", + layers: [ + parsed_from_str(), + or_default_val(|| 181.), + ], + description: "The default position in X axis of the leaderboards link of an event edition in \ + the Titlepack menu (double)", + default_val_fmt: "181.0", + }, + + /// The default position in Y axis of the leaderboards link of an event edition in the Titlepack menu. + pub ingame_default_lb_link_pos_y: { + var_name: "RECORDS_API_INGAME_DEFAULT_LB_LINK_POS_Y", + layers: [ + parsed_from_str(), + or_default_val(|| -29.5), + ], + description: "The default position in Y axis of the leaderboards link of an event edition in \ + the Titlepack menu (double)", + default_val_fmt: "-29.5", + }, + + /// The default position in X axis of the author list of an event edition in the Titlepack menu. + pub ingame_default_authors_pos_x: { + var_name: "RECORDS_API_INGAME_DEFAULT_AUTHORS_POS_X", + layers: [ + parsed_from_str(), + or_default_val(|| 181.), + ], + description: "The default position in X axis of the author list of an event edition in \ + the Titlepack menu (double)", + default_val_fmt: "181.0", + }, + + /// The default position in Y axis of the author list of an event edition in the Titlepack menu. + pub ingame_default_authors_pos_y: { + var_name: "RECORDS_API_INGAME_DEFAULT_AUTHORS_POS_Y", + layers: [ + parsed_from_str(), + or_default_val(|| -29.5), + ], + description: "The default position in Y axis of the author list of an event edition in \ + the Titlepack menu (double)", + default_val_fmt: "-29.5", + }, + + /// The interval of a mappack scores update + pub event_scores_interval: { + var_name: "EVENT_SCORES_INTERVAL_SECONDS", + layers: [ + parsed(|input| { + input.parse().map(Duration::from_secs).map_err(From::from) + }), + or_default_val(|| Duration::from_secs(6 * 3600)), + ], + description: "The interval of the update of the event scores, in seconds", + default_val_fmt: "6h", + }, + + /// The interval of the player/map ranking update + pub player_map_ranking_scores_interval: { + var_name: "PLAYER_MAP_RANKING_SCORES_INTERVAL", + layers: [ + parsed(|input| { + input.parse().map(Duration::from_secs).map_err(From::from) + }), + or_default_val(|| Duration::from_secs(3600 * 24 * 7)), + ], + description: "The interval of the update of the player/map ranking scores, in seconds", + default_val_fmt: "every week", + } } } @@ -214,6 +228,5 @@ pub fn init_env(env: LibEnv) { /// **Caution**: To use this function, the [`init_env()`] function must have been called at the start /// of the program. pub fn env() -> &'static LibEnv { - // SAFETY: this function is always called when `init_env()` is called at the start. - unsafe { ENV.get_unchecked() } + ENV.get().unwrap() } diff --git a/crates/records_lib/src/leaderboard.rs b/crates/records_lib/src/leaderboard.rs index e5c66eb..339e88b 100644 --- a/crates/records_lib/src/leaderboard.rs +++ b/crates/records_lib/src/leaderboard.rs @@ -113,7 +113,6 @@ impl CompetRankingByKeyIter for I {} #[derive(Debug, Clone, FromQueryResult)] struct RecordQueryRow { - player_id: u32, login: String, nickname: String, time: i32, @@ -184,18 +183,11 @@ pub async fn leaderboard_into( rows.reserve(result.len()); - let mut ranking_session = ranks::RankingSession::try_from_pool(redis_pool).await?; + let mut redis_conn = redis_pool.get().await?; for r in result { rows.push(Row { - rank: ranks::get_rank_in_session( - &mut ranking_session, - map_id, - r.player_id, - r.time, - event, - ) - .await?, + rank: ranks::get_rank(&mut redis_conn, map_id, r.time, event).await?, login: r.login, nickname: r.nickname, time: r.time, diff --git a/crates/records_lib/src/mappack.rs b/crates/records_lib/src/mappack.rs index 3b70b43..525e615 100644 --- a/crates/records_lib/src/mappack.rs +++ b/crates/records_lib/src/mappack.rs @@ -1,5 +1,6 @@ //! This module contains anything related to mappacks in this library. +use mkenv::prelude::*; use std::{fmt, time::SystemTime}; use deadpool_redis::redis::{self, AsyncCommands, SetExpiry, SetOptions, ToRedisArgs}; @@ -132,7 +133,7 @@ impl AnyMappackId<'_> { /// /// Only regular MX mappacks have a time-to-live. fn get_ttl(&self) -> Option { - self.has_ttl().then_some(crate::env().mappack_ttl) + self.has_ttl().then_some(crate::env().mappack_ttl.get()) } } @@ -365,7 +366,7 @@ async fn calc_scores( let mut scores = Vec::::with_capacity(mappack.len()); - let mut ranking_session = ranks::RankingSession::try_from_pool(redis_pool).await?; + let mut redis_conn = redis_pool.get().await?; for (i, map) in mappack.iter().enumerate() { let mut query = Query::select(); @@ -420,14 +421,7 @@ async fn calc_scores( } let record = RankedRecordRow { - rank: ranks::get_rank_in_session( - &mut ranking_session, - map.id, - record.record.record_player_id, - record.record.time, - event, - ) - .await?, + rank: ranks::get_rank(&mut redis_conn, map.id, record.record.time, event).await?, record, }; records.push(record); diff --git a/crates/records_lib/src/pool.rs b/crates/records_lib/src/pool.rs index 47acd3e..e0925b5 100644 --- a/crates/records_lib/src/pool.rs +++ b/crates/records_lib/src/pool.rs @@ -25,7 +25,7 @@ pub enum DatabaseCreationError { } #[inline(always)] -#[cfg(feature = "test")] +#[cfg(feature = "mock")] const fn empty_query_results() -> std::iter::Empty> { std::iter::empty::>() } @@ -55,8 +55,8 @@ impl Database { /// with initial data for the mock database. /// /// This is used for testing, by simulating an SQL database. - #[cfg(feature = "test")] - #[cfg_attr(nightly, doc(cfg(feature = "test")))] + #[cfg(feature = "mock")] + #[cfg_attr(nightly, doc(cfg(feature = "mock")))] pub fn from_mock_db_with_initial( db_backend: sea_orm::DbBackend, redis_url: String, @@ -78,8 +78,8 @@ impl Database { /// with initial query results for the mock database. /// /// This is used for testing, by simulating an SQL database. - #[cfg(feature = "test")] - #[cfg_attr(nightly, doc(cfg(feature = "test")))] + #[cfg(feature = "mock")] + #[cfg_attr(nightly, doc(cfg(feature = "mock")))] pub fn from_mock_db_with_query_results( db_backend: sea_orm::DbBackend, redis_url: String, @@ -95,8 +95,8 @@ impl Database { /// with initial exec results for the mock database. /// /// This is used for testing, by simulating an SQL database. - #[cfg(feature = "test")] - #[cfg_attr(nightly, doc(cfg(feature = "test")))] + #[cfg(feature = "mock")] + #[cfg_attr(nightly, doc(cfg(feature = "mock")))] pub fn from_mock_db_with_exec_results( db_backend: sea_orm::DbBackend, redis_url: String, @@ -112,8 +112,8 @@ impl Database { /// with no data in the mock database. /// /// This is used for testing, by simulating an SQL database. - #[cfg(feature = "test")] - #[cfg_attr(nightly, doc(cfg(feature = "test")))] + #[cfg(feature = "mock")] + #[cfg_attr(nightly, doc(cfg(feature = "mock")))] pub fn from_mock_db( db_backend: sea_orm::DbBackend, redis_url: String, @@ -129,7 +129,7 @@ pub fn clone_dbconn(conn: &DbConn) -> DbConn { sea_orm::DatabaseConnection::SqlxMySqlPoolConnection(conn) => { sea_orm::DatabaseConnection::SqlxMySqlPoolConnection(conn.clone()) } - #[cfg(feature = "test")] + #[cfg(feature = "mock")] sea_orm::DatabaseConnection::MockDatabaseConnection(conn) => { sea_orm::DatabaseConnection::MockDatabaseConnection(conn.clone()) } diff --git a/crates/records_lib/src/ranks.rs b/crates/records_lib/src/ranks.rs index 9d3e126..9cf76ba 100644 --- a/crates/records_lib/src/ranks.rs +++ b/crates/records_lib/src/ranks.rs @@ -1,16 +1,9 @@ //! Module which contains utility functions used to update maps leaderboards and get players ranks. use crate::{ - RedisConnection, RedisPool, - error::RecordsResult, - internal, - opt_event::OptEvent, - redis_key::{MapKey, map_key}, -}; -use deadpool_redis::{ - PoolError, - redis::{self, AsyncCommands}, + RedisConnection, RedisPool, error::RecordsResult, opt_event::OptEvent, redis_key::map_key, }; +use deadpool_redis::redis::{self, AsyncCommands}; use entity::{event_edition_records, records}; use futures::TryStreamExt; use sea_orm::{ @@ -127,131 +120,25 @@ async fn force_update_locked( Ok(()) } -// This is just a Redis connection retrieved from the pool, we just don't expose it to make sure -// we exlusively own it in the body of our rank retrieval implementation. -/// A current ranking session. -/// -/// See the documentation of the [`get_rank_in_session`] function for more information. -pub struct RankingSession { - redis_conn: RedisConnection, -} - -impl RankingSession { - /// Returns a new ranking session from the provided Redis pool. - pub async fn try_from_pool(pool: &RedisPool) -> Result { - let redis_conn = pool.get().await?; - Ok(Self { redis_conn }) - } -} - -async fn get_rank_impl( - redis_conn: &mut RedisConnection, - key: &MapKey<'_>, - player_id: u32, - time: i32, -) -> RecordsResult> { - let score: Option = redis_conn.zscore(key, player_id).await?; - - let mut pipe = redis::pipe(); - pipe.atomic() - .zadd(key, player_id, time) - .ignore() - .zcount(key, "-inf", time - 1); - - // Restore the previous state - let _ = match score { - Some(old_time) => pipe.zadd(key, player_id, old_time).ignore(), - None => pipe.zrem(key, player_id).ignore(), - }; - - let response: Option<(i32,)> = pipe.query_async(redis_conn).await?; - - Ok(response.map(|(t,)| t + 1)) -} - -/// Gets the rank of the time of a player on a map, using the current ranking session. -/// -/// This is like [`get_rank`], but used when retrieving a large amount of ranks during an operation, -/// to avoid creating a new Redis connection from the pool each time. +/// Gets the rank of the time of a player on a map. /// /// ## Example /// /// ```ignore -/// let mut session = ranks::RankingSession::try_from_pool(&pool).await?; /// let rank1 = -/// ranks::get_rank_in_session(&mut session, map_id, player1_id, time1, Default::default()) +/// ranks::get_rank_in_session(&mut redis_conn, map_id, time1, Default::default()) /// .await?; /// let rank2 = -/// ranks::get_rank_in_session(&mut session, map_id, player2_id, time2, Default::default()) +/// ranks::get_rank_in_session(&mut redis_conn, map_id, time2, Default::default()) /// .await?; /// ``` -/// -/// See the documentation of the [`get_rank`] function for more information. -pub async fn get_rank_in_session( - session: &mut RankingSession, - map_id: u32, - player_id: u32, - time: i32, - event: OptEvent<'_>, -) -> RecordsResult { - const MAX_TXN_RETRY_COUNT: usize = 100; - - let key = map_key(map_id, event); - - for _ in 0..MAX_TXN_RETRY_COUNT { - redis::cmd("WATCH") - .arg(&key) - .exec_async(&mut session.redis_conn) - .await?; - - let result = get_rank_impl(&mut session.redis_conn, &key, player_id, time).await; - - match result { - Ok(Some(rank)) => return Ok(rank), - // The watchpoint triggered so a null response was returned, we restart the transaction - Ok(None) => (), - Err(e) => { - redis::cmd("UNWATCH") - .exec_async(&mut session.redis_conn) - .await?; - return Err(e); - } - } - } - - // We reached the max amount of retry, which is very unlikely. - // Just get the rank without the watch part. - get_rank_impl(&mut session.redis_conn, &key, player_id, time) - .await - .and_then(|r| { - r.ok_or_else(|| { - internal!( - "couldn't retrieve rank of member {player_id} on ZSET `{key}` \ - (required score: {time})" - ) - }) - }) -} - -/// Gets the rank of the time of a player on a map. -/// -/// Note: if you intend to use this function in a loop, prefer using the [`get_rank_in_session`] -/// function instead. -/// -/// This function is concurrency-safe, so it guarantees that at the moment it is called, it returns -/// the correct rank for the provided time. It also keeps the leaderboard unchanged. -/// -/// However, it doesn't guarantee that the related ZSET of the map's leaderboard is synchronized -/// with its SQL database version. For this, please use the [`update_leaderboard`] function. -/// -/// The ranking type is the standard competition ranking (1224). pub async fn get_rank( - redis_pool: &RedisPool, + redis_conn: &mut RedisConnection, map_id: u32, - player_id: u32, time: i32, event: OptEvent<'_>, ) -> RecordsResult { - let mut session = RankingSession::try_from_pool(redis_pool).await?; - get_rank_in_session(&mut session, map_id, player_id, time, event).await + let key = map_key(map_id, event); + let count: i32 = redis_conn.zcount(key, "-inf", time - 1).await?; + Ok(count + 1) } diff --git a/crates/socc/src/main.rs b/crates/socc/src/main.rs index 26ec42e..c01e9e7 100644 --- a/crates/socc/src/main.rs +++ b/crates/socc/src/main.rs @@ -7,7 +7,7 @@ use std::{future::Future, time::Duration}; use anyhow::Context; -use mkenv::Env as _; +use mkenv::prelude::*; use records_lib::{Database, DbEnv, LibEnv}; use tokio::{task::JoinHandle, time}; use tracing::info; @@ -44,19 +44,28 @@ fn setup_tracing() -> anyhow::Result<()> { .map_err(|e| anyhow::format_err!("{e}")) } -mkenv::make_env! {Env includes [DbEnv as db_env, LibEnv as lib_env]:} +mkenv::make_config! { + struct Env { + db_env: { DbEnv }, + lib_env: { LibEnv }, + } +} #[tokio::main] async fn main() -> anyhow::Result<()> { dotenvy::dotenv()?; setup_tracing()?; - let env = Env::try_get()?; - let event_scores_interval = env.lib_env.event_scores_interval; - let player_ranking_scores_interval = env.lib_env.player_map_ranking_scores_interval; + let env = Env::define(); + env.init(); + let event_scores_interval = env.lib_env.event_scores_interval.get(); + let player_ranking_scores_interval = env.lib_env.player_map_ranking_scores_interval.get(); records_lib::init_env(env.lib_env); - let db = - Database::from_db_url(env.db_env.db_url.db_url, env.db_env.redis_url.redis_url).await?; + let db = Database::from_db_url( + env.db_env.db_url.db_url.get(), + env.db_env.redis_url.redis_url.get(), + ) + .await?; let event_scores_handle = tokio::spawn(handle( db.clone(), diff --git a/crates/socc/src/player_ranking.rs b/crates/socc/src/player_ranking.rs index 39f97a3..347a771 100644 --- a/crates/socc/src/player_ranking.rs +++ b/crates/socc/src/player_ranking.rs @@ -1,14 +1,16 @@ use anyhow::Context as _; use chrono::{DateTime, Utc}; use deadpool_redis::redis; +use entity::{maps, players}; use player_map_ranking::compute_scores; use records_lib::{ Database, RedisPool, redis_key::{map_ranking, player_ranking}, + sync, }; -use sea_orm::ConnectionTrait; +use sea_orm::{ActiveValue::Set, ConnectionTrait, EntityTrait, TransactionTrait}; -async fn do_update( +async fn do_update( conn: &C, redis_pool: &RedisPool, from: Option>, @@ -25,12 +27,35 @@ async fn do_update( let mut pipe = redis::pipe(); let pipe = pipe.atomic(); - for (player, score) in scores.player_scores { - pipe.zadd(player_ranking(), player.inner.id, score); - } - for (map, score) in scores.map_scores { - pipe.zadd(map_ranking(), map.inner.id, score); - } + sync::transaction(conn, async |txn| { + for (player, score) in scores.player_scores { + players::Entity::update(players::ActiveModel { + id: Set(player.inner.id), + score: Set(score), + ..Default::default() + }) + .exec(txn) + .await?; + + pipe.zadd(player_ranking(), player.inner.id, score); + } + + for (map, score) in scores.map_scores { + maps::Entity::update(maps::ActiveModel { + id: Set(map.inner.id), + score: Set(score), + ..Default::default() + }) + .exec(txn) + .await?; + + pipe.zadd(map_ranking(), map.inner.id, score); + } + + anyhow::Ok(()) + }) + .await?; + pipe.exec_async(&mut redis_conn) .await .context("couldn't save scores to Redis")?; diff --git a/crates/test-env/Cargo.toml b/crates/test-env/Cargo.toml new file mode 100644 index 0000000..9d66def --- /dev/null +++ b/crates/test-env/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "test-env" +version = "0.1.0" +edition = "2024" + +[dependencies] +anyhow.workspace = true +dotenvy.workspace = true +sea-orm.workspace = true +tracing.workspace = true +tracing-subscriber.workspace = true +records-lib = { path = "../records_lib" } +migration = { path = "../migration" } +futures.workspace = true +mkenv.workspace = true +rand.workspace = true + +[features] +default = [] +mysql = ["records-lib/mysql"] +postgres = ["records-lib/postgres"] diff --git a/crates/test-env/src/lib.rs b/crates/test-env/src/lib.rs new file mode 100644 index 0000000..034daee --- /dev/null +++ b/crates/test-env/src/lib.rs @@ -0,0 +1,146 @@ +use std::{env, panic}; + +use anyhow::Context as _; +use futures::FutureExt as _; +use migration::MigratorTrait as _; +use mkenv::prelude::*; +use records_lib::{Database, DbEnv, pool::get_redis_pool}; +use sea_orm::{ConnectionTrait as _, DbConn}; +use tracing_subscriber::fmt::TestWriter; + +fn is_db_drop_forced() -> bool { + env::args_os().any(|arg| arg == "--force-drop-db") +} + +pub fn get_map_id() -> u32 { + rand::random_range(..u32::MAX) +} + +pub trait IntoResult { + type Out; + + fn into_result(self) -> anyhow::Result; +} + +impl IntoResult for () { + type Out = (); + + fn into_result(self) -> anyhow::Result { + Ok(()) + } +} + +impl IntoResult for Result +where + anyhow::Error: From, +{ + type Out = T; + + fn into_result(self) -> anyhow::Result { + self.map_err(From::from) + } +} + +pub fn init_env() -> anyhow::Result<()> { + match dotenvy::dotenv() { + Err(err) if !err.not_found() => return Err(err).context("cannot retrieve .env files"), + _ => (), + } + + let _ = tracing_subscriber::fmt() + .with_writer(TestWriter::new()) + .try_init(); + + Ok(()) +} + +pub async fn wrap(test: F) -> anyhow::Result<::Out> +where + F: AsyncFnOnce(Database) -> R, + R: IntoResult, +{ + init_env()?; + let env = DbEnv::define(); + + let master_db = sea_orm::Database::connect(&env.db_url.db_url.get()).await?; + + // For some reasons, on MySQL/MariaDB, using a schema name with some capital letters + // may produce the error code 1932 (42S02) "Table 'X' doesn't exist in engine" when + // doing a query. + let db_name = format!( + "_test_db_{}", + records_lib::gen_random_str(10).to_lowercase() + ); + + master_db + .execute_unprepared(&format!("create database {db_name}")) + .await?; + tracing::info!("Created database {db_name}"); + + let db = match master_db { + #[cfg(feature = "mysql")] + sea_orm::DatabaseConnection::SqlxMySqlPoolConnection(_) => { + use sea_orm::sqlx; + + let connect_options = master_db.get_mysql_connection_pool().connect_options(); + let connect_options = (*connect_options).clone(); + let options = connect_options.database(&db_name); + let db = sqlx::mysql::MySqlPool::connect_with(options).await?; + DbConn::from(db) + } + #[cfg(feature = "postgres")] + sea_orm::DatabaseConnection::SqlxPostgresPoolConnection(_) => { + use sea_orm::sqlx; + + let connect_options = master_db.get_postgres_connection_pool().connect_options(); + let connect_options = (*connect_options).clone(); + let options = connect_options.database(&db_name); + let db = sqlx::postgres::PgPool::connect_with(options).await?; + DbConn::from(db) + } + _ => unreachable!("must enable either `mysql` or `postgres` feature for testing"), + }; + + migration::Migrator::up(&db, None).await?; + + let r = panic::AssertUnwindSafe(test(Database { + sql_conn: db, + redis_pool: get_redis_pool(env.redis_url.redis_url.get())?, + })) + .catch_unwind() + .await; + + if is_db_drop_forced() { + master_db + .execute_unprepared(&format!("drop database {db_name}")) + .await?; + tracing::info!("Database {db_name} force-deleted"); + match r { + Ok(r) => r.into_result(), + Err(e) => { + tracing::info!("Test failed"); + panic::resume_unwind(e) + } + } + } else { + match r.map(IntoResult::into_result) { + Ok(Ok(out)) => { + master_db + .execute_unprepared(&format!("drop database {db_name}")) + .await?; + Ok(out) + } + other => { + tracing::info!( + "Test failed, leaving database {db_name} as-is. \ + Run with `--force-drop-db` to drop the database everytime." + ); + match other { + Ok(Err(e)) => Err(e), + Err(e) => panic::resume_unwind(e), + _ => unreachable!(), + } + } + } + } +}