Initializing multitenant

This commit is contained in:
2025-03-30 17:10:42 +02:00
parent 1a247f14ce
commit 50fdf22afc
14 changed files with 272 additions and 135 deletions

116
back/app/core/filter.py Normal file
View File

@@ -0,0 +1,116 @@
from collections.abc import Callable, Mapping
from typing import Any, Optional, Union
from pydantic import ValidationInfo, field_validator
from fastapi_filter.base.filter import BaseFilterModel
_odm_operator_transformer: dict[str, Callable[[Optional[str]], Optional[dict[str, Any]]]] = {
"neq": lambda value: {"$ne": value},
"gt": lambda value: {"$gt": value},
"gte": lambda value: {"$gte": value},
"in": lambda value: {"$in": value},
"isnull": lambda value: None if value else {"$ne": None},
"lt": lambda value: {"$lt": value},
"lte": lambda value: {"$lte": value},
"not": lambda value: {"$ne": value},
"ne": lambda value: {"$ne": value},
"not_in": lambda value: {"$nin": value},
"nin": lambda value: {"$nin": value},
"like": lambda value: {"$regex": f".*{value}.*"},
"ilike": lambda value: {"$regex": f".*{value}.*", "$options": "i"},
"exists": lambda value: {"$exists": value},
}
class Filter(BaseFilterModel):
"""Base filter for beanie related filters.
Example:
```python
class MyModel:
id: PrimaryKey()
name: StringField(null=True)
count: IntField()
created_at: DatetimeField()
class MyModelFilter(Filter):
id: Optional[int]
id__in: Optional[str]
count: Optional[int]
count__lte: Optional[int]
created_at__gt: Optional[datetime]
name__ne: Optional[str]
name__nin: Optional[list[str]]
name__isnull: Optional[bool]
```
"""
def sort(self, query):
if not self.ordering_values:
return query
return query.sort(*self.ordering_values)
@field_validator("*", mode="before")
@classmethod
def split_str(
cls: type["BaseFilterModel"], value: Optional[str], field: ValidationInfo
) -> Optional[Union[list[str], str]]:
if (
field.field_name is not None
and (
field.field_name == cls.Constants.ordering_field_name
or field.field_name.endswith("__in")
or field.field_name.endswith("__nin")
)
and isinstance(value, str)
):
if not value:
# Empty string should return [] not ['']
return []
return list(value.split(","))
return value
def _get_filter_conditions(self, nesting_depth: int = 1) -> list[tuple[Mapping[str, Any], Mapping[str, Any]]]:
filter_conditions: list[tuple[Mapping[str, Any], Mapping[str, Any]]] = []
for field_name, value in self.filtering_fields:
field_value = getattr(self, field_name)
if isinstance(field_value, Filter):
if not field_value.model_dump(exclude_none=True, exclude_unset=True):
continue
filter_conditions.append(
(
{field_name: _odm_operator_transformer["neq"](None)},
{"fetch_links": True, "nesting_depth": nesting_depth},
)
)
for part, part_options in field_value._get_filter_conditions(nesting_depth=nesting_depth + 1): # noqa: SLF001
for sub_field_name, sub_value in part.items():
filter_conditions.append(
(
{f"{field_name}.{sub_field_name}": sub_value},
{"fetch_links": True, "nesting_depth": nesting_depth, **part_options},
)
)
elif "__" in field_name:
stripped_field_name, operator = field_name.split("__")
search_criteria = _odm_operator_transformer[operator](value)
filter_conditions.append(({stripped_field_name: search_criteria}, {}))
elif field_name == self.Constants.search_field_name and hasattr(self.Constants, "search_model_fields"):
search_conditions = [
{search_field: _odm_operator_transformer["ilike"](value)}
for search_field in self.Constants.search_model_fields
]
filter_conditions.append(({"$or": search_conditions}, {}))
else:
filter_conditions.append(({field_name: value}, {}))
return filter_conditions
def filter(self, query):
data = self._get_filter_conditions()
for filter_condition, filter_kwargs in data:
query = query.find(filter_condition, **filter_kwargs)
return query

