From c8f07c07eade739d7cd8e4f7d7e41bcec4c9dab7 Mon Sep 17 00:00:00 2001 From: nebilam <49345234+Nebilam@users.noreply.github.com> Date: Mon, 22 Dec 2025 22:51:53 +0100 Subject: [PATCH 01/11] add by_id method to get a service --- database/src/repos/service.rs | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/database/src/repos/service.rs b/database/src/repos/service.rs index d44b0e2..297a782 100644 --- a/database/src/repos/service.rs +++ b/database/src/repos/service.rs @@ -74,4 +74,12 @@ impl<'a> ServiceRepo<'a> { .await? .ok_or(DatabaseError::NotFound) } + + pub async fn by_id(&self, id: u32) -> Result { + sqlx::query_as("SELECT id, name, api_key FROM service WHERE id == ? LIMIT 1;") + .bind(id) + .fetch_optional(self.db) + .await? + .ok_or(DatabaseError::NotFound) + } } From acee6a4a5d4e32602e55772da89dedbc7c40880a Mon Sep 17 00:00:00 2001 From: nebilam <49345234+Nebilam@users.noreply.github.com> Date: Tue, 23 Dec 2025 16:56:48 +0100 Subject: [PATCH 02/11] add .vscode to gitignore --- .gitignore | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 4fc2701..fc60c92 100644 --- a/.gitignore +++ b/.gitignore @@ -4,4 +4,5 @@ *.db *.db-wal *.db-shm -/.direnv \ No newline at end of file +/.direnv +/.vscode From 20fb235c2b4452107f880dbd08d8293e047add69 Mon Sep 17 00:00:00 2001 From: nebilam <49345234+Nebilam@users.noreply.github.com> Date: Tue, 23 Dec 2025 17:02:26 +0100 Subject: [PATCH 03/11] add by_goal_id method to achievement repo --- database/src/repos/achievement.rs | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/database/src/repos/achievement.rs b/database/src/repos/achievement.rs index 877636c..671cb56 100644 --- a/database/src/repos/achievement.rs +++ b/database/src/repos/achievement.rs @@ -40,6 +40,31 @@ impl<'a> AchievementRepo<'a> { .await?) } + async fn by_goal_id(&self, goal_id: u32) -> Result, DatabaseError> { + Ok(query_as( + " + SELECT + achievement.id as achievement_id, + achievement.name as achievement_name, + service_id, + goal.id as goal_id, + description as goal_description, + sequence as goal_sequence + + FROM + goal + inner join achievement on achievement.id = goal.achievement_id + inner join goal on goal.achievement_id = achievement.id + WHERE + goal.id = ?; + + ", + ) + .bind(goal_id) + .fetch_all(self.db) + .await?) + } + pub async fn for_service( &self, service_id: u32, From 54208d834875c4b6dbc32d11ada430c4005e6c94 Mon Sep 17 00:00:00 2001 From: nebilam <49345234+Nebilam@users.noreply.github.com> Date: Tue, 23 Dec 2025 17:13:39 +0100 Subject: [PATCH 04/11] add route to unlock goals --- database/src/repos/achievement.rs | 40 +++++++++++++++++++++++-------- src/dto/achievement.rs | 13 ++++++++++ src/error.rs | 4 ++++ src/handlers/service.rs | 24 +++++++++++++++++-- src/lib.rs | 4 ++++ 5 files changed, 73 insertions(+), 12 deletions(-) diff --git a/database/src/repos/achievement.rs b/database/src/repos/achievement.rs index 671cb56..94824f9 100644 --- a/database/src/repos/achievement.rs +++ b/database/src/repos/achievement.rs @@ -44,20 +44,19 @@ impl<'a> AchievementRepo<'a> { Ok(query_as( " SELECT - achievement.id as achievement_id, - achievement.name as achievement_name, + achievement.id as achievement_id, + achievement.name as achievement_name, service_id, - goal.id as goal_id, - description as goal_description, - sequence as goal_sequence + goal2.id as goal_id, + goal2.description as goal_description, + goal2.sequence as goal_sequence FROM - goal - inner join achievement on achievement.id = goal.achievement_id - inner join goal on goal.achievement_id = achievement.id + goal as goal1 + inner join achievement on achievement.id = goal1.achievement_id + inner join goal as goal2 on goal2.achievement_id = achievement.id WHERE - goal.id = ?; - + goal1.id = ?; ", ) .bind(goal_id) @@ -144,4 +143,25 @@ impl<'a> AchievementRepo<'a> { tx.commit().await?; self.by_id(db_achievement.id).await } + + pub async fn unlock_goal( + &self, + user_id: u32, + goal_id: u32, + ) -> Result, DatabaseError> { + query( + " + INSERT INTO + unlock (user_id, goal_id) + VALUES + (?,?); + ", + ) + .bind(user_id) + .bind(goal_id) + .execute(self.db) + .await?; + + self.by_goal_id(goal_id).await + } } diff --git a/src/dto/achievement.rs b/src/dto/achievement.rs index 9046bfb..63ba143 100644 --- a/src/dto/achievement.rs +++ b/src/dto/achievement.rs @@ -34,6 +34,19 @@ impl AchievementPayload { Ok(achievements) } + + pub async fn unlock_goal( + db: &Database, + user_id: u32, + goal_id: u32, + ) -> Result { + let rows = db.achievements().unlock_goal(user_id, goal_id).await?; + + // pack rows into an achievement payload + let mut rows = rows.into_iter().peekable(); + let achievement = unpack_next_achievement(&mut rows).ok_or(AppError::NotFound)?; + Ok(achievement) + } } #[derive(Serialize, Deserialize, Debug, PartialEq)] diff --git a/src/error.rs b/src/error.rs index 6088190..d7d965a 100644 --- a/src/error.rs +++ b/src/error.rs @@ -60,6 +60,9 @@ pub enum AppError { #[error("User was not logged in")] NotLoggedIn, + #[error("Wrong api key")] + BadApiKey, + #[error("Forbidden")] Forbidden, @@ -80,6 +83,7 @@ impl AppError { let (status, msg) = match self { Self::PayloadError(_) => (StatusCode::BAD_REQUEST, "Payload error"), Self::NotLoggedIn => (StatusCode::UNAUTHORIZED, "Not logged in."), + Self::BadApiKey => (StatusCode::UNAUTHORIZED, "Bad api key."), Self::Forbidden => (StatusCode::FORBIDDEN, "Forbidden."), Self::NoFile => ( StatusCode::BAD_REQUEST, diff --git a/src/handlers/service.rs b/src/handlers/service.rs index 89e47c1..fe4dab5 100644 --- a/src/handlers/service.rs +++ b/src/handlers/service.rs @@ -1,9 +1,14 @@ use axum::{Json, extract::Path}; +use axum_extra::TypedHeader; use database::Database; +use headers::{Authorization, authorization::Bearer}; use crate::{ - dto::service::{ - ServiceCreatePayload, ServicePatchPayload, ServicePayloadAdmin, ServicePayloadUser, + dto::{ + achievement::AchievementPayload, + service::{ + ServiceCreatePayload, ServicePatchPayload, ServicePayloadAdmin, ServicePayloadUser, + }, }, error::AppError, }; @@ -42,4 +47,19 @@ impl ServiceHandler { ServicePayloadAdmin::regenerate_api_key(&db, service_id).await?, )) } + + pub async fn unlock_goal( + db: Database, + Path((user_id, service_id, goal_id)): Path<(u32, u32, u32)>, + api_key: TypedHeader>, + ) -> Result, AppError> { + let expected_api_key = db.services().by_id(service_id).await?.api_key; + if api_key.token() != expected_api_key { + return Err(AppError::BadApiKey); + } + + Ok(Json( + AchievementPayload::unlock_goal(&db, user_id, goal_id).await?, + )) + } } diff --git a/src/lib.rs b/src/lib.rs index 507b052..2b069c6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -78,6 +78,10 @@ fn open_routes() -> Router { .route("/oauth/callback", get(AuthHandler::callback)) .route("/image/{id}", get(ImageHandler::get)) .route("/version", get(VersionHandler::get)) + .route( + "/users/{id}/unlock/{service_id}/{goal_id}", + post(ServiceHandler::unlock_goal), + ) } fn authenticated_routes() -> Router { From 83c7a8bebbfaba965e7bb0dd27d01feebaffccb4 Mon Sep 17 00:00:00 2001 From: nebilam <49345234+Nebilam@users.noreply.github.com> Date: Tue, 23 Dec 2025 22:12:29 +0100 Subject: [PATCH 05/11] add tests for unlocking goals --- src/error.rs | 2 +- tests/achievement.rs | 33 ++++++++++++++++++++++++++++++- tests/common/router.rs | 45 +++++++++++++++++++++++++++++++++++++----- 3 files changed, 73 insertions(+), 7 deletions(-) diff --git a/src/error.rs b/src/error.rs index d7d965a..af3f6b7 100644 --- a/src/error.rs +++ b/src/error.rs @@ -48,7 +48,7 @@ pub enum AppError { #[error("Submitted image resolution was too large")] ImageResTooLarge, - #[error("The requested image was not found")] + #[error("Not found")] NotFound, #[error("Submitted file had an incorrect type")] diff --git a/tests/achievement.rs b/tests/achievement.rs index 8f087b3..8d75554 100644 --- a/tests/achievement.rs +++ b/tests/achievement.rs @@ -6,7 +6,9 @@ use zpi::dto::{ }; use crate::common::{ - into_struct::IntoStruct, router::AuthenticatedRouter, test_objects::TestObjects, + into_struct::IntoStruct, + router::{AuthenticatedRouter, UnauthenticatedRouter}, + test_objects::TestObjects, }; mod common; @@ -81,3 +83,32 @@ async fn post_achievements_wrong_sequence(db_pool: SqlitePool) { let response = router.post("/admin/services/1/achievements", &body).await; assert_eq!(response.status(), StatusCode::BAD_REQUEST); } + +#[sqlx::test(fixtures("services"))] +#[test_log::test] +async fn unlock_goal_wrong_api_key(db_pool: SqlitePool) { + let router = UnauthenticatedRouter::new(db_pool) + .await + .with_api_key("wrongapikey"); + + let response = router.post("/users/1/unlock/1/1", None::<()>).await; + assert_eq!(response.status(), StatusCode::UNAUTHORIZED); +} + +#[sqlx::test(fixtures("services", "achievements", "users"))] +#[test_log::test] +async fn unlock_goal(db_pool: SqlitePool) { + let router = UnauthenticatedRouter::new(db_pool) + .await + .with_api_key("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"); + + let response = router.post("/users/1/unlock/1/1", None::<()>).await; + assert_eq!(response.status(), StatusCode::OK); + + let data: AchievementPayload = response.into_struct().await; + + assert_eq!(data, TestObjects::achievement_1()); +} + +// TODO wat als goal niet bestaat -> status code 404 +// TODO wat als goal al unlocked is -> status code 200 diff --git a/tests/common/router.rs b/tests/common/router.rs index e64f07d..9f4bca7 100644 --- a/tests/common/router.rs +++ b/tests/common/router.rs @@ -76,7 +76,7 @@ impl AuthenticatedRouter { self.request(Method::PATCH, path, Some(body)).await } - /// send a patch request to an endpoint on this router + /// send a post request to an endpoint on this router /// /// must have a leading "/" pub async fn post(self, path: &str, body: T) -> Response { @@ -108,6 +108,7 @@ impl AuthenticatedRouter { } pub struct UnauthenticatedRouter { router: Router, + api_key: Option, } impl UnauthenticatedRouter { @@ -129,6 +130,7 @@ impl UnauthenticatedRouter { Self { router: api_router().layer(session_layer).with_state(state), + api_key: None, } } @@ -136,9 +138,42 @@ impl UnauthenticatedRouter { /// /// must have a leading "/" pub async fn get(self, path: &str) -> Response { - self.router - .oneshot(Request::builder().uri(path).body(Body::empty()).unwrap()) - .await - .unwrap() + self.request(Method::GET, path, None::<()>).await + } + + /// send a post request to an endpoint on this router + /// + /// must have a leading "/" + pub async fn post(self, path: &str, body: T) -> Response { + self.request(Method::POST, path, Some(body)).await + } + + /// send a request to an endpoint on this router + /// + /// must have a leading "/" + async fn request( + self, + method: Method, + path: &str, + body: Option, + ) -> Response { + let mut request_builder = Request::builder().method(method).uri(path); + + if let Some(api_key) = &self.api_key { + request_builder = request_builder.header(header::AUTHORIZATION, api_key); + } + + let request = match body { + Some(body) => request_builder + .header(header::CONTENT_TYPE, mime::APPLICATION_JSON.as_ref()) + .body(Json(body).into_response().into_body()), + None => request_builder.body(Body::empty()), + }; + self.router.oneshot(request.unwrap()).await.unwrap() + } + + pub fn with_api_key(mut self, api_key: &str) -> Self { + self.api_key = Some("Bearer ".to_string() + api_key); + self } } From 9580a821acc2119441973b4f74cf30e4ee642696 Mon Sep 17 00:00:00 2001 From: Hannes Date: Wed, 24 Dec 2025 02:07:41 +0100 Subject: [PATCH 06/11] refactor testrouters --- tests/achievement.rs | 25 ++---- tests/common/router.rs | 154 +++++++++++++---------------------- tests/common/test_objects.rs | 8 ++ tests/image.rs | 10 +-- tests/service.rs | 16 ++-- tests/tags.rs | 6 +- tests/user.rs | 22 ++--- 7 files changed, 93 insertions(+), 148 deletions(-) diff --git a/tests/achievement.rs b/tests/achievement.rs index 8d75554..be7d5d7 100644 --- a/tests/achievement.rs +++ b/tests/achievement.rs @@ -5,18 +5,14 @@ use zpi::dto::{ goal::GoalCreatePayload, }; -use crate::common::{ - into_struct::IntoStruct, - router::{AuthenticatedRouter, UnauthenticatedRouter}, - test_objects::TestObjects, -}; +use crate::common::{into_struct::IntoStruct, router::TestRouter, test_objects::TestObjects}; mod common; #[sqlx::test(fixtures("services", "achievements"))] #[test_log::test] async fn get_achievements_for_service(db_pool: SqlitePool) { - let router = AuthenticatedRouter::new(db_pool).await; + let router = TestRouter::as_admin(db_pool).await; let response = router.get("/admin/services/1/achievements").await; assert_eq!(response.status(), StatusCode::OK); @@ -32,7 +28,7 @@ async fn get_achievements_for_service(db_pool: SqlitePool) { #[sqlx::test(fixtures("services"))] #[test_log::test] async fn post_achievements_for_service(db_pool: SqlitePool) { - let router = AuthenticatedRouter::new(db_pool).await; + let router = TestRouter::as_admin(db_pool).await; let body = AchievementCreatePayload { name: "Achievements".into(), goals: vec![ @@ -58,7 +54,7 @@ async fn post_achievements_for_service(db_pool: SqlitePool) { #[sqlx::test(fixtures("services"))] #[test_log::test] async fn post_achievements_wrong_sequence(db_pool: SqlitePool) { - let router = AuthenticatedRouter::new(db_pool).await; + let router = TestRouter::as_admin(db_pool).await; let mut body = AchievementCreatePayload { name: "Achievements".into(), goals: vec![ @@ -73,10 +69,7 @@ async fn post_achievements_wrong_sequence(db_pool: SqlitePool) { ], }; - let response = router - .clone() - .post("/admin/services/1/achievements", &body) - .await; + let response = router.post("/admin/services/1/achievements", &body).await; assert_eq!(response.status(), StatusCode::BAD_REQUEST); body.goals[1].sequence = 1; @@ -87,9 +80,7 @@ async fn post_achievements_wrong_sequence(db_pool: SqlitePool) { #[sqlx::test(fixtures("services"))] #[test_log::test] async fn unlock_goal_wrong_api_key(db_pool: SqlitePool) { - let router = UnauthenticatedRouter::new(db_pool) - .await - .with_api_key("wrongapikey"); + let router = TestRouter::with_api_key(db_pool, "wrongapikey").await; let response = router.post("/users/1/unlock/1/1", None::<()>).await; assert_eq!(response.status(), StatusCode::UNAUTHORIZED); @@ -98,9 +89,7 @@ async fn unlock_goal_wrong_api_key(db_pool: SqlitePool) { #[sqlx::test(fixtures("services", "achievements", "users"))] #[test_log::test] async fn unlock_goal(db_pool: SqlitePool) { - let router = UnauthenticatedRouter::new(db_pool) - .await - .with_api_key("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"); + let router = TestRouter::with_api_key(db_pool, "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa").await; let response = router.post("/users/1/unlock/1/1", None::<()>).await; assert_eq!(response.status(), StatusCode::OK); diff --git a/tests/common/router.rs b/tests/common/router.rs index 9f4bca7..e8f8360 100644 --- a/tests/common/router.rs +++ b/tests/common/router.rs @@ -1,12 +1,13 @@ use std::{path::PathBuf, sync::Arc}; use axum::{ - Json, Router, + Json, body::Body, http::Request, response::{IntoResponse, Response}, }; use database::Database; +use dotenvy::dotenv; use reqwest::{Method, header}; use serde::Serialize; use sqlx::SqlitePool; @@ -16,35 +17,19 @@ use zpi::{ AppState, api_router, config::AppConfig, extractors::authenticated_user::AuthenticatedUser, }; -#[derive(Clone)] -pub struct AuthenticatedRouter { - router: Router, - cookie: String, +pub struct TestRouter { + router: axum::Router, + store: MemoryStore, + cookie: Option, + api_key: Option, } -impl AuthenticatedRouter { - pub async fn new(db: SqlitePool) -> Self { - let _ = dotenvy::dotenv(); - let store = Arc::new(MemoryStore::default()); - - let session_id = { - let session = Session::new(Some(Id(1)), store.clone(), None); - session - .insert( - "user", - AuthenticatedUser { - id: 1, - username: "cheese".to_string(), - admin: true, - }, - ) - .await - .unwrap(); - session.save().await.unwrap(); - session.id().unwrap() - }; +impl TestRouter { + pub fn new(db: SqlitePool) -> Self { + let _ = dotenv(); + let store = MemoryStore::default(); - let session_layer = SessionManagerLayer::new(Arc::into_inner(store).unwrap()) + let session_layer = SessionManagerLayer::new(store.clone()) .with_secure(false) .with_same_site(tower_sessions::cookie::SameSite::Lax); @@ -58,93 +43,64 @@ impl AuthenticatedRouter { Self { router: api_router().layer(session_layer).with_state(state), - cookie: format!("id={}", session_id), + store: store, + cookie: None, + api_key: None, } } - /// send a request to an endpoint on this router - /// - /// must have a leading "/" - pub async fn get(self, path: &str) -> Response { - self.request(Method::GET, path, None::<()>).await + pub async fn as_user(db: SqlitePool) -> Self { + Self::new(db) + .add_to_store(AuthenticatedUser { + id: 1, + username: "cheese".to_string(), + admin: false, + }) + .await } - /// send a patch request to an endpoint on this router - /// - /// must have a leading "/" - pub async fn patch(self, path: &str, body: T) -> Response { - self.request(Method::PATCH, path, Some(body)).await + pub async fn as_admin(db: SqlitePool) -> Self { + Self::new(db) + .add_to_store(AuthenticatedUser { + id: 1, + username: "cheese".to_string(), + admin: true, + }) + .await } - /// send a post request to an endpoint on this router - /// - /// must have a leading "/" - pub async fn post(self, path: &str, body: T) -> Response { - self.request(Method::POST, path, Some(body)).await + async fn add_to_store(mut self, user: AuthenticatedUser) -> Self { + let session = Session::new(Some(Id(1)), Arc::new(self.store.clone()), None); + session.insert("user", user).await.unwrap(); + session.save().await.unwrap(); + self.cookie.replace(format!("id={}", session.id().unwrap())); + self + } + + pub async fn with_api_key(db: SqlitePool, api_key: &str) -> Self { + let mut router = Self::new(db); + router.api_key = Some("Bearer ".to_string() + api_key); + router } /// send a request to an endpoint on this router /// /// must have a leading "/" - async fn request( - self, - method: Method, - path: &str, - body: Option, - ) -> Response { - let request_builder = Request::builder() - .method(method) - .uri(path) - .header(header::COOKIE, &self.cookie); - - let request = match body { - Some(body) => request_builder - .header(header::CONTENT_TYPE, mime::APPLICATION_JSON.as_ref()) - .body(Json(body).into_response().into_body()), - None => request_builder.body(Body::empty()), - }; - self.router.oneshot(request.unwrap()).await.unwrap() - } -} -pub struct UnauthenticatedRouter { - router: Router, - api_key: Option, -} - -impl UnauthenticatedRouter { - pub async fn new(db: SqlitePool) -> Self { - let _ = dotenvy::dotenv(); - let store = MemoryStore::default(); - - let session_layer = SessionManagerLayer::new(store) - .with_secure(false) - .with_same_site(tower_sessions::cookie::SameSite::Lax); - - let mut config = AppConfig::load().unwrap(); - config.image_path = PathBuf::from("./tests/test_images"); - - let state = AppState { - db: Database::new(db), - config, - }; - - Self { - router: api_router().layer(session_layer).with_state(state), - api_key: None, - } + pub async fn get(&self, path: &str) -> Response { + self.request(Method::GET, path, None::<()>).await } - /// send a request to an endpoint on this router + /// send a patch request to an endpoint on this router /// /// must have a leading "/" - pub async fn get(self, path: &str) -> Response { - self.request(Method::GET, path, None::<()>).await + pub async fn patch(&self, path: &str, body: T) -> Response { + self.request(Method::PATCH, path, Some(body)).await } /// send a post request to an endpoint on this router /// /// must have a leading "/" - pub async fn post(self, path: &str, body: T) -> Response { + pub async fn post(&self, path: &str, body: T) -> Response { self.request(Method::POST, path, Some(body)).await } @@ -152,7 +108,7 @@ impl UnauthenticatedRouter { /// /// must have a leading "/" async fn request( - self, + &self, method: Method, path: &str, body: Option, @@ -163,17 +119,17 @@ impl UnauthenticatedRouter { request_builder = request_builder.header(header::AUTHORIZATION, api_key); } + if let Some(cookie) = &self.cookie { + request_builder = request_builder.header(header::COOKIE, cookie); + } + let request = match body { Some(body) => request_builder .header(header::CONTENT_TYPE, mime::APPLICATION_JSON.as_ref()) .body(Json(body).into_response().into_body()), None => request_builder.body(Body::empty()), }; - self.router.oneshot(request.unwrap()).await.unwrap() - } - pub fn with_api_key(mut self, api_key: &str) -> Self { - self.api_key = Some("Bearer ".to_string() + api_key); - self + self.router.clone().oneshot(request.unwrap()).await.unwrap() } } diff --git a/tests/common/test_objects.rs b/tests/common/test_objects.rs index d304790..317ee2f 100644 --- a/tests/common/test_objects.rs +++ b/tests/common/test_objects.rs @@ -13,6 +13,14 @@ pub struct TestObjects; impl TestObjects { pub fn authenticated_user_1() -> AuthenticatedUser { + AuthenticatedUser { + id: 1, + username: "cheese".into(), + admin: false, + } + } + + pub fn admin_user_1() -> AuthenticatedUser { AuthenticatedUser { id: 1, username: "cheese".into(), diff --git a/tests/image.rs b/tests/image.rs index b90042d..33894d1 100644 --- a/tests/image.rs +++ b/tests/image.rs @@ -1,13 +1,13 @@ use reqwest::StatusCode; use sqlx::SqlitePool; -use crate::common::router::UnauthenticatedRouter; +use crate::common::router::TestRouter; mod common; #[sqlx::test] async fn get_image_default(db_pool: SqlitePool) { - let router = UnauthenticatedRouter::new(db_pool).await; + let router = TestRouter::new(db_pool); let response = router.get("/image/1").await; assert_eq!(response.status(), StatusCode::OK); @@ -15,7 +15,7 @@ async fn get_image_default(db_pool: SqlitePool) { #[sqlx::test] async fn get_image_placeholder(db_pool: SqlitePool) { - let router = UnauthenticatedRouter::new(db_pool).await; + let router = TestRouter::new(db_pool); let response = router.get("/image/1?placeholder=true").await; assert_eq!(response.status(), StatusCode::OK); @@ -23,7 +23,7 @@ async fn get_image_placeholder(db_pool: SqlitePool) { #[sqlx::test] async fn get_image_no_placeholder_404(db_pool: SqlitePool) { - let router = UnauthenticatedRouter::new(db_pool).await; + let router = TestRouter::new(db_pool); let response = router.get("/image/1?placeholder=false").await; assert_eq!(response.status(), StatusCode::NOT_FOUND); @@ -31,7 +31,7 @@ async fn get_image_no_placeholder_404(db_pool: SqlitePool) { #[sqlx::test] async fn get_image_no_placeholder(db_pool: SqlitePool) { - let router = UnauthenticatedRouter::new(db_pool).await; + let router = TestRouter::new(db_pool); let response = router.get("/image/2?placeholder=false").await; assert_eq!(response.status(), StatusCode::OK); diff --git a/tests/service.rs b/tests/service.rs index 02374d2..5e6f8cf 100644 --- a/tests/service.rs +++ b/tests/service.rs @@ -5,16 +5,14 @@ use zpi::dto::service::{ ServiceCreatePayload, ServicePatchPayload, ServicePayloadAdmin, ServicePayloadUser, }; -use crate::common::{ - into_struct::IntoStruct, router::AuthenticatedRouter, test_objects::TestObjects, -}; +use crate::common::{into_struct::IntoStruct, router::TestRouter, test_objects::TestObjects}; mod common; #[sqlx::test(fixtures("services"))] #[test_log::test] async fn get_all_services_as_admin(db_pool: SqlitePool) { - let router = AuthenticatedRouter::new(db_pool).await; + let router = TestRouter::as_admin(db_pool).await; let response = router.get("/admin/services").await; assert_eq!(response.status(), StatusCode::OK); @@ -27,7 +25,7 @@ async fn get_all_services_as_admin(db_pool: SqlitePool) { #[sqlx::test(fixtures("services"))] #[test_log::test] async fn get_all_services(db_pool: SqlitePool) { - let router = AuthenticatedRouter::new(db_pool).await; + let router = TestRouter::as_user(db_pool).await; let response = router.get("/services").await; assert_eq!(response.status(), StatusCode::OK); @@ -45,7 +43,7 @@ struct ApiKey { #[sqlx::test(fixtures("services"))] #[test_log::test] async fn users_dont_see_api_key(db_pool: SqlitePool) { - let router = AuthenticatedRouter::new(db_pool).await; + let router = TestRouter::as_user(db_pool).await; let response = router.get("/services").await; assert_eq!(response.status(), StatusCode::OK); @@ -58,7 +56,7 @@ async fn users_dont_see_api_key(db_pool: SqlitePool) { #[sqlx::test] #[test_log::test] async fn create_service(db_pool: SqlitePool) { - let router = AuthenticatedRouter::new(db_pool).await; + let router = TestRouter::as_admin(db_pool).await; let body = ServiceCreatePayload { name: "zpi".to_string(), }; @@ -77,7 +75,7 @@ async fn create_service(db_pool: SqlitePool) { #[test_log::test] async fn patch_service(db_pool: SqlitePool) { let new_name = "gamification2"; - let router = AuthenticatedRouter::new(db_pool).await; + let router = TestRouter::as_admin(db_pool).await; let body = ServicePatchPayload { name: new_name.to_string(), }; @@ -95,7 +93,7 @@ async fn patch_service(db_pool: SqlitePool) { #[sqlx::test(fixtures("services"))] #[test_log::test] async fn regenerate_api_key(db_pool: SqlitePool) { - let router = AuthenticatedRouter::new(db_pool).await; + let router = TestRouter::as_admin(db_pool).await; let response = router.post("/admin/services/1/apikey", "").await; // empty body assert_eq!(response.status(), StatusCode::OK); diff --git a/tests/tags.rs b/tests/tags.rs index b3d4b75..06b0222 100644 --- a/tests/tags.rs +++ b/tests/tags.rs @@ -2,16 +2,14 @@ use reqwest::StatusCode; use sqlx::SqlitePool; use zpi::dto::user::UserProfile; -use crate::common::{ - into_struct::IntoStruct, router::AuthenticatedRouter, test_objects::TestObjects, -}; +use crate::common::{into_struct::IntoStruct, router::TestRouter, test_objects::TestObjects}; mod common; #[sqlx::test(fixtures("users", "tags"))] #[test_log::test] async fn get_user_with_tags(db_pool: SqlitePool) { - let router = AuthenticatedRouter::new(db_pool).await; + let router = TestRouter::as_user(db_pool).await; let response = router.get("/users/2").await; assert_eq!(response.status(), StatusCode::OK); diff --git a/tests/user.rs b/tests/user.rs index dabd533..e993c4a 100644 --- a/tests/user.rs +++ b/tests/user.rs @@ -3,18 +3,14 @@ use reqwest::StatusCode; use sqlx::SqlitePool; use zpi::{dto::user::UserProfile, extractors::AuthenticatedUser}; -use crate::common::{ - into_struct::IntoStruct, - router::{AuthenticatedRouter, UnauthenticatedRouter}, - test_objects::TestObjects, -}; +use crate::common::{into_struct::IntoStruct, router::TestRouter, test_objects::TestObjects}; mod common; #[sqlx::test] #[test_log::test] async fn get_users_me(db_pool: SqlitePool) { - let router = AuthenticatedRouter::new(db_pool).await; + let router = TestRouter::as_user(db_pool).await; let response = router.get("/users/me").await; assert_eq!(response.status(), StatusCode::OK); @@ -25,7 +21,7 @@ async fn get_users_me(db_pool: SqlitePool) { #[sqlx::test] #[test_log::test] async fn get_users_me_unauthenticated(db_pool: SqlitePool) { - let router = UnauthenticatedRouter::new(db_pool).await; + let router = TestRouter::new(db_pool); let response = router.get("/users/me").await; assert_eq!(response.status(), StatusCode::UNAUTHORIZED); } @@ -33,7 +29,7 @@ async fn get_users_me_unauthenticated(db_pool: SqlitePool) { #[sqlx::test(fixtures("users"))] #[test_log::test] async fn patch_user(db_pool: SqlitePool) { - let router = AuthenticatedRouter::new(db_pool).await; + let router = TestRouter::as_user(db_pool).await; let body = UserPatch { about: "Changed about".to_string(), }; @@ -52,7 +48,7 @@ async fn patch_user(db_pool: SqlitePool) { #[sqlx::test(fixtures("users"))] #[test_log::test] async fn get_profile_by_id(db_pool: SqlitePool) { - let router = AuthenticatedRouter::new(db_pool).await; + let router = TestRouter::as_user(db_pool).await; let response = router.get("/users/1").await; assert_eq!(response.status(), StatusCode::OK); @@ -63,7 +59,7 @@ async fn get_profile_by_id(db_pool: SqlitePool) { #[sqlx::test] #[test_log::test] async fn get_profile_by_id_unauthenticated(db_pool: SqlitePool) { - let router = UnauthenticatedRouter::new(db_pool).await; + let router = TestRouter::new(db_pool); let response = router.get("/users/1").await; assert_eq!(response.status(), StatusCode::UNAUTHORIZED); } @@ -71,13 +67,13 @@ async fn get_profile_by_id_unauthenticated(db_pool: SqlitePool) { #[sqlx::test] #[test_log::test] async fn get_profile_404(db_pool: SqlitePool) { + let router = TestRouter::as_user(db_pool).await; + // test getting by id - let router = AuthenticatedRouter::new(db_pool.clone()).await; let response = router.get("/users/1").await; assert_eq!(response.status(), StatusCode::NOT_FOUND); // test getting by username - let router = AuthenticatedRouter::new(db_pool).await; let response = router.get("/users/cheese").await; assert_eq!(response.status(), StatusCode::NOT_FOUND); } @@ -85,7 +81,7 @@ async fn get_profile_404(db_pool: SqlitePool) { #[sqlx::test(fixtures("users"))] #[test_log::test] async fn get_profile_by_name(db_pool: SqlitePool) { - let router = AuthenticatedRouter::new(db_pool).await; + let router = TestRouter::as_user(db_pool).await; let response = router.get("/users/cheese").await; assert_eq!(response.status(), StatusCode::OK); From e398e6ecf234d99a6264bc9603fe7faa15f39eb0 Mon Sep 17 00:00:00 2001 From: Hannes Date: Wed, 24 Dec 2025 03:22:57 +0100 Subject: [PATCH 07/11] improve achievement tests --- tests/achievement.rs | 59 ++++++++++++++++++++++++++---------------- tests/common/router.rs | 2 +- 2 files changed, 38 insertions(+), 23 deletions(-) diff --git a/tests/achievement.rs b/tests/achievement.rs index be7d5d7..3708742 100644 --- a/tests/achievement.rs +++ b/tests/achievement.rs @@ -11,14 +11,20 @@ mod common; #[sqlx::test(fixtures("services", "achievements"))] #[test_log::test] -async fn get_achievements_for_service(db_pool: SqlitePool) { - let router = TestRouter::as_admin(db_pool).await; - let response = router.get("/admin/services/1/achievements").await; +async fn get_achievements_for_service(db: SqlitePool) { + let none = TestRouter::new(db.clone()); + let response = none.get("/admin/services/1/achievements").await; + assert_eq!(response.status(), StatusCode::UNAUTHORIZED); + + let user = TestRouter::as_user(db.clone()).await; + let response = user.get("/admin/services/1/achievements").await; + assert_eq!(response.status(), StatusCode::FORBIDDEN); + let admin = TestRouter::as_admin(db).await; + let response = admin.get("/admin/services/1/achievements").await; assert_eq!(response.status(), StatusCode::OK); let data: Vec = response.into_struct().await; - assert_eq!( data, vec![TestObjects::achievement_1(), TestObjects::achievement_2()] @@ -27,8 +33,7 @@ async fn get_achievements_for_service(db_pool: SqlitePool) { #[sqlx::test(fixtures("services"))] #[test_log::test] -async fn post_achievements_for_service(db_pool: SqlitePool) { - let router = TestRouter::as_admin(db_pool).await; +async fn post_achievements_for_service(db: SqlitePool) { let body = AchievementCreatePayload { name: "Achievements".into(), goals: vec![ @@ -42,19 +47,26 @@ async fn post_achievements_for_service(db_pool: SqlitePool) { }, ], }; - let response = router.post("/admin/services/1/achievements", body).await; + let none = TestRouter::new(db.clone()); + let response = none.post("/admin/services/1/achievements", &body).await; + assert_eq!(response.status(), StatusCode::UNAUTHORIZED); + + let user = TestRouter::as_user(db.clone()).await; + let response = user.post("/admin/services/1/achievements", &body).await; + assert_eq!(response.status(), StatusCode::FORBIDDEN); + + let admin = TestRouter::as_admin(db).await; + let response = admin.post("/admin/services/1/achievements", &body).await; assert_eq!(response.status(), StatusCode::OK); let data: AchievementPayload = response.into_struct().await; - assert_eq!(data, TestObjects::achievement_1()); } #[sqlx::test(fixtures("services"))] #[test_log::test] -async fn post_achievements_wrong_sequence(db_pool: SqlitePool) { - let router = TestRouter::as_admin(db_pool).await; +async fn post_achievements_wrong_sequence(db: SqlitePool) { let mut body = AchievementCreatePayload { name: "Achievements".into(), goals: vec![ @@ -69,6 +81,7 @@ async fn post_achievements_wrong_sequence(db_pool: SqlitePool) { ], }; + let router = TestRouter::as_admin(db.clone()).await; let response = router.post("/admin/services/1/achievements", &body).await; assert_eq!(response.status(), StatusCode::BAD_REQUEST); @@ -77,27 +90,29 @@ async fn post_achievements_wrong_sequence(db_pool: SqlitePool) { assert_eq!(response.status(), StatusCode::BAD_REQUEST); } -#[sqlx::test(fixtures("services"))] -#[test_log::test] -async fn unlock_goal_wrong_api_key(db_pool: SqlitePool) { - let router = TestRouter::with_api_key(db_pool, "wrongapikey").await; - - let response = router.post("/users/1/unlock/1/1", None::<()>).await; - assert_eq!(response.status(), StatusCode::UNAUTHORIZED); -} - #[sqlx::test(fixtures("services", "achievements", "users"))] #[test_log::test] -async fn unlock_goal(db_pool: SqlitePool) { - let router = TestRouter::with_api_key(db_pool, "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa").await; +async fn unlock_goal(db: SqlitePool) { + let none = TestRouter::new(db.clone()); + let response = none.post("/users/1/unlock/1/1", None::<()>).await; + assert_eq!(response.status(), StatusCode::BAD_REQUEST); // TODO + let router = TestRouter::with_api_key(db, "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"); let response = router.post("/users/1/unlock/1/1", None::<()>).await; assert_eq!(response.status(), StatusCode::OK); let data: AchievementPayload = response.into_struct().await; - assert_eq!(data, TestObjects::achievement_1()); } +#[sqlx::test(fixtures("services"))] +#[test_log::test] +async fn unlock_goal_wrong_api_key(db_pool: SqlitePool) { + let router = TestRouter::with_api_key(db_pool, "wrongapikey"); + + let response = router.post("/users/1/unlock/1/1", None::<()>).await; + assert_eq!(response.status(), StatusCode::UNAUTHORIZED); +} + // TODO wat als goal niet bestaat -> status code 404 // TODO wat als goal al unlocked is -> status code 200 diff --git a/tests/common/router.rs b/tests/common/router.rs index e8f8360..c3b9a52 100644 --- a/tests/common/router.rs +++ b/tests/common/router.rs @@ -77,7 +77,7 @@ impl TestRouter { self } - pub async fn with_api_key(db: SqlitePool, api_key: &str) -> Self { + pub fn with_api_key(db: SqlitePool, api_key: &str) -> Self { let mut router = Self::new(db); router.api_key = Some("Bearer ".to_string() + api_key); router From 93d64884cd27e0b888d60426394144c81cd26fef Mon Sep 17 00:00:00 2001 From: Hannes Date: Wed, 24 Dec 2025 03:36:53 +0100 Subject: [PATCH 08/11] make api key extractor, return unauthorized for no api key --- src/extractors/api_key.rs | 24 ++++++++++++++++++++++++ src/extractors/mod.rs | 1 + src/handlers/service.rs | 7 +++---- tests/achievement.rs | 2 +- 4 files changed, 29 insertions(+), 5 deletions(-) create mode 100644 src/extractors/api_key.rs diff --git a/src/extractors/api_key.rs b/src/extractors/api_key.rs new file mode 100644 index 0000000..26af31b --- /dev/null +++ b/src/extractors/api_key.rs @@ -0,0 +1,24 @@ +use axum::{extract::FromRequestParts, http::request::Parts}; +use axum_extra::TypedHeader; +use headers::{Authorization, authorization::Bearer}; + +use crate::error::AppError; + +#[derive(Debug)] +pub struct ApiKey(pub String); + +impl FromRequestParts for ApiKey +where + S: Send + Sync, +{ + type Rejection = AppError; + + async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { + let header = TypedHeader::>::from_request_parts(parts, state).await; + + match header { + Ok(TypedHeader(Authorization(bearer))) => Ok(ApiKey(bearer.token().to_string())), + _ => Err(AppError::BadApiKey), + } + } +} diff --git a/src/extractors/mod.rs b/src/extractors/mod.rs index 128e91e..d8be7bf 100644 --- a/src/extractors/mod.rs +++ b/src/extractors/mod.rs @@ -1,4 +1,5 @@ pub mod admin; +pub mod api_key; pub mod authenticated_user; pub mod config; pub mod database; diff --git a/src/handlers/service.rs b/src/handlers/service.rs index fe4dab5..9452ff4 100644 --- a/src/handlers/service.rs +++ b/src/handlers/service.rs @@ -1,7 +1,5 @@ use axum::{Json, extract::Path}; -use axum_extra::TypedHeader; use database::Database; -use headers::{Authorization, authorization::Bearer}; use crate::{ dto::{ @@ -11,6 +9,7 @@ use crate::{ }, }, error::AppError, + extractors::api_key::ApiKey, }; pub struct ServiceHandler; @@ -51,10 +50,10 @@ impl ServiceHandler { pub async fn unlock_goal( db: Database, Path((user_id, service_id, goal_id)): Path<(u32, u32, u32)>, - api_key: TypedHeader>, + ApiKey(api_key): ApiKey, ) -> Result, AppError> { let expected_api_key = db.services().by_id(service_id).await?.api_key; - if api_key.token() != expected_api_key { + if api_key != expected_api_key { return Err(AppError::BadApiKey); } diff --git a/tests/achievement.rs b/tests/achievement.rs index 3708742..5510231 100644 --- a/tests/achievement.rs +++ b/tests/achievement.rs @@ -95,7 +95,7 @@ async fn post_achievements_wrong_sequence(db: SqlitePool) { async fn unlock_goal(db: SqlitePool) { let none = TestRouter::new(db.clone()); let response = none.post("/users/1/unlock/1/1", None::<()>).await; - assert_eq!(response.status(), StatusCode::BAD_REQUEST); // TODO + assert_eq!(response.status(), StatusCode::UNAUTHORIZED); let router = TestRouter::with_api_key(db, "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"); let response = router.post("/users/1/unlock/1/1", None::<()>).await; From 40ca597adb4266f1f7cf666720d5161fdf371e84 Mon Sep 17 00:00:00 2001 From: nebilam <49345234+Nebilam@users.noreply.github.com> Date: Thu, 1 Jan 2026 00:33:43 +0100 Subject: [PATCH 09/11] remove .vscode and .direnv from gitignore Happy new year!!! --- .gitignore | 2 -- 1 file changed, 2 deletions(-) diff --git a/.gitignore b/.gitignore index fc60c92..f7379f7 100644 --- a/.gitignore +++ b/.gitignore @@ -4,5 +4,3 @@ *.db *.db-wal *.db-shm -/.direnv -/.vscode From f710a017648edd0fb2f2c4b53effe08f7c1eb0ea Mon Sep 17 00:00:00 2001 From: Nebilam <49345234+Nebilam@users.noreply.github.com> Date: Wed, 4 Feb 2026 13:51:29 +0100 Subject: [PATCH 10/11] add some more test for unlocking goals - test for unlocking a goal that doesn't exist - test for ulocking a goal that is already unlocked --- database/src/repos/achievement.rs | 38 +++++++++++++++++++++++++++++-- src/dto/achievement.rs | 11 ++++++++- tests/achievement.rs | 20 ++++++++++++++-- 3 files changed, 64 insertions(+), 5 deletions(-) diff --git a/database/src/repos/achievement.rs b/database/src/repos/achievement.rs index 94824f9..909948e 100644 --- a/database/src/repos/achievement.rs +++ b/database/src/repos/achievement.rs @@ -1,4 +1,4 @@ -use sqlx::{SqlitePool, query, query_as}; +use sqlx::{SqlitePool, query, query_as, query_scalar}; use crate::{ error::DatabaseError, @@ -40,7 +40,7 @@ impl<'a> AchievementRepo<'a> { .await?) } - async fn by_goal_id(&self, goal_id: u32) -> Result, DatabaseError> { + pub async fn by_goal_id(&self, goal_id: u32) -> Result, DatabaseError> { Ok(query_as( " SELECT @@ -164,4 +164,38 @@ impl<'a> AchievementRepo<'a> { self.by_goal_id(goal_id).await } + + pub async fn goal_exist(&self, goal_id: u32) -> Result { + Ok(query_scalar::<_, i32>( + " + SELECT + 1 + FROM + goal + WHERE + goal.id = ?; + ", + ) + .bind(goal_id) + .fetch_optional(self.db) + .await? + .is_some()) + } + + pub async fn goal_unlocked(&self, goal_id: u32) -> Result { + Ok(query_scalar::<_, i32>( + " + SELECT + 1 + FROM + unlock + WHERE + goal_id = ?; + ", + ) + .bind(goal_id) + .fetch_optional(self.db) + .await? + .is_some()) + } } diff --git a/src/dto/achievement.rs b/src/dto/achievement.rs index 63ba143..1453805 100644 --- a/src/dto/achievement.rs +++ b/src/dto/achievement.rs @@ -40,7 +40,16 @@ impl AchievementPayload { user_id: u32, goal_id: u32, ) -> Result { - let rows = db.achievements().unlock_goal(user_id, goal_id).await?; + if !db.achievements().goal_exist(goal_id).await? { + return Err(AppError::NotFound); + } + + let rows = if db.achievements().goal_unlocked(goal_id).await? { + // goal already unlocked + db.achievements().by_goal_id(goal_id).await? + } else { + db.achievements().unlock_goal(user_id, goal_id).await? + }; // pack rows into an achievement payload let mut rows = rows.into_iter().peekable(); diff --git a/tests/achievement.rs b/tests/achievement.rs index 5510231..ca07347 100644 --- a/tests/achievement.rs +++ b/tests/achievement.rs @@ -114,5 +114,21 @@ async fn unlock_goal_wrong_api_key(db_pool: SqlitePool) { assert_eq!(response.status(), StatusCode::UNAUTHORIZED); } -// TODO wat als goal niet bestaat -> status code 404 -// TODO wat als goal al unlocked is -> status code 200 +#[sqlx::test(fixtures("services"))] +#[test_log::test] +async fn unlock_goal_404(db: SqlitePool) { + let router = TestRouter::with_api_key(db, "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"); + let response = router.post("/users/1/unlock/1/3", None::<()>).await; + assert_eq!(response.status(), StatusCode::NOT_FOUND); +} + +#[sqlx::test(fixtures("services", "users", "achievements", "unlocks"))] +#[test_log::test] +async fn unlock_goal_already_unlocked(db: SqlitePool) { + let router = TestRouter::with_api_key(db, "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"); + let response = router.post("/users/1/unlock/1/3", None::<()>).await; + assert_eq!(response.status(), StatusCode::OK); + + let data: AchievementPayload = response.into_struct().await; + assert_eq!(data, TestObjects::achievement_2()); +} From 18e9aa69905796f30440778a5a3155fed95ae05a Mon Sep 17 00:00:00 2001 From: Hannes Date: Wed, 4 Feb 2026 15:16:09 +0100 Subject: [PATCH 11/11] fix spelling --- src/dto/achievement.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/dto/achievement.rs b/src/dto/achievement.rs index 1453805..38669c7 100644 --- a/src/dto/achievement.rs +++ b/src/dto/achievement.rs @@ -86,7 +86,7 @@ impl AchievementCreatePayload { } self.goals.sort_by_key(|x| x.sequence); - let ordered_1_seperated = self + let ordered_1_separated = self .goals .iter() .map(|x| x.sequence) @@ -97,7 +97,7 @@ impl AchievementCreatePayload { _ => false, }); if let Some(goal) = self.goals.first() - && (goal.sequence != 0 || !ordered_1_seperated) + && (goal.sequence != 0 || !ordered_1_separated) { return Err(AppError::PayloadError( "Sequence should start with 0 and count up by 1".into(),