Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 9 additions & 10 deletions cot/src/auth/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
use std::any::Any;
use std::borrow::Cow;
use std::fmt::{Display, Formatter};
use std::sync::Arc;

use async_trait::async_trait;
// Importing `Auto` from `cot` instead of `crate` so that the migration generator
Expand Down Expand Up @@ -192,10 +191,10 @@ impl DatabaseUser {
/// use cot::auth::UserId;
/// use cot::auth::db::DatabaseUser;
/// use cot::common_types::Password;
/// use cot::db::Database;
/// use cot::html::Html;
/// use cot::request::extractors::RequestDb;
///
/// async fn view(RequestDb(db): RequestDb) -> cot::Result<Html> {
/// async fn view(db: Database) -> cot::Result<Html> {
/// let user =
/// DatabaseUser::create_user(&db, "testuser".to_string(), &Password::new("password123"))
/// .await?;
Expand All @@ -210,7 +209,7 @@ impl DatabaseUser {
/// # use cot::test::{TestDatabase, TestRequestBuilder};
/// # let mut test_database = TestDatabase::new_sqlite().await?;
/// # test_database.with_auth().run_migrations().await;
/// # view(RequestDb(test_database.database())).await?;
/// # view(test_database.database()).await?;
/// # test_database.cleanup().await?;
/// # Ok(())
/// # }
Expand Down Expand Up @@ -284,10 +283,10 @@ impl DatabaseUser {
/// use cot::auth::UserId;
/// use cot::auth::db::DatabaseUser;
/// use cot::common_types::Password;
/// use cot::db::Database;
/// use cot::html::Html;
/// use cot::request::extractors::RequestDb;
///
/// async fn view(RequestDb(db): RequestDb) -> cot::Result<Html> {
/// async fn view(db: Database) -> cot::Result<Html> {
/// let user =
/// DatabaseUser::create_user(&db, "testuser".to_string(), &Password::new("password123"))
/// .await?;
Expand Down Expand Up @@ -327,10 +326,10 @@ impl DatabaseUser {
/// use cot::auth::UserId;
/// use cot::auth::db::DatabaseUser;
/// use cot::common_types::Password;
/// use cot::db::Database;
/// use cot::html::Html;
/// use cot::request::extractors::RequestDb;
///
/// async fn view(RequestDb(db): RequestDb) -> cot::Result<Html> {
/// async fn view(db: Database) -> cot::Result<Html> {
/// let user =
/// DatabaseUser::create_user(&db, "testuser".to_string(), &Password::new("password123"))
/// .await?;
Expand Down Expand Up @@ -469,7 +468,7 @@ impl DatabaseUserCredentials {
/// [`DatabaseUserCredentials`] struct and ignores all other credential types.
#[derive(Debug, Clone)]
pub struct DatabaseUserBackend {
database: Arc<Database>,
database: Database,
}

impl DatabaseUserBackend {
Expand All @@ -495,7 +494,7 @@ impl DatabaseUserBackend {
/// }
/// ```
#[must_use]
pub fn new(database: Arc<Database>) -> Self {
pub fn new(database: Database) -> Self {
Self { database }
}
}
Expand Down
72 changes: 15 additions & 57 deletions cot/src/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ mod sea_query_db;
use std::fmt::{Display, Formatter, Write};
use std::hash::Hash;
use std::str::FromStr;
use std::sync::Arc;

use async_trait::async_trait;
pub use cot_macros::{model, query};
Expand Down Expand Up @@ -789,10 +790,9 @@ pub trait SqlxValueRef<'r>: Sized {
/// It is used to execute queries and interact with the database. The connection
/// is established when the structure is created and closed when
/// [`Self::close()`] is called.
#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct Database {
_url: String,
inner: DatabaseImpl,
inner: Arc<DatabaseImpl>,
}

#[derive(Debug)]
Expand Down Expand Up @@ -837,26 +837,23 @@ impl Database {
if url.starts_with("sqlite:") {
let inner = DatabaseSqlite::new(&url).await?;
return Ok(Self {
_url: url,
inner: DatabaseImpl::Sqlite(inner),
inner: Arc::new(DatabaseImpl::Sqlite(inner)),
});
}

#[cfg(feature = "postgres")]
if url.starts_with("postgresql:") {
let inner = DatabasePostgres::new(&url).await?;
return Ok(Self {
_url: url,
inner: DatabaseImpl::Postgres(inner),
inner: Arc::new(DatabaseImpl::Postgres(inner)),
});
}

#[cfg(feature = "mysql")]
if url.starts_with("mysql:") {
let inner = DatabaseMySql::new(&url).await?;
return Ok(Self {
_url: url,
inner: DatabaseImpl::MySql(inner),
inner: Arc::new(DatabaseImpl::MySql(inner)),
});
}

Expand Down Expand Up @@ -886,7 +883,7 @@ impl Database {
/// }
/// ```
pub async fn close(&self) -> Result<()> {
match &self.inner {
match &*self.inner {
#[cfg(feature = "sqlite")]
DatabaseImpl::Sqlite(inner) => inner.close().await,
#[cfg(feature = "postgres")]
Expand Down Expand Up @@ -1124,7 +1121,7 @@ impl Database {
return Ok(());
}

let max_params = match self.inner {
let max_params = match &*self.inner {
// https://sqlite.org/limits.html#max_variable_number
// Assuming SQLite > 3.32.0 (2020-05-22)
#[cfg(feature = "sqlite")]
Expand Down Expand Up @@ -1471,7 +1468,7 @@ impl Database {
.collect::<Vec<_>>();
let values = SqlxValues(sea_query::Values(values));

let result = match &self.inner {
let result = match &*self.inner {
#[cfg(feature = "sqlite")]
DatabaseImpl::Sqlite(inner) => inner.raw_with(query, values).await?,
#[cfg(feature = "postgres")]
Expand All @@ -1487,7 +1484,7 @@ impl Database {
where
T: SqlxBinder + Send + Sync,
{
let result = match &self.inner {
let result = match &*self.inner {
#[cfg(feature = "sqlite")]
DatabaseImpl::Sqlite(inner) => inner.fetch_option(statement).await?.map(Row::Sqlite),
#[cfg(feature = "postgres")]
Expand All @@ -1502,7 +1499,7 @@ impl Database {
}

fn supports_returning(&self) -> bool {
match self.inner {
match &*self.inner {
#[cfg(feature = "sqlite")]
DatabaseImpl::Sqlite(_) => true,
#[cfg(feature = "postgres")]
Expand All @@ -1516,7 +1513,7 @@ impl Database {
where
T: SqlxBinder + Send + Sync,
{
let result = match &self.inner {
let result = match &*self.inner {
#[cfg(feature = "sqlite")]
DatabaseImpl::Sqlite(inner) => inner
.fetch_all(statement)
Expand Down Expand Up @@ -1547,7 +1544,7 @@ impl Database {
where
T: SqlxBinder + Send + Sync,
{
let result = match &self.inner {
let result = match &*self.inner {
#[cfg(feature = "sqlite")]
DatabaseImpl::Sqlite(inner) => inner.execute_statement(statement).await?,
#[cfg(feature = "postgres")]
Expand All @@ -1563,7 +1560,7 @@ impl Database {
&self,
statement: T,
) -> Result<StatementResult> {
let result = match &self.inner {
let result = match &*self.inner {
#[cfg(feature = "sqlite")]
DatabaseImpl::Sqlite(inner) => inner.execute_schema(statement).await?,
#[cfg(feature = "postgres")]
Expand All @@ -1578,7 +1575,7 @@ impl Database {

impl ColumnTypeMapper for Database {
fn sea_query_column_type_for(&self, column_type: ColumnType) -> sea_query::ColumnType {
match &self.inner {
match &*self.inner {
#[cfg(feature = "sqlite")]
DatabaseImpl::Sqlite(inner) => inner.sea_query_column_type_for(column_type),
#[cfg(feature = "postgres")]
Expand Down Expand Up @@ -1735,45 +1732,6 @@ impl DatabaseBackend for Database {
}
}

#[async_trait]
impl DatabaseBackend for std::sync::Arc<Database> {
async fn insert_or_update<T: Model>(&self, data: &mut T) -> Result<()> {
Database::insert_or_update(self, data).await
}

async fn insert<T: Model>(&self, data: &mut T) -> Result<()> {
Database::insert(self, data).await
}

async fn update<T: Model>(&self, data: &mut T) -> Result<()> {
Database::update(self, data).await
}

async fn bulk_insert<T: Model>(&self, data: &mut [T]) -> Result<()> {
Database::bulk_insert(self, data).await
}

async fn bulk_insert_or_update<T: Model>(&self, data: &mut [T]) -> Result<()> {
Database::bulk_insert_or_update(self, data).await
}

async fn query<T: Model>(&self, query: &Query<T>) -> Result<Vec<T>> {
Database::query(self, query).await
}

async fn get<T: Model>(&self, query: &Query<T>) -> Result<Option<T>> {
Database::get(self, query).await
}

async fn exists<T: Model>(&self, query: &Query<T>) -> Result<bool> {
Database::exists(self, query).await
}

async fn delete<T: Model>(&self, query: &Query<T>) -> Result<StatementResult> {
Database::delete(self, query).await
}
}

/// Result of a statement execution.
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct StatementResult {
Expand Down
2 changes: 1 addition & 1 deletion cot/src/openapi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -837,7 +837,7 @@ impl ApiOperationPart for Method {}
impl ApiOperationPart for Session {}
impl ApiOperationPart for Auth {}
#[cfg(feature = "db")]
impl ApiOperationPart for crate::request::extractors::RequestDb {}
impl ApiOperationPart for crate::db::Database {}

impl<D: JsonSchema> ApiOperationPart for Json<D> {
fn modify_api_operation(
Expand Down
16 changes: 8 additions & 8 deletions cot/src/project.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1224,11 +1224,11 @@ impl Bootstrapper<WithApps> {
}

#[cfg(feature = "db")]
async fn init_database(config: &DatabaseConfig) -> cot::Result<Option<Arc<Database>>> {
async fn init_database(config: &DatabaseConfig) -> cot::Result<Option<Database>> {
match &config.url {
Some(url) => {
let database = Database::new(url.as_str()).await?;
Ok(Some(Arc::new(database)))
Ok(Some(database))
}
None => Ok(None),
}
Expand Down Expand Up @@ -1626,7 +1626,7 @@ impl BootstrapPhase for WithDatabase {
type Apps = <WithApps as BootstrapPhase>::Apps;
type Router = <WithApps as BootstrapPhase>::Router;
#[cfg(feature = "db")]
type Database = Option<Arc<Database>>;
type Database = Option<Database>;
type AuthBackend = <WithApps as BootstrapPhase>::AuthBackend;
#[cfg(feature = "cache")]
type Cache = ();
Expand All @@ -1649,7 +1649,7 @@ impl BootstrapPhase for WithCache {
type Apps = <WithApps as BootstrapPhase>::Apps;
type Router = <WithApps as BootstrapPhase>::Router;
#[cfg(feature = "db")]
type Database = Option<Arc<Database>>;
type Database = <WithDatabase as BootstrapPhase>::Database;
type AuthBackend = <WithApps as BootstrapPhase>::AuthBackend;
#[cfg(feature = "cache")]
type Cache = Cache;
Expand Down Expand Up @@ -1791,7 +1791,7 @@ impl ProjectContext<WithApps> {
#[must_use]
fn with_database(
self,
#[cfg(feature = "db")] database: Option<Arc<Database>>,
#[cfg(feature = "db")] database: Option<Database>,
) -> ProjectContext<WithDatabase> {
ProjectContext {
config: self.config,
Expand Down Expand Up @@ -1931,7 +1931,7 @@ impl<S: BootstrapPhase<Cache = Cache>> ProjectContext<S> {
}

#[cfg(feature = "db")]
impl<S: BootstrapPhase<Database = Option<Arc<Database>>>> ProjectContext<S> {
impl<S: BootstrapPhase<Database = Option<Database>>> ProjectContext<S> {
/// Returns the database for the project, if it is enabled.
///
/// # Examples
Expand All @@ -1952,7 +1952,7 @@ impl<S: BootstrapPhase<Database = Option<Arc<Database>>>> ProjectContext<S> {
/// ```
#[must_use]
#[cfg(feature = "db")]
pub fn try_database(&self) -> Option<&Arc<Database>> {
pub fn try_database(&self) -> Option<&Database> {
self.database.as_ref()
}

Expand Down Expand Up @@ -1980,7 +1980,7 @@ impl<S: BootstrapPhase<Database = Option<Arc<Database>>>> ProjectContext<S> {
#[cfg(feature = "db")]
#[must_use]
#[track_caller]
pub fn database(&self) -> &Arc<Database> {
pub fn database(&self) -> &Database {
self.try_database().expect(
"Database missing. Did you forget to add the database when configuring CotProject?",
)
Expand Down
6 changes: 3 additions & 3 deletions cot/src/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ pub trait RequestExt: private::Sealed {
/// ```
#[cfg(feature = "db")]
#[must_use]
fn db(&self) -> &Arc<Database>;
fn db(&self) -> &Database;

/// Get the content type of the request.
///
Expand Down Expand Up @@ -318,7 +318,7 @@ impl RequestExt for Request {
}

#[cfg(feature = "db")]
fn db(&self) -> &Arc<Database> {
fn db(&self) -> &Database {
self.context().database()
}

Expand Down Expand Up @@ -378,7 +378,7 @@ impl RequestExt for RequestHead {
}

#[cfg(feature = "db")]
fn db(&self) -> &Arc<Database> {
fn db(&self) -> &Database {
self.context().database()
}

Expand Down
Loading
Loading