View File

@@ -1,21 +1,66 @@
from datetime import datetime from datetime import datetime, UTC
from typing import Optional
from beanie import Document from beanie import PydanticObjectId
from pydantic import BaseModel, Field, validator from pydantic import BaseModel, Field, computed_field
class CrudDocument(Document): class CrudDocument(BaseModel):
_id: str id: Optional[PydanticObjectId] = Field(alias="_id", default=None)
created_at: datetime = Field(default=datetime.utcnow(), nullable=False, title="Créé le") created_at: datetime = Field(default=datetime.now(UTC), nullable=False, title="Créé le")
# created_by: str
updated_at: datetime = Field(default_factory=datetime.utcnow, nullable=False, title="Modifié le") updated_at: datetime = Field(default_factory=datetime.utcnow, nullable=False, title="Modifié le")
# updated_by: str
@validator("label", always=True, check_fields=False) @computed_field
def generate_label(cls, v, values, **kwargs): def label(self) -> str:
return v return self.compute_label()
def compute_label(self) -> str:
return ""
class Settings: class Settings:
fulltext_search = [] fulltext_search = []
@classmethod
def _collection_name(cls):
return cls.__name__
@classmethod
def _get_collection(cls, db):
return db.get_collection(cls._collection_name())
@classmethod
async def create(cls, db, create_schema):
values = cls.model_validate(create_schema.model_dump()).model_dump(mode="json")
result = await cls._get_collection(db).insert_one(values)
return await cls.get(db, result.inserted_id)
@classmethod
def list(cls, db, filters):
query = filters.filter(cls._get_collection(db))
query = filters.sort(query)
return query
@classmethod
async def get(cls, db, model_id):
return cls.model_validate(await cls._get_collection(db).find_one({"_id": model_id}))
@classmethod
async def update(cls, db, model, update_schema):
update_query = {
"$set": {field: value for field, value in update_schema.model_dump(mode="json").items()}
}
await cls._get_collection(db).update_one({"_id": model.id}, update_query)
return await cls.get(db, model.id)
@classmethod
async def delete(cls, db, model):
await cls._get_collection(db).delete_one({"_id": model.id})
def text_area(*args, **kwargs): def text_area(*args, **kwargs):
kwargs['widget'] = { kwargs['widget'] = {

View File

@@ -3,10 +3,12 @@ from beanie.odm.operators.find.comparison import In
from beanie.operators import And, RegEx, Eq from beanie.operators import And, RegEx, Eq
from fastapi import APIRouter, HTTPException, Depends from fastapi import APIRouter, HTTPException, Depends
from fastapi_pagination import Page, Params, add_pagination from fastapi_filter import FilterDepends
from fastapi_pagination.ext.beanie import paginate from fastapi_pagination import Page, add_pagination
from fastapi_pagination.ext.motor import paginate
from ..user.manager import get_current_user, get_current_superuser, get_current_user_and_firm from ..db import get_db_client
from ..user.manager import get_current_user
def parse_sort(sort_by): def parse_sort(sort_by):
@@ -15,8 +17,8 @@ def parse_sort(sort_by):
fields = [] fields = []
for field in sort_by.split(','): for field in sort_by.split(','):
dir, col = field.split('(') direction, column = field.split('(')
fields.append((col[:-1], 1 if dir == 'asc' else -1)) fields.append((column[:-1], 1 if direction == 'asc' else -1))
return fields return fields
@@ -33,6 +35,7 @@ def parse_query(query: str, model):
for criterion in query.split(' AND '): for criterion in query.split(' AND '):
[column, operator, value] = criterion.split(' ', 2) [column, operator, value] = criterion.split(' ', 2)
column = column.lower() column = column.lower()
operand = None
if column == 'fulltext': if column == 'fulltext':
if not model.Settings.fulltext_search: if not model.Settings.fulltext_search:
continue continue
@@ -50,68 +53,67 @@ def parse_query(query: str, model):
elif operator == 'in': elif operator == 'in':
operand = In(column, value.split(',')) operand = In(column, value.split(','))
and_array.append(operand) if operand:
and_array.append(operand)
if and_array: if and_array:
return And(*and_array) if len(and_array) > 1 else and_array[0] return And(*and_array) if len(and_array) > 1 else and_array[0]
else: else:
return {} return {}
#user=Depends(get_current_user)
def get_tenant_db_cursor(instance: str="westside", firm: str="cht", db_client=Depends(get_db_client), user=None):
return db_client[f"tenant_{instance}_{firm}"]
def get_crud_router(model, model_create, model_read, model_update): def get_crud_router(model, model_create, model_read, model_update, model_filter):
model_name = model.__name__
router = APIRouter() router = APIRouter()
@router.post("/", response_description="{} added to the database".format(model.__name__)) @router.post("/", response_description=f"{model_name} added to the database")
async def create(instance: str, firm: str, item: model_create, user=Depends(get_current_user)) -> dict: async def create(schema: model_create, db=Depends(get_tenant_db_cursor)) -> model_read:
await item.validate_foreign_key() await schema.validate_foreign_key(db)
o = await model(**item.dict()).create() record = await model.create(db, schema)
return {"message": "{} added successfully".format(model.__name__), "id": o.id} return model_read.from_model(record)
@router.get("/{id}", response_description="{} record retrieved".format(model.__name__)) @router.get("/{record_id}", response_description=f"{model_name} record retrieved")
async def read_id(instance: str, firm: str, id: PydanticObjectId, user=Depends(get_current_user)) -> model_read: async def read_one(record_id: PydanticObjectId, db=Depends(get_tenant_db_cursor)) -> model_read:
item = await model.get(id) record = await model.get(db, record_id)
return model_read(**item.dict()) if not record:
@router.get("/", response_model=Page[model_read], response_description="{} records retrieved".format(model.__name__))
async def read_list(instance: str, firm: str, size: int = 50, page: int = 1, sort_by: str = None, query: str = None,
user=Depends(get_current_user_and_firm)) -> Page[model_read]:
sort = parse_sort(sort_by)
query = parse_query(query, model_read)
items = paginate(model.find(query), Params(**{'size': size, 'page': page}))
return await items
@router.put("/{id}", response_description="{} record updated".format(model.__name__))
async def update(instance: str, firm: str, id: PydanticObjectId, req: model_update, user=Depends(get_current_user)) -> model_read:
req = {k: v for k, v in req.dict().items() if v is not None}
update_query = {"$set": {
field: value for field, value in req.items()
}}
item = await model.get(id)
if not item:
raise HTTPException( raise HTTPException(
status_code=404, status_code=404,
detail="{} record not found!".format(model.__name__) detail=f"{model_name} record not found!"
) )
await item.update(update_query) return model_read.from_model(record)
return model_read(**item.dict())
@router.delete("/{id}", response_description="{} record deleted from the database".format(model.__name__)) @router.get("/", response_model=Page[model_read], response_description=f"{model_name} records retrieved")
async def delete(instance: str, firm: str, id: PydanticObjectId, user=Depends(get_current_superuser)) -> dict: async def read_list(filters: model_filter=FilterDepends(model_filter), db=Depends(get_tenant_db_cursor)) -> Page[model_read]:
item = await model.get(id) return await paginate(model.list(db, filters))
if not item: @router.put("/{record_id}", response_description=f"{model_name} record updated")
async def update(record_id: PydanticObjectId, schema: model_update, db=Depends(get_tenant_db_cursor)) -> model_read:
record = await model.get(db, record_id)
if not record:
raise HTTPException( raise HTTPException(
status_code=404, status_code=404,
detail="{} record not found!".format(model.__name__) detail=f"{model_name} record not found!"
) )
await item.delete() record = await model.update(db, record, schema)
return model_read.from_model(record)
@router.delete("/{record_id}", response_description=f"{model_name} record deleted from the database")
async def delete(record_id: PydanticObjectId, db=Depends(get_tenant_db_cursor)) -> dict:
record = await model.get(db, record_id)
if not record:
raise HTTPException(
status_code=404,
detail=f"{model_name} record not found!"
)
await model.delete(db, record)
return { return {
"message": "{} deleted successfully".format(model.__name__) "message": f"{model_name} deleted successfully"
} }
add_pagination(router) add_pagination(router)

View File

@@ -1,12 +1,16 @@
from pydantic import BaseModel from pydantic import BaseModel, Field
class Reader(BaseModel): class Reader(BaseModel):
pass id: str = Field()
# class Config:
# fields = {'id': '_id'} @classmethod
def from_model(cls, model):
schema = cls.model_validate(model.model_dump())
schema.id = model.id
return schema
class Writer(BaseModel): class Writer(BaseModel):
async def validate_foreign_key(self): async def validate_foreign_key(self, db):
pass pass

View File

@@ -3,20 +3,22 @@ import motor.motor_asyncio
from beanie import init_beanie from beanie import init_beanie
from .user import User, AccessToken from .user import User, AccessToken
from .entity.models import Entity
from .template.models import ContractTemplate, ProvisionTemplate
from .contract.models import ContractDraft, Contract
DB_PASSWORD = "IBO3eber0mdw2R9pnInLdtFykQFY2f06" DB_PASSWORD = "IBO3eber0mdw2R9pnInLdtFykQFY2f06"
DATABASE_URL = f"mongodb://root:{DB_PASSWORD}@mongo:27017/" DATABASE_URL = f"mongodb://root:{DB_PASSWORD}@mongo:27017/"
client = motor.motor_asyncio.AsyncIOMotorClient(
DATABASE_URL, uuidRepresentation="standard"
)
async def init_db(): async def init_db():
client = motor.motor_asyncio.AsyncIOMotorClient( await init_beanie(database=client.core,
DATABASE_URL, uuidRepresentation="standard" document_models=[User, AccessToken, ], # Entity, ContractTemplate, ProvisionTemplate, ContractDraft, Contract,
)
await init_beanie(database=client.db_name,
document_models=[User, AccessToken, Entity, ContractTemplate, ProvisionTemplate, ContractDraft,
Contract, ],
allow_index_dropping=True) allow_index_dropping=True)
async def stop_db():
client.close()
def get_db_client():
yield client

View File

@@ -1,10 +1,11 @@
from datetime import date, datetime from datetime import date, datetime
from typing import List, Literal, Optional from typing import List, Literal, Optional
from pydantic import Field, BaseModel, validator from pydantic import Field, BaseModel
from beanie import Indexed from beanie import Indexed
from ..core.models import CrudDocument from ..core.models import CrudDocument
from ..core.filter import Filter
class EntityType(BaseModel): class EntityType(BaseModel):
@@ -75,14 +76,12 @@ class Entity(CrudDocument):
Fiche d'un client Fiche d'un client
""" """
entity_data: Individual | Corporation | Institution = Field(..., discriminator='type') entity_data: Individual | Corporation | Institution = Field(..., discriminator='type')
label: str = None
address: str = Field(default="", title='Adresse') address: str = Field(default="", title='Adresse')
@validator("label", always=True) def compute_label(self) -> str:
def generate_label(cls, v, values, **kwargs): if not self.entity_data:
if 'entity_data' not in values: return ""
return v return self.entity_data.label
return values['entity_data'].label
class Settings(CrudDocument.Settings): class Settings(CrudDocument.Settings):
fulltext_search = ['label'] fulltext_search = ['label']
@@ -96,6 +95,12 @@ class Entity(CrudDocument):
class Config: class Config:
title = 'Client' title = 'Client'
@classmethod
def get_create_resource(cls): class EntityFilters(Filter):
print('coucou') name__like: Optional[str] = None
order_by: Optional[list[str]] = None
class Constants(Filter.Constants):
model = Entity
search_model_fields = ["name"]

View File

@@ -1,7 +1,5 @@
from ..core.routes import get_crud_router from ..core.routes import get_crud_router
from .models import Entity from .models import Entity, EntityFilters
from .schemas import EntityCreate, EntityRead, EntityUpdate from .schemas import EntityCreate, EntityRead, EntityUpdate
router = get_crud_router(Entity, EntityCreate, EntityRead, EntityUpdate) router = get_crud_router(Entity, EntityCreate, EntityRead, EntityUpdate, EntityFilters)

View File

@@ -1,14 +1,11 @@
from typing import Optional from pydantic import Field
from pydantic import BaseModel, Field
from .models import Entity, Institution, Individual, Corporation from .models import Entity, Institution, Individual, Corporation
from ..core.schemas import Writer from ..core.schemas import Writer, Reader
class EntityRead(Entity, Reader):
class EntityRead(Entity):
pass pass
class EntityCreate(Writer): class EntityCreate(Writer):
entity_data: Individual | Corporation | Institution = Field(..., discriminator='type') entity_data: Individual | Corporation | Institution = Field(..., discriminator='type')
address: str = Field(default="", title='Adresse') address: str = Field(default="", title='Adresse')
@@ -16,6 +13,5 @@ class EntityCreate(Writer):
class Config: class Config:
title = "Création d'un client" title = "Création d'un client"
class EntityUpdate(EntityCreate): class EntityUpdate(EntityCreate):
pass pass

View File

@@ -1,19 +1,20 @@
from contextlib import asynccontextmanager
from fastapi import FastAPI from fastapi import FastAPI
from .contract import contract_router #from .contract import contract_router
from .db import init_db from .db import init_db, stop_db
from .user import user_router, get_auth_router from .user import user_router, get_auth_router
from .entity import entity_router from .entity import entity_router
from .template import template_router #from .template import template_router
# from .order import order_router
app = FastAPI(root_path="/api/v1")
@app.on_event("startup") @asynccontextmanager
async def on_startup(): async def lifespan(app: FastAPI):
await init_db() await init_db()
yield
await stop_db()
app = FastAPI(root_path="/api/v1", lifespan=lifespan)
app.include_router(get_auth_router(), prefix="/auth", tags=["auth"], ) app.include_router(get_auth_router(), prefix="/auth", tags=["auth"], )
app.include_router(user_router, prefix="/users", tags=["users"], ) app.include_router(user_router, prefix="/users", tags=["users"], )
@@ -21,8 +22,8 @@ app.include_router(user_router, prefix="/users", tags=["users"], )
multitenant_prefix = "/{instance}/{firm}" multitenant_prefix = "/{instance}/{firm}"
app.include_router(entity_router, prefix=f"{multitenant_prefix}/entity", tags=["entity"], ) app.include_router(entity_router, prefix=f"{multitenant_prefix}/entity", tags=["entity"], )
app.include_router(template_router, prefix=f"{multitenant_prefix}/template", tags=["template"], ) #app.include_router(template_router, prefix=f"{multitenant_prefix}/template", tags=["template"], )
app.include_router(contract_router, prefix=f"{multitenant_prefix}/contract", tags=["contract"], ) #app.include_router(contract_router, prefix=f"{multitenant_prefix}/contract", tags=["contract"], )
if __name__ == '__main__': if __name__ == '__main__':
import uvicorn import uvicorn

View File

@@ -1 +0,0 @@
from .routes import router as order_router

View File

@@ -1,13 +0,0 @@
from datetime import datetime
from beanie import Document
class Order(Document):
id: str
client: str
created_at: datetime
updated_at: datetime
class Settings:
name = "order_collection"

View File

@@ -1,5 +0,0 @@
from ..core.routes import get_crud_router
from .models import Order
from .schemas import OrderCreate, OrderRead, OrderUpdate
router = get_crud_router(Order, OrderCreate, OrderRead, OrderUpdate)

View File

@@ -1,14 +0,0 @@
import uuid
from pydantic import BaseModel
class OrderRead(BaseModel):
pass
class OrderCreate(BaseModel):
login: str
class OrderUpdate(BaseModel):
pass

View File

@@ -1,7 +1,8 @@
fastapi fastapi
fastapi_users fastapi_users
fastapi_users_db_beanie fastapi_users_db_beanie
fastapi-pagination fastapi_pagination
fastapi_filter
uvicorn uvicorn
jinja2 jinja2
weasyprint weasyprint