Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 33 additions & 7 deletions pymfdata/mongodb/connection.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,41 @@
from motor.motor_asyncio import AsyncIOMotorClient
from typing import Union
from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorDatabase
from pymongo import MongoClient
from pymongo.database import Database


class AsyncMotor:
class AsyncPyMongo:
client: AsyncIOMotorClient = None
db: AsyncIOMotorDatabase = None

def __init__(self, db_name, db_uri: str) -> None:
self._db_uri = db_uri
self.db_name = db_name
self.client: Union[AsyncIOMotorClient, None] = None
self.db = db_name

async def connect(self):
self.client = AsyncIOMotorClient(self._db_uri)
"""
:param minPoolSize: Minimum Pool Size
:param maxPoolSize: Maximum Pool size (Default: 100)
"""
async def connect(self, **kwargs):
self.client = AsyncIOMotorClient(self._db_uri, **kwargs)

async def disconnect(self):
self.client.close()


class SyncPyMongo:
client: MongoClient = None
db: Database = None

def __init__(self, db_name, db_uri: str) -> None:
self._db_uri = db_uri
self.db = db_name

"""
:param minPoolSize: Minimum Pool Size
:param maxPoolSize: Maximum Pool Size (Default: 100)
"""
def connect(self, **kwargs):
self.client = MongoClient(self._db_uri, **kwargs)

def disconnect(self):
self.client.close()
35 changes: 35 additions & 0 deletions pymfdata/mongodb/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from bson import ObjectId
from pydantic import BaseModel, Field
from pydantic.generics import GenericModel
from typing import Generic, TypeVar


class BaseObjectId(ObjectId):
@classmethod
def __get_validators__(cls):
yield cls.validate

@classmethod
def validate(cls, v):
if ObjectId.is_valid(v) is False:
raise TypeError('ObjectId invalid')
return ObjectId(v)

@classmethod
def __modify_schema__(cls, field_schema):
field_schema.update(type="string")


_T = TypeVar('_T', bound=BaseObjectId)


class BaseMongoDBModel(GenericModel, Generic[_T]):
id: _T = Field(default_factory=_T, alias="_id")

class Config:
# allow_population_by_field_name = True
arbitrary_types_allowed = True
json_encoders = {ObjectId: str}


_MT = TypeVar('_MT', bound=BaseMongoDBModel)
82 changes: 63 additions & 19 deletions pymfdata/mongodb/repository.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,86 @@
from abc import ABC
from bson import ObjectId
from typing import final, List, Optional
from typing import final, List, Optional, Protocol
from motor.core import AgnosticBaseCursor, AgnosticCollection, Cursor
from pymongo.collection import Collection

from pymfdata.mongodb.connection import AsyncMotor
from pymfdata.mongodb.connection import AsyncPyMongo, SyncPyMongo
from pymfdata.mongodb.models import _T, _MT


class AsyncRepository(ABC):
def __init__(self, collection_name: str, motor: AsyncMotor) -> None:
self._collection = motor.client[motor.db_name][collection_name]
class AsyncRepository(Protocol[_MT, _T]):
_collection: AgnosticCollection

@property
def collection(self) -> AgnosticCollection:
assert self._collection is not None
return self._collection

# def __init__(self, collection_name: str, motor: AsyncPyMongo) -> None:
# self._collection: AgnosticCollection = motor.client[motor.db][collection_name]

@final
async def delete_by_id(self, item_id: str) -> bool:
row = await self._collection.delete_one({"_id": ObjectId(item_id)})
async def delete_by_id(self, item_id: _T) -> bool:
row = await self.collection.delete_one({"_id": item_id})
if not row:
return False

return True

@final
async def find_all(self) -> List[dict]:
cursor = self._collection.find()
results = list(map(lambda item: item, await cursor.to_list(length=100)))
async def find_all(self, **kwargs) -> List[dict]:
cursor: AgnosticBaseCursor = self.collection.find()
return list(map(lambda item: item, await cursor.to_list(**kwargs)))

@final
async def find_by_id(self, item_id: _T) -> Optional[dict]:
row = await self.collection.find_one({"_id": item_id})
if not row:
return None

return row

@final
async def save(self, req: _MT) -> dict:
return await self.collection.insert_one(req.dict())

@final
async def update_by_id(self, item_id: _T, req: _MT):
await self.collection.update_one({"_id": item_id}, req.dict())


class SyncRepository(Protocol[_MT, _T]):
_collection: Collection

return results
@property
def collection(self) -> Collection:
assert self._collection is not None
return self._collection

@final
def delete_by_id(self, item_id: _T) -> bool:
row = self.collection.find_one({"_id": item_id})
if not row:
return False

return True

@final
def find_all(self, **kwargs) -> List[dict]:
cursor: Cursor = self.collection.find(**kwargs)
return list(map(lambda item: item, cursor))

@final
async def find_by_id(self, item_id: str) -> Optional[dict]:
row = await self._collection.find_one({"_id": ObjectId(item_id)})
def find_by_id(self, item_id: _T) -> Optional[dict]:
row = self.collection.find_one({"_id": item_id})
if not row:
return None

return row

@final
async def save(self, req: dict) -> dict:
return await self._collection.insert_one(req)
def save(self, req: _MT) -> _T:
return self.collection.insert_one(req.dict()).inserted_id

@final
async def update_by_id(self, item_id: str, req: dict) -> dict:
await self._collection.update_one({"_id": ObjectId(item_id)}, req)
return await self.find_by_id(item_id)
def update_by_id(self, item_id: _T, req: _MT):
self.collection.update_one({"_id": item_id}, req.dict())
1 change: 0 additions & 1 deletion pymfdata/rdb/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ def session(self) -> Session:


class SyncRepository(BaseSyncRepository, Protocol[_MT, _T]):

@property
def _model(self):
return get_args(self.__orig_bases__[0])[0]
Expand Down