diff --git a/.gitignore b/.gitignore index 4fc2701..f7379f7 100644 --- a/.gitignore +++ b/.gitignore @@ -4,4 +4,3 @@ *.db *.db-wal *.db-shm -/.direnv \ No newline at end of file diff --git a/database/src/repos/achievement.rs b/database/src/repos/achievement.rs index 877636c..909948e 100644 --- a/database/src/repos/achievement.rs +++ b/database/src/repos/achievement.rs @@ -1,4 +1,4 @@ -use sqlx::{SqlitePool, query, query_as}; +use sqlx::{SqlitePool, query, query_as, query_scalar}; use crate::{ error::DatabaseError, @@ -40,6 +40,30 @@ impl<'a> AchievementRepo<'a> { .await?) } + pub async fn by_goal_id(&self, goal_id: u32) -> Result, DatabaseError> { + Ok(query_as( + " + SELECT + achievement.id as achievement_id, + achievement.name as achievement_name, + service_id, + goal2.id as goal_id, + goal2.description as goal_description, + goal2.sequence as goal_sequence + + FROM + goal as goal1 + inner join achievement on achievement.id = goal1.achievement_id + inner join goal as goal2 on goal2.achievement_id = achievement.id + WHERE + goal1.id = ?; + ", + ) + .bind(goal_id) + .fetch_all(self.db) + .await?) + } + pub async fn for_service( &self, service_id: u32, @@ -119,4 +143,59 @@ impl<'a> AchievementRepo<'a> { tx.commit().await?; self.by_id(db_achievement.id).await } + + pub async fn unlock_goal( + &self, + user_id: u32, + goal_id: u32, + ) -> Result, DatabaseError> { + query( + " + INSERT INTO + unlock (user_id, goal_id) + VALUES + (?,?); + ", + ) + .bind(user_id) + .bind(goal_id) + .execute(self.db) + .await?; + + self.by_goal_id(goal_id).await + } + + pub async fn goal_exist(&self, goal_id: u32) -> Result { + Ok(query_scalar::<_, i32>( + " + SELECT + 1 + FROM + goal + WHERE + goal.id = ?; + ", + ) + .bind(goal_id) + .fetch_optional(self.db) + .await? + .is_some()) + } + + pub async fn goal_unlocked(&self, goal_id: u32) -> Result { + Ok(query_scalar::<_, i32>( + " + SELECT + 1 + FROM + unlock + WHERE + goal_id = ?; + ", + ) + .bind(goal_id) + .fetch_optional(self.db) + .await? + .is_some()) + } } diff --git a/database/src/repos/service.rs b/database/src/repos/service.rs index d44b0e2..297a782 100644 --- a/database/src/repos/service.rs +++ b/database/src/repos/service.rs @@ -74,4 +74,12 @@ impl<'a> ServiceRepo<'a> { .await? .ok_or(DatabaseError::NotFound) } + + pub async fn by_id(&self, id: u32) -> Result { + sqlx::query_as("SELECT id, name, api_key FROM service WHERE id == ? LIMIT 1;") + .bind(id) + .fetch_optional(self.db) + .await? + .ok_or(DatabaseError::NotFound) + } } diff --git a/src/dto/achievement.rs b/src/dto/achievement.rs index 9046bfb..38669c7 100644 --- a/src/dto/achievement.rs +++ b/src/dto/achievement.rs @@ -34,6 +34,28 @@ impl AchievementPayload { Ok(achievements) } + + pub async fn unlock_goal( + db: &Database, + user_id: u32, + goal_id: u32, + ) -> Result { + if !db.achievements().goal_exist(goal_id).await? { + return Err(AppError::NotFound); + } + + let rows = if db.achievements().goal_unlocked(goal_id).await? { + // goal already unlocked + db.achievements().by_goal_id(goal_id).await? + } else { + db.achievements().unlock_goal(user_id, goal_id).await? + }; + + // pack rows into an achievement payload + let mut rows = rows.into_iter().peekable(); + let achievement = unpack_next_achievement(&mut rows).ok_or(AppError::NotFound)?; + Ok(achievement) + } } #[derive(Serialize, Deserialize, Debug, PartialEq)] @@ -64,7 +86,7 @@ impl AchievementCreatePayload { } self.goals.sort_by_key(|x| x.sequence); - let ordered_1_seperated = self + let ordered_1_separated = self .goals .iter() .map(|x| x.sequence) @@ -75,7 +97,7 @@ impl AchievementCreatePayload { _ => false, }); if let Some(goal) = self.goals.first() - && (goal.sequence != 0 || !ordered_1_seperated) + && (goal.sequence != 0 || !ordered_1_separated) { return Err(AppError::PayloadError( "Sequence should start with 0 and count up by 1".into(), diff --git a/src/error.rs b/src/error.rs index 6088190..af3f6b7 100644 --- a/src/error.rs +++ b/src/error.rs @@ -48,7 +48,7 @@ pub enum AppError { #[error("Submitted image resolution was too large")] ImageResTooLarge, - #[error("The requested image was not found")] + #[error("Not found")] NotFound, #[error("Submitted file had an incorrect type")] @@ -60,6 +60,9 @@ pub enum AppError { #[error("User was not logged in")] NotLoggedIn, + #[error("Wrong api key")] + BadApiKey, + #[error("Forbidden")] Forbidden, @@ -80,6 +83,7 @@ impl AppError { let (status, msg) = match self { Self::PayloadError(_) => (StatusCode::BAD_REQUEST, "Payload error"), Self::NotLoggedIn => (StatusCode::UNAUTHORIZED, "Not logged in."), + Self::BadApiKey => (StatusCode::UNAUTHORIZED, "Bad api key."), Self::Forbidden => (StatusCode::FORBIDDEN, "Forbidden."), Self::NoFile => ( StatusCode::BAD_REQUEST, diff --git a/src/extractors/api_key.rs b/src/extractors/api_key.rs new file mode 100644 index 0000000..26af31b --- /dev/null +++ b/src/extractors/api_key.rs @@ -0,0 +1,24 @@ +use axum::{extract::FromRequestParts, http::request::Parts}; +use axum_extra::TypedHeader; +use headers::{Authorization, authorization::Bearer}; + +use crate::error::AppError; + +#[derive(Debug)] +pub struct ApiKey(pub String); + +impl FromRequestParts for ApiKey +where + S: Send + Sync, +{ + type Rejection = AppError; + + async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { + let header = TypedHeader::>::from_request_parts(parts, state).await; + + match header { + Ok(TypedHeader(Authorization(bearer))) => Ok(ApiKey(bearer.token().to_string())), + _ => Err(AppError::BadApiKey), + } + } +} diff --git a/src/extractors/mod.rs b/src/extractors/mod.rs index 128e91e..d8be7bf 100644 --- a/src/extractors/mod.rs +++ b/src/extractors/mod.rs @@ -1,4 +1,5 @@ pub mod admin; +pub mod api_key; pub mod authenticated_user; pub mod config; pub mod database; diff --git a/src/handlers/service.rs b/src/handlers/service.rs index 89e47c1..9452ff4 100644 --- a/src/handlers/service.rs +++ b/src/handlers/service.rs @@ -2,10 +2,14 @@ use axum::{Json, extract::Path}; use database::Database; use crate::{ - dto::service::{ - ServiceCreatePayload, ServicePatchPayload, ServicePayloadAdmin, ServicePayloadUser, + dto::{ + achievement::AchievementPayload, + service::{ + ServiceCreatePayload, ServicePatchPayload, ServicePayloadAdmin, ServicePayloadUser, + }, }, error::AppError, + extractors::api_key::ApiKey, }; pub struct ServiceHandler; @@ -42,4 +46,19 @@ impl ServiceHandler { ServicePayloadAdmin::regenerate_api_key(&db, service_id).await?, )) } + + pub async fn unlock_goal( + db: Database, + Path((user_id, service_id, goal_id)): Path<(u32, u32, u32)>, + ApiKey(api_key): ApiKey, + ) -> Result, AppError> { + let expected_api_key = db.services().by_id(service_id).await?.api_key; + if api_key != expected_api_key { + return Err(AppError::BadApiKey); + } + + Ok(Json( + AchievementPayload::unlock_goal(&db, user_id, goal_id).await?, + )) + } } diff --git a/src/lib.rs b/src/lib.rs index 507b052..2b069c6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -78,6 +78,10 @@ fn open_routes() -> Router { .route("/oauth/callback", get(AuthHandler::callback)) .route("/image/{id}", get(ImageHandler::get)) .route("/version", get(VersionHandler::get)) + .route( + "/users/{id}/unlock/{service_id}/{goal_id}", + post(ServiceHandler::unlock_goal), + ) } fn authenticated_routes() -> Router { diff --git a/tests/achievement.rs b/tests/achievement.rs index 8f087b3..ca07347 100644 --- a/tests/achievement.rs +++ b/tests/achievement.rs @@ -5,22 +5,26 @@ use zpi::dto::{ goal::GoalCreatePayload, }; -use crate::common::{ - into_struct::IntoStruct, router::AuthenticatedRouter, test_objects::TestObjects, -}; +use crate::common::{into_struct::IntoStruct, router::TestRouter, test_objects::TestObjects}; mod common; #[sqlx::test(fixtures("services", "achievements"))] #[test_log::test] -async fn get_achievements_for_service(db_pool: SqlitePool) { - let router = AuthenticatedRouter::new(db_pool).await; - let response = router.get("/admin/services/1/achievements").await; +async fn get_achievements_for_service(db: SqlitePool) { + let none = TestRouter::new(db.clone()); + let response = none.get("/admin/services/1/achievements").await; + assert_eq!(response.status(), StatusCode::UNAUTHORIZED); + + let user = TestRouter::as_user(db.clone()).await; + let response = user.get("/admin/services/1/achievements").await; + assert_eq!(response.status(), StatusCode::FORBIDDEN); + let admin = TestRouter::as_admin(db).await; + let response = admin.get("/admin/services/1/achievements").await; assert_eq!(response.status(), StatusCode::OK); let data: Vec = response.into_struct().await; - assert_eq!( data, vec![TestObjects::achievement_1(), TestObjects::achievement_2()] @@ -29,8 +33,7 @@ async fn get_achievements_for_service(db_pool: SqlitePool) { #[sqlx::test(fixtures("services"))] #[test_log::test] -async fn post_achievements_for_service(db_pool: SqlitePool) { - let router = AuthenticatedRouter::new(db_pool).await; +async fn post_achievements_for_service(db: SqlitePool) { let body = AchievementCreatePayload { name: "Achievements".into(), goals: vec![ @@ -44,19 +47,26 @@ async fn post_achievements_for_service(db_pool: SqlitePool) { }, ], }; - let response = router.post("/admin/services/1/achievements", body).await; + let none = TestRouter::new(db.clone()); + let response = none.post("/admin/services/1/achievements", &body).await; + assert_eq!(response.status(), StatusCode::UNAUTHORIZED); + + let user = TestRouter::as_user(db.clone()).await; + let response = user.post("/admin/services/1/achievements", &body).await; + assert_eq!(response.status(), StatusCode::FORBIDDEN); + + let admin = TestRouter::as_admin(db).await; + let response = admin.post("/admin/services/1/achievements", &body).await; assert_eq!(response.status(), StatusCode::OK); let data: AchievementPayload = response.into_struct().await; - assert_eq!(data, TestObjects::achievement_1()); } #[sqlx::test(fixtures("services"))] #[test_log::test] -async fn post_achievements_wrong_sequence(db_pool: SqlitePool) { - let router = AuthenticatedRouter::new(db_pool).await; +async fn post_achievements_wrong_sequence(db: SqlitePool) { let mut body = AchievementCreatePayload { name: "Achievements".into(), goals: vec![ @@ -71,13 +81,54 @@ async fn post_achievements_wrong_sequence(db_pool: SqlitePool) { ], }; - let response = router - .clone() - .post("/admin/services/1/achievements", &body) - .await; + let router = TestRouter::as_admin(db.clone()).await; + let response = router.post("/admin/services/1/achievements", &body).await; assert_eq!(response.status(), StatusCode::BAD_REQUEST); body.goals[1].sequence = 1; let response = router.post("/admin/services/1/achievements", &body).await; assert_eq!(response.status(), StatusCode::BAD_REQUEST); } + +#[sqlx::test(fixtures("services", "achievements", "users"))] +#[test_log::test] +async fn unlock_goal(db: SqlitePool) { + let none = TestRouter::new(db.clone()); + let response = none.post("/users/1/unlock/1/1", None::<()>).await; + assert_eq!(response.status(), StatusCode::UNAUTHORIZED); + + let router = TestRouter::with_api_key(db, "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"); + let response = router.post("/users/1/unlock/1/1", None::<()>).await; + assert_eq!(response.status(), StatusCode::OK); + + let data: AchievementPayload = response.into_struct().await; + assert_eq!(data, TestObjects::achievement_1()); +} + +#[sqlx::test(fixtures("services"))] +#[test_log::test] +async fn unlock_goal_wrong_api_key(db_pool: SqlitePool) { + let router = TestRouter::with_api_key(db_pool, "wrongapikey"); + + let response = router.post("/users/1/unlock/1/1", None::<()>).await; + assert_eq!(response.status(), StatusCode::UNAUTHORIZED); +} + +#[sqlx::test(fixtures("services"))] +#[test_log::test] +async fn unlock_goal_404(db: SqlitePool) { + let router = TestRouter::with_api_key(db, "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"); + let response = router.post("/users/1/unlock/1/3", None::<()>).await; + assert_eq!(response.status(), StatusCode::NOT_FOUND); +} + +#[sqlx::test(fixtures("services", "users", "achievements", "unlocks"))] +#[test_log::test] +async fn unlock_goal_already_unlocked(db: SqlitePool) { + let router = TestRouter::with_api_key(db, "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"); + let response = router.post("/users/1/unlock/1/3", None::<()>).await; + assert_eq!(response.status(), StatusCode::OK); + + let data: AchievementPayload = response.into_struct().await; + assert_eq!(data, TestObjects::achievement_2()); +} diff --git a/tests/common/router.rs b/tests/common/router.rs index e64f07d..c3b9a52 100644 --- a/tests/common/router.rs +++ b/tests/common/router.rs @@ -1,12 +1,13 @@ use std::{path::PathBuf, sync::Arc}; use axum::{ - Json, Router, + Json, body::Body, http::Request, response::{IntoResponse, Response}, }; use database::Database; +use dotenvy::dotenv; use reqwest::{Method, header}; use serde::Serialize; use sqlx::SqlitePool; @@ -16,35 +17,19 @@ use zpi::{ AppState, api_router, config::AppConfig, extractors::authenticated_user::AuthenticatedUser, }; -#[derive(Clone)] -pub struct AuthenticatedRouter { - router: Router, - cookie: String, +pub struct TestRouter { + router: axum::Router, + store: MemoryStore, + cookie: Option, + api_key: Option, } -impl AuthenticatedRouter { - pub async fn new(db: SqlitePool) -> Self { - let _ = dotenvy::dotenv(); - let store = Arc::new(MemoryStore::default()); - - let session_id = { - let session = Session::new(Some(Id(1)), store.clone(), None); - session - .insert( - "user", - AuthenticatedUser { - id: 1, - username: "cheese".to_string(), - admin: true, - }, - ) - .await - .unwrap(); - session.save().await.unwrap(); - session.id().unwrap() - }; +impl TestRouter { + pub fn new(db: SqlitePool) -> Self { + let _ = dotenv(); + let store = MemoryStore::default(); - let session_layer = SessionManagerLayer::new(Arc::into_inner(store).unwrap()) + let session_layer = SessionManagerLayer::new(store.clone()) .with_secure(false) .with_same_site(tower_sessions::cookie::SameSite::Lax); @@ -58,28 +43,64 @@ impl AuthenticatedRouter { Self { router: api_router().layer(session_layer).with_state(state), - cookie: format!("id={}", session_id), + store: store, + cookie: None, + api_key: None, } } + pub async fn as_user(db: SqlitePool) -> Self { + Self::new(db) + .add_to_store(AuthenticatedUser { + id: 1, + username: "cheese".to_string(), + admin: false, + }) + .await + } + + pub async fn as_admin(db: SqlitePool) -> Self { + Self::new(db) + .add_to_store(AuthenticatedUser { + id: 1, + username: "cheese".to_string(), + admin: true, + }) + .await + } + + async fn add_to_store(mut self, user: AuthenticatedUser) -> Self { + let session = Session::new(Some(Id(1)), Arc::new(self.store.clone()), None); + session.insert("user", user).await.unwrap(); + session.save().await.unwrap(); + self.cookie.replace(format!("id={}", session.id().unwrap())); + self + } + + pub fn with_api_key(db: SqlitePool, api_key: &str) -> Self { + let mut router = Self::new(db); + router.api_key = Some("Bearer ".to_string() + api_key); + router + } + /// send a request to an endpoint on this router /// /// must have a leading "/" - pub async fn get(self, path: &str) -> Response { + pub async fn get(&self, path: &str) -> Response { self.request(Method::GET, path, None::<()>).await } /// send a patch request to an endpoint on this router /// /// must have a leading "/" - pub async fn patch(self, path: &str, body: T) -> Response { + pub async fn patch(&self, path: &str, body: T) -> Response { self.request(Method::PATCH, path, Some(body)).await } - /// send a patch request to an endpoint on this router + /// send a post request to an endpoint on this router /// /// must have a leading "/" - pub async fn post(self, path: &str, body: T) -> Response { + pub async fn post(&self, path: &str, body: T) -> Response { self.request(Method::POST, path, Some(body)).await } @@ -87,15 +108,20 @@ impl AuthenticatedRouter { /// /// must have a leading "/" async fn request( - self, + &self, method: Method, path: &str, body: Option, ) -> Response { - let request_builder = Request::builder() - .method(method) - .uri(path) - .header(header::COOKIE, &self.cookie); + let mut request_builder = Request::builder().method(method).uri(path); + + if let Some(api_key) = &self.api_key { + request_builder = request_builder.header(header::AUTHORIZATION, api_key); + } + + if let Some(cookie) = &self.cookie { + request_builder = request_builder.header(header::COOKIE, cookie); + } let request = match body { Some(body) => request_builder @@ -103,42 +129,7 @@ impl AuthenticatedRouter { .body(Json(body).into_response().into_body()), None => request_builder.body(Body::empty()), }; - self.router.oneshot(request.unwrap()).await.unwrap() - } -} -pub struct UnauthenticatedRouter { - router: Router, -} - -impl UnauthenticatedRouter { - pub async fn new(db: SqlitePool) -> Self { - let _ = dotenvy::dotenv(); - let store = MemoryStore::default(); - - let session_layer = SessionManagerLayer::new(store) - .with_secure(false) - .with_same_site(tower_sessions::cookie::SameSite::Lax); - - let mut config = AppConfig::load().unwrap(); - config.image_path = PathBuf::from("./tests/test_images"); - - let state = AppState { - db: Database::new(db), - config, - }; - - Self { - router: api_router().layer(session_layer).with_state(state), - } - } - /// send a request to an endpoint on this router - /// - /// must have a leading "/" - pub async fn get(self, path: &str) -> Response { - self.router - .oneshot(Request::builder().uri(path).body(Body::empty()).unwrap()) - .await - .unwrap() + self.router.clone().oneshot(request.unwrap()).await.unwrap() } } diff --git a/tests/common/test_objects.rs b/tests/common/test_objects.rs index d304790..317ee2f 100644 --- a/tests/common/test_objects.rs +++ b/tests/common/test_objects.rs @@ -13,6 +13,14 @@ pub struct TestObjects; impl TestObjects { pub fn authenticated_user_1() -> AuthenticatedUser { + AuthenticatedUser { + id: 1, + username: "cheese".into(), + admin: false, + } + } + + pub fn admin_user_1() -> AuthenticatedUser { AuthenticatedUser { id: 1, username: "cheese".into(), diff --git a/tests/image.rs b/tests/image.rs index b90042d..33894d1 100644 --- a/tests/image.rs +++ b/tests/image.rs @@ -1,13 +1,13 @@ use reqwest::StatusCode; use sqlx::SqlitePool; -use crate::common::router::UnauthenticatedRouter; +use crate::common::router::TestRouter; mod common; #[sqlx::test] async fn get_image_default(db_pool: SqlitePool) { - let router = UnauthenticatedRouter::new(db_pool).await; + let router = TestRouter::new(db_pool); let response = router.get("/image/1").await; assert_eq!(response.status(), StatusCode::OK); @@ -15,7 +15,7 @@ async fn get_image_default(db_pool: SqlitePool) { #[sqlx::test] async fn get_image_placeholder(db_pool: SqlitePool) { - let router = UnauthenticatedRouter::new(db_pool).await; + let router = TestRouter::new(db_pool); let response = router.get("/image/1?placeholder=true").await; assert_eq!(response.status(), StatusCode::OK); @@ -23,7 +23,7 @@ async fn get_image_placeholder(db_pool: SqlitePool) { #[sqlx::test] async fn get_image_no_placeholder_404(db_pool: SqlitePool) { - let router = UnauthenticatedRouter::new(db_pool).await; + let router = TestRouter::new(db_pool); let response = router.get("/image/1?placeholder=false").await; assert_eq!(response.status(), StatusCode::NOT_FOUND); @@ -31,7 +31,7 @@ async fn get_image_no_placeholder_404(db_pool: SqlitePool) { #[sqlx::test] async fn get_image_no_placeholder(db_pool: SqlitePool) { - let router = UnauthenticatedRouter::new(db_pool).await; + let router = TestRouter::new(db_pool); let response = router.get("/image/2?placeholder=false").await; assert_eq!(response.status(), StatusCode::OK); diff --git a/tests/service.rs b/tests/service.rs index 02374d2..5e6f8cf 100644 --- a/tests/service.rs +++ b/tests/service.rs @@ -5,16 +5,14 @@ use zpi::dto::service::{ ServiceCreatePayload, ServicePatchPayload, ServicePayloadAdmin, ServicePayloadUser, }; -use crate::common::{ - into_struct::IntoStruct, router::AuthenticatedRouter, test_objects::TestObjects, -}; +use crate::common::{into_struct::IntoStruct, router::TestRouter, test_objects::TestObjects}; mod common; #[sqlx::test(fixtures("services"))] #[test_log::test] async fn get_all_services_as_admin(db_pool: SqlitePool) { - let router = AuthenticatedRouter::new(db_pool).await; + let router = TestRouter::as_admin(db_pool).await; let response = router.get("/admin/services").await; assert_eq!(response.status(), StatusCode::OK); @@ -27,7 +25,7 @@ async fn get_all_services_as_admin(db_pool: SqlitePool) { #[sqlx::test(fixtures("services"))] #[test_log::test] async fn get_all_services(db_pool: SqlitePool) { - let router = AuthenticatedRouter::new(db_pool).await; + let router = TestRouter::as_user(db_pool).await; let response = router.get("/services").await; assert_eq!(response.status(), StatusCode::OK); @@ -45,7 +43,7 @@ struct ApiKey { #[sqlx::test(fixtures("services"))] #[test_log::test] async fn users_dont_see_api_key(db_pool: SqlitePool) { - let router = AuthenticatedRouter::new(db_pool).await; + let router = TestRouter::as_user(db_pool).await; let response = router.get("/services").await; assert_eq!(response.status(), StatusCode::OK); @@ -58,7 +56,7 @@ async fn users_dont_see_api_key(db_pool: SqlitePool) { #[sqlx::test] #[test_log::test] async fn create_service(db_pool: SqlitePool) { - let router = AuthenticatedRouter::new(db_pool).await; + let router = TestRouter::as_admin(db_pool).await; let body = ServiceCreatePayload { name: "zpi".to_string(), }; @@ -77,7 +75,7 @@ async fn create_service(db_pool: SqlitePool) { #[test_log::test] async fn patch_service(db_pool: SqlitePool) { let new_name = "gamification2"; - let router = AuthenticatedRouter::new(db_pool).await; + let router = TestRouter::as_admin(db_pool).await; let body = ServicePatchPayload { name: new_name.to_string(), }; @@ -95,7 +93,7 @@ async fn patch_service(db_pool: SqlitePool) { #[sqlx::test(fixtures("services"))] #[test_log::test] async fn regenerate_api_key(db_pool: SqlitePool) { - let router = AuthenticatedRouter::new(db_pool).await; + let router = TestRouter::as_admin(db_pool).await; let response = router.post("/admin/services/1/apikey", "").await; // empty body assert_eq!(response.status(), StatusCode::OK); diff --git a/tests/tags.rs b/tests/tags.rs index b3d4b75..06b0222 100644 --- a/tests/tags.rs +++ b/tests/tags.rs @@ -2,16 +2,14 @@ use reqwest::StatusCode; use sqlx::SqlitePool; use zpi::dto::user::UserProfile; -use crate::common::{ - into_struct::IntoStruct, router::AuthenticatedRouter, test_objects::TestObjects, -}; +use crate::common::{into_struct::IntoStruct, router::TestRouter, test_objects::TestObjects}; mod common; #[sqlx::test(fixtures("users", "tags"))] #[test_log::test] async fn get_user_with_tags(db_pool: SqlitePool) { - let router = AuthenticatedRouter::new(db_pool).await; + let router = TestRouter::as_user(db_pool).await; let response = router.get("/users/2").await; assert_eq!(response.status(), StatusCode::OK); diff --git a/tests/user.rs b/tests/user.rs index dabd533..e993c4a 100644 --- a/tests/user.rs +++ b/tests/user.rs @@ -3,18 +3,14 @@ use reqwest::StatusCode; use sqlx::SqlitePool; use zpi::{dto::user::UserProfile, extractors::AuthenticatedUser}; -use crate::common::{ - into_struct::IntoStruct, - router::{AuthenticatedRouter, UnauthenticatedRouter}, - test_objects::TestObjects, -}; +use crate::common::{into_struct::IntoStruct, router::TestRouter, test_objects::TestObjects}; mod common; #[sqlx::test] #[test_log::test] async fn get_users_me(db_pool: SqlitePool) { - let router = AuthenticatedRouter::new(db_pool).await; + let router = TestRouter::as_user(db_pool).await; let response = router.get("/users/me").await; assert_eq!(response.status(), StatusCode::OK); @@ -25,7 +21,7 @@ async fn get_users_me(db_pool: SqlitePool) { #[sqlx::test] #[test_log::test] async fn get_users_me_unauthenticated(db_pool: SqlitePool) { - let router = UnauthenticatedRouter::new(db_pool).await; + let router = TestRouter::new(db_pool); let response = router.get("/users/me").await; assert_eq!(response.status(), StatusCode::UNAUTHORIZED); } @@ -33,7 +29,7 @@ async fn get_users_me_unauthenticated(db_pool: SqlitePool) { #[sqlx::test(fixtures("users"))] #[test_log::test] async fn patch_user(db_pool: SqlitePool) { - let router = AuthenticatedRouter::new(db_pool).await; + let router = TestRouter::as_user(db_pool).await; let body = UserPatch { about: "Changed about".to_string(), }; @@ -52,7 +48,7 @@ async fn patch_user(db_pool: SqlitePool) { #[sqlx::test(fixtures("users"))] #[test_log::test] async fn get_profile_by_id(db_pool: SqlitePool) { - let router = AuthenticatedRouter::new(db_pool).await; + let router = TestRouter::as_user(db_pool).await; let response = router.get("/users/1").await; assert_eq!(response.status(), StatusCode::OK); @@ -63,7 +59,7 @@ async fn get_profile_by_id(db_pool: SqlitePool) { #[sqlx::test] #[test_log::test] async fn get_profile_by_id_unauthenticated(db_pool: SqlitePool) { - let router = UnauthenticatedRouter::new(db_pool).await; + let router = TestRouter::new(db_pool); let response = router.get("/users/1").await; assert_eq!(response.status(), StatusCode::UNAUTHORIZED); } @@ -71,13 +67,13 @@ async fn get_profile_by_id_unauthenticated(db_pool: SqlitePool) { #[sqlx::test] #[test_log::test] async fn get_profile_404(db_pool: SqlitePool) { + let router = TestRouter::as_user(db_pool).await; + // test getting by id - let router = AuthenticatedRouter::new(db_pool.clone()).await; let response = router.get("/users/1").await; assert_eq!(response.status(), StatusCode::NOT_FOUND); // test getting by username - let router = AuthenticatedRouter::new(db_pool).await; let response = router.get("/users/cheese").await; assert_eq!(response.status(), StatusCode::NOT_FOUND); } @@ -85,7 +81,7 @@ async fn get_profile_404(db_pool: SqlitePool) { #[sqlx::test(fixtures("users"))] #[test_log::test] async fn get_profile_by_name(db_pool: SqlitePool) { - let router = AuthenticatedRouter::new(db_pool).await; + let router = TestRouter::as_user(db_pool).await; let response = router.get("/users/cheese").await; assert_eq!(response.status(), StatusCode::OK);