Initializing multitenant

This commit is contained in:
2025-03-30 17:10:42 +02:00
parent 1a247f14ce
commit 50fdf22afc
14 changed files with 272 additions and 135 deletions

116
back/app/core/filter.py Normal file
View File

@@ -0,0 +1,116 @@
from collections.abc import Callable, Mapping
from typing import Any, Optional, Union
from pydantic import ValidationInfo, field_validator
from fastapi_filter.base.filter import BaseFilterModel
_odm_operator_transformer: dict[str, Callable[[Optional[str]], Optional[dict[str, Any]]]] = {
"neq": lambda value: {"$ne": value},
"gt": lambda value: {"$gt": value},
"gte": lambda value: {"$gte": value},
"in": lambda value: {"$in": value},
"isnull": lambda value: None if value else {"$ne": None},
"lt": lambda value: {"$lt": value},
"lte": lambda value: {"$lte": value},
"not": lambda value: {"$ne": value},
"ne": lambda value: {"$ne": value},
"not_in": lambda value: {"$nin": value},
"nin": lambda value: {"$nin": value},
"like": lambda value: {"$regex": f".*{value}.*"},
"ilike": lambda value: {"$regex": f".*{value}.*", "$options": "i"},
"exists": lambda value: {"$exists": value},
}
class Filter(BaseFilterModel):
"""Base filter for beanie related filters.
Example:
```python
class MyModel:
id: PrimaryKey()
name: StringField(null=True)
count: IntField()
created_at: DatetimeField()
class MyModelFilter(Filter):
id: Optional[int]
id__in: Optional[str]
count: Optional[int]
count__lte: Optional[int]
created_at__gt: Optional[datetime]
name__ne: Optional[str]
name__nin: Optional[list[str]]
name__isnull: Optional[bool]
```
"""
def sort(self, query):
if not self.ordering_values:
return query
return query.sort(*self.ordering_values)
@field_validator("*", mode="before")
@classmethod
def split_str(
cls: type["BaseFilterModel"], value: Optional[str], field: ValidationInfo
) -> Optional[Union[list[str], str]]:
if (
field.field_name is not None
and (
field.field_name == cls.Constants.ordering_field_name
or field.field_name.endswith("__in")
or field.field_name.endswith("__nin")
)
and isinstance(value, str)
):
if not value:
# Empty string should return [] not ['']
return []
return list(value.split(","))
return value
def _get_filter_conditions(self, nesting_depth: int = 1) -> list[tuple[Mapping[str, Any], Mapping[str, Any]]]:
filter_conditions: list[tuple[Mapping[str, Any], Mapping[str, Any]]] = []
for field_name, value in self.filtering_fields:
field_value = getattr(self, field_name)
if isinstance(field_value, Filter):
if not field_value.model_dump(exclude_none=True, exclude_unset=True):
continue
filter_conditions.append(
(
{field_name: _odm_operator_transformer["neq"](None)},
{"fetch_links": True, "nesting_depth": nesting_depth},
)
)
for part, part_options in field_value._get_filter_conditions(nesting_depth=nesting_depth + 1): # noqa: SLF001
for sub_field_name, sub_value in part.items():
filter_conditions.append(
(
{f"{field_name}.{sub_field_name}": sub_value},
{"fetch_links": True, "nesting_depth": nesting_depth, **part_options},
)
)
elif "__" in field_name:
stripped_field_name, operator = field_name.split("__")
search_criteria = _odm_operator_transformer[operator](value)
filter_conditions.append(({stripped_field_name: search_criteria}, {}))
elif field_name == self.Constants.search_field_name and hasattr(self.Constants, "search_model_fields"):
search_conditions = [
{search_field: _odm_operator_transformer["ilike"](value)}
for search_field in self.Constants.search_model_fields
]
filter_conditions.append(({"$or": search_conditions}, {}))
else:
filter_conditions.append(({field_name: value}, {}))
return filter_conditions
def filter(self, query):
data = self._get_filter_conditions()
for filter_condition, filter_kwargs in data:
query = query.find(filter_condition, **filter_kwargs)
return query

View File

@@ -1,21 +1,66 @@
from datetime import datetime
from datetime import datetime, UTC
from typing import Optional
from beanie import Document
from pydantic import BaseModel, Field, validator
from beanie import PydanticObjectId
from pydantic import BaseModel, Field, computed_field
class CrudDocument(Document):
_id: str
created_at: datetime = Field(default=datetime.utcnow(), nullable=False, title="Créé le")
class CrudDocument(BaseModel):
id: Optional[PydanticObjectId] = Field(alias="_id", default=None)
created_at: datetime = Field(default=datetime.now(UTC), nullable=False, title="Créé le")
# created_by: str
updated_at: datetime = Field(default_factory=datetime.utcnow, nullable=False, title="Modifié le")
# updated_by: str
@validator("label", always=True, check_fields=False)
def generate_label(cls, v, values, **kwargs):
return v
@computed_field
def label(self) -> str:
return self.compute_label()
def compute_label(self) -> str:
return ""
class Settings:
fulltext_search = []
@classmethod
def _collection_name(cls):
return cls.__name__
@classmethod
def _get_collection(cls, db):
return db.get_collection(cls._collection_name())
@classmethod
async def create(cls, db, create_schema):
values = cls.model_validate(create_schema.model_dump()).model_dump(mode="json")
result = await cls._get_collection(db).insert_one(values)
return await cls.get(db, result.inserted_id)
@classmethod
def list(cls, db, filters):
query = filters.filter(cls._get_collection(db))
query = filters.sort(query)
return query
@classmethod
async def get(cls, db, model_id):
return cls.model_validate(await cls._get_collection(db).find_one({"_id": model_id}))
@classmethod
async def update(cls, db, model, update_schema):
update_query = {
"$set": {field: value for field, value in update_schema.model_dump(mode="json").items()}
}
await cls._get_collection(db).update_one({"_id": model.id}, update_query)
return await cls.get(db, model.id)
@classmethod
async def delete(cls, db, model):
await cls._get_collection(db).delete_one({"_id": model.id})
def text_area(*args, **kwargs):
kwargs['widget'] = {

View File

@@ -3,10 +3,12 @@ from beanie.odm.operators.find.comparison import In
from beanie.operators import And, RegEx, Eq
from fastapi import APIRouter, HTTPException, Depends
from fastapi_pagination import Page, Params, add_pagination
from fastapi_pagination.ext.beanie import paginate
from fastapi_filter import FilterDepends
from fastapi_pagination import Page, add_pagination
from fastapi_pagination.ext.motor import paginate
from ..user.manager import get_current_user, get_current_superuser, get_current_user_and_firm
from ..db import get_db_client
from ..user.manager import get_current_user
def parse_sort(sort_by):
@@ -15,8 +17,8 @@ def parse_sort(sort_by):
fields = []
for field in sort_by.split(','):
dir, col = field.split('(')
fields.append((col[:-1], 1 if dir == 'asc' else -1))
direction, column = field.split('(')
fields.append((column[:-1], 1 if direction == 'asc' else -1))
return fields
@@ -33,6 +35,7 @@ def parse_query(query: str, model):
for criterion in query.split(' AND '):
[column, operator, value] = criterion.split(' ', 2)
column = column.lower()
operand = None
if column == 'fulltext':
if not model.Settings.fulltext_search:
continue
@@ -50,68 +53,67 @@ def parse_query(query: str, model):
elif operator == 'in':
operand = In(column, value.split(','))
and_array.append(operand)
if operand:
and_array.append(operand)
if and_array:
return And(*and_array) if len(and_array) > 1 else and_array[0]
else:
return {}
#user=Depends(get_current_user)
def get_tenant_db_cursor(instance: str="westside", firm: str="cht", db_client=Depends(get_db_client), user=None):
return db_client[f"tenant_{instance}_{firm}"]
def get_crud_router(model, model_create, model_read, model_update):
def get_crud_router(model, model_create, model_read, model_update, model_filter):
model_name = model.__name__
router = APIRouter()
@router.post("/", response_description="{} added to the database".format(model.__name__))
async def create(instance: str, firm: str, item: model_create, user=Depends(get_current_user)) -> dict:
await item.validate_foreign_key()
o = await model(**item.dict()).create()
return {"message": "{} added successfully".format(model.__name__), "id": o.id}
@router.post("/", response_description=f"{model_name} added to the database")
async def create(schema: model_create, db=Depends(get_tenant_db_cursor)) -> model_read:
await schema.validate_foreign_key(db)
record = await model.create(db, schema)
return model_read.from_model(record)
@router.get("/{id}", response_description="{} record retrieved".format(model.__name__))
async def read_id(instance: str, firm: str, id: PydanticObjectId, user=Depends(get_current_user)) -> model_read:
item = await model.get(id)
return model_read(**item.dict())
@router.get("/", response_model=Page[model_read], response_description="{} records retrieved".format(model.__name__))
async def read_list(instance: str, firm: str, size: int = 50, page: int = 1, sort_by: str = None, query: str = None,
user=Depends(get_current_user_and_firm)) -> Page[model_read]:
sort = parse_sort(sort_by)
query = parse_query(query, model_read)
items = paginate(model.find(query), Params(**{'size': size, 'page': page}))
return await items
@router.put("/{id}", response_description="{} record updated".format(model.__name__))
async def update(instance: str, firm: str, id: PydanticObjectId, req: model_update, user=Depends(get_current_user)) -> model_read:
req = {k: v for k, v in req.dict().items() if v is not None}
update_query = {"$set": {
field: value for field, value in req.items()
}}
item = await model.get(id)
if not item:
@router.get("/{record_id}", response_description=f"{model_name} record retrieved")
async def read_one(record_id: PydanticObjectId, db=Depends(get_tenant_db_cursor)) -> model_read:
record = await model.get(db, record_id)
if not record:
raise HTTPException(
status_code=404,
detail="{} record not found!".format(model.__name__)
detail=f"{model_name} record not found!"
)
await item.update(update_query)
return model_read(**item.dict())
return model_read.from_model(record)
@router.delete("/{id}", response_description="{} record deleted from the database".format(model.__name__))
async def delete(instance: str, firm: str, id: PydanticObjectId, user=Depends(get_current_superuser)) -> dict:
item = await model.get(id)
@router.get("/", response_model=Page[model_read], response_description=f"{model_name} records retrieved")
async def read_list(filters: model_filter=FilterDepends(model_filter), db=Depends(get_tenant_db_cursor)) -> Page[model_read]:
return await paginate(model.list(db, filters))
if not item:
@router.put("/{record_id}", response_description=f"{model_name} record updated")
async def update(record_id: PydanticObjectId, schema: model_update, db=Depends(get_tenant_db_cursor)) -> model_read:
record = await model.get(db, record_id)
if not record:
raise HTTPException(
status_code=404,
detail="{} record not found!".format(model.__name__)
detail=f"{model_name} record not found!"
)
await item.delete()
record = await model.update(db, record, schema)
return model_read.from_model(record)
@router.delete("/{record_id}", response_description=f"{model_name} record deleted from the database")
async def delete(record_id: PydanticObjectId, db=Depends(get_tenant_db_cursor)) -> dict:
record = await model.get(db, record_id)
if not record:
raise HTTPException(
status_code=404,
detail=f"{model_name} record not found!"
)
await model.delete(db, record)
return {
"message": "{} deleted successfully".format(model.__name__)
"message": f"{model_name} deleted successfully"
}
add_pagination(router)

View File

@@ -1,12 +1,16 @@
from pydantic import BaseModel
from pydantic import BaseModel, Field
class Reader(BaseModel):
pass
# class Config:
# fields = {'id': '_id'}
id: str = Field()
@classmethod
def from_model(cls, model):
schema = cls.model_validate(model.model_dump())
schema.id = model.id
return schema
class Writer(BaseModel):
async def validate_foreign_key(self):
async def validate_foreign_key(self, db):
pass