From 5afe52bbf0283634b317674613ca899135b20892 Mon Sep 17 00:00:00 2001 From: ewandor Date: Tue, 15 Apr 2025 21:11:54 +0200 Subject: [PATCH] Switching to the registry paradigm --- api/rpk-api/firm/contract/print/__init__.py | 24 ++++---- api/rpk-api/firm/contract/routes_contract.py | 22 +++---- api/rpk-api/firm/contract/routes_draft.py | 18 +++--- api/rpk-api/firm/contract/routes_signature.py | 12 ++-- api/rpk-api/firm/core/depends.py | 61 ++++++++++++------- api/rpk-api/firm/core/routes.py | 25 ++++---- 6 files changed, 90 insertions(+), 72 deletions(-) diff --git a/api/rpk-api/firm/contract/print/__init__.py b/api/rpk-api/firm/contract/print/__init__.py index 3268c77..93b195d 100644 --- a/api/rpk-api/firm/contract/print/__init__.py +++ b/api/rpk-api/firm/contract/print/__init__.py @@ -12,7 +12,7 @@ from weasyprint.text.fonts import FontConfiguration from pathlib import Path -from firm.core.depends import get_tenant_db_cursor +from firm.core.depends import get_tenant_registry from firm.entity.models import Entity from firm.template.models import ProvisionTemplate from firm.contract.models import ContractDraft, Contract, ContractStatus, replace_variables_in_value @@ -87,15 +87,15 @@ def retrieve_signature_png(filepath): preview_router = APIRouter() @preview_router.get("/draft/{draft_id}", response_class=HTMLResponse, tags=["Contract Draft"]) -async def preview_draft(draft_id: str, db=Depends(get_tenant_db_cursor)) -> str: - draft = await build_model(await ContractDraft.get(db, draft_id)) +async def preview_draft(draft_id: str, reg=Depends(get_tenant_registry)) -> str: + draft = await build_model(await ContractDraft.get(reg.db, draft_id)) return await render_print('', draft) @preview_router.get("/signature/{signature_id}", response_class=HTMLResponse, tags=["Signature"]) -async def preview_contract_by_signature(signature_id: UUID, db=Depends(get_tenant_db_cursor)) -> str: - contract = await Contract.find_by_signature_id(db, signature_id) +async def preview_contract_by_signature(signature_id: UUID, reg=Depends(get_tenant_registry)) -> str: + contract = await Contract.find_by_signature_id(reg.db, signature_id) for p in contract.parties: if p.signature_affixed: p.signature_png = retrieve_signature_png(f'media/signatures/{p.signature_uuid}.png') @@ -104,8 +104,8 @@ async def preview_contract_by_signature(signature_id: UUID, db=Depends(get_tenan @preview_router.get("/{contract_id}", response_class=HTMLResponse, tags=["Contract"]) -async def preview_contract(contract_id: str, db=Depends(get_tenant_db_cursor)) -> str: - contract = await Contract.get(db, contract_id) +async def preview_contract(contract_id: str, reg=Depends(get_tenant_registry)) -> str: + contract = await Contract.get(reg.db, contract_id) for p in contract.parties: if p.signature_affixed: p.signature_png = retrieve_signature_png(f'media/signatures/{p.signature_uuid}.png') @@ -115,8 +115,8 @@ async def preview_contract(contract_id: str, db=Depends(get_tenant_db_cursor)) - print_router = APIRouter() @print_router.get("/pdf/{contract_id}", response_class=FileResponse, tags=["Contract"]) -async def create_pdf(contract_id: str, db=Depends(get_tenant_db_cursor)) -> str: - contract = await Contract.get(db, contract_id) +async def create_pdf(contract_id: str, reg=Depends(get_tenant_registry)) -> str: + contract = await Contract.get(reg.db, contract_id) contract_path = "media/contracts/{}.pdf".format(contract_id) if not os.path.isfile(contract_path): if contract.status != ContractStatus.signed: @@ -133,7 +133,7 @@ async def create_pdf(contract_id: str, db=Depends(get_tenant_db_cursor)) -> str: html.write_pdf(contract_path, stylesheets=[css], font_config=font_config) - await contract.update_status(db, 'printed') + await contract.update_status(reg.db, 'printed') return FileResponse( contract_path, @@ -142,8 +142,8 @@ async def create_pdf(contract_id: str, db=Depends(get_tenant_db_cursor)) -> str: @print_router.get("/opengraph/{signature_id}", response_class=HTMLResponse, tags=["Signature"]) -async def get_signature_opengraph(signature_id: str, request: Request, db=Depends(get_tenant_db_cursor)) -> str: - contract = await Contract.find_by_signature_id(db, signature_id) +async def get_signature_opengraph(signature_id: str, request: Request, reg=Depends(get_tenant_registry)) -> str: + contract = await Contract.find_by_signature_id(reg.db, signature_id) signature = contract.get_signature(signature_id) template = templates.get_template("opengraph.html") diff --git a/api/rpk-api/firm/contract/routes_contract.py b/api/rpk-api/firm/contract/routes_contract.py index 6864d5c..358eb5e 100644 --- a/api/rpk-api/firm/contract/routes_contract.py +++ b/api/rpk-api/firm/contract/routes_contract.py @@ -2,7 +2,7 @@ import uuid from fastapi import Depends, HTTPException from firm.core.routes import get_crud_router -from firm.core.depends import get_logged_tenant_db_cursor +from firm.core.depends import get_authed_tenant_registry from firm.contract.models import Contract, ContractDraft, ContractDraftStatus, replace_variables_in_value, ContractFilters from firm.contract.schemas import ContractCreate, ContractRead, ContractUpdate, ContractInit @@ -17,10 +17,10 @@ del(contract_router.routes[3]) #update del(contract_router.routes[1]) #create @contract_router.post("/", response_description="Contract Successfully created") -async def create(schema: ContractCreate, db=Depends(get_logged_tenant_db_cursor)) -> ContractRead: - await schema.validate_foreign_key(db) +async def create(schema: ContractCreate, reg=Depends(get_authed_tenant_registry)) -> ContractRead: + await schema.validate_foreign_key(reg.db) - draft = await ContractDraft.get(db, schema.draft_id) + draft = await ContractDraft.get(reg.db, schema.draft_id) if not draft: raise HTTPException(status_code=404, detail=f"Contract draft not found!") @@ -31,7 +31,7 @@ async def create(schema: ContractCreate, db=Depends(get_logged_tenant_db_cursor) contract_dict = schema.model_dump() del(contract_dict['draft_id']) - lawyer = await Entity.get(db, db.user.entity_id) + lawyer = await Entity.get(reg.db, reg.user.entity_id) contract_dict['lawyer'] = lawyer.model_dump() contract_dict['name'] = draft.name @@ -39,9 +39,9 @@ async def create(schema: ContractCreate, db=Depends(get_logged_tenant_db_cursor) parties = [] for p in draft.parties: parties.append({ - 'entity': await Entity.get(db, p.entity_id), + 'entity': await Entity.get(reg.db, p.entity_id), 'part': p.part, - 'representative': await Entity.get(db, p.representative_id) if p.representative_id else None, + 'representative': await Entity.get(reg.db, p.representative_id) if p.representative_id else None, 'signature_uuid': str(uuid.uuid4()) }) @@ -50,7 +50,7 @@ async def create(schema: ContractCreate, db=Depends(get_logged_tenant_db_cursor) provisions = [] for p in draft.provisions: p = p.provision - provision = await ProvisionTemplate.get(db, p.provision_template_id) if p.type == "template" \ + provision = await ProvisionTemplate.get(reg.db, p.provision_template_id) if p.type == "template" \ else p provisions.append({ @@ -60,11 +60,11 @@ async def create(schema: ContractCreate, db=Depends(get_logged_tenant_db_cursor) contract_dict['provisions'] = provisions - record = await Contract.create(db, ContractInit(**contract_dict)) - await draft.update_status(db, ContractDraftStatus.published) + record = await Contract.create(reg.db, ContractInit(**contract_dict)) + await draft.update_status(reg.db, ContractDraftStatus.published) return ContractRead.from_model(record) @contract_router.put("/{record_id}", response_description="") -async def update(record_id: str, contract_form: ContractUpdate, db=Depends(get_logged_tenant_db_cursor)) -> ContractRead: +async def update(record_id: str, contract_form: ContractUpdate, reg=Depends(get_authed_tenant_registry)) -> ContractRead: raise HTTPException(status_code=400, detail="No modification on contract") diff --git a/api/rpk-api/firm/contract/routes_draft.py b/api/rpk-api/firm/contract/routes_draft.py index 620f990..5808dea 100644 --- a/api/rpk-api/firm/contract/routes_draft.py +++ b/api/rpk-api/firm/contract/routes_draft.py @@ -2,7 +2,7 @@ from beanie import PydanticObjectId from fastapi import HTTPException, Depends from firm.core.routes import get_crud_router -from firm.core.depends import get_logged_tenant_db_cursor +from firm.core.depends import get_authed_tenant_registry from firm.contract.models import ContractDraft, ContractDraftStatus, ContractDraftFilters from firm.contract.schemas import ContractDraftCreate, ContractDraftRead, ContractDraftUpdate @@ -14,17 +14,17 @@ del(draft_router.routes[1]) #post route @draft_router.post("/", response_description="Contract Draft added to the database") -async def create(schema: ContractDraftCreate, db=Depends(get_logged_tenant_db_cursor)) -> ContractDraftRead: - await schema.validate_foreign_key(db) - record = await ContractDraft.create(db, schema) - await record.check_is_ready(db) +async def create(schema: ContractDraftCreate, reg=Depends(get_authed_tenant_registry)) -> ContractDraftRead: + await schema.validate_foreign_key(reg.db) + record = await ContractDraft.create(reg.db, schema) + await record.check_is_ready(reg.db) return ContractDraftRead.from_model(record) @draft_router.put("/{record_id}", response_description="Contract Draft record updated") -async def update(record_id: PydanticObjectId, schema: ContractDraftUpdate, db=Depends(get_logged_tenant_db_cursor)) -> ContractDraftRead: - record = await ContractDraft.get(db, record_id) +async def update(record_id: PydanticObjectId, schema: ContractDraftUpdate, reg=Depends(get_authed_tenant_registry)) -> ContractDraftRead: + record = await ContractDraft.get(reg.db, record_id) if not record: raise HTTPException( status_code=404, @@ -36,7 +36,7 @@ async def update(record_id: PydanticObjectId, schema: ContractDraftUpdate, db=De detail="Contract Draft has already been published" ) - record = await ContractDraft.update(db, record, schema) - await record.check_is_ready(db) + record = await ContractDraft.update(reg.db, record, schema) + await record.check_is_ready(reg.db) return ContractDraftRead.from_model(record) diff --git a/api/rpk-api/firm/contract/routes_signature.py b/api/rpk-api/firm/contract/routes_signature.py index 140ae08..1e6c958 100644 --- a/api/rpk-api/firm/contract/routes_signature.py +++ b/api/rpk-api/firm/contract/routes_signature.py @@ -4,20 +4,20 @@ import shutil from uuid import UUID from firm.contract.models import Contract, Party -from firm.core.depends import get_tenant_db_cursor +from firm.core.depends import get_tenant_registry signature_router = APIRouter() @signature_router.get("/{signature_id}", response_description="") -async def get_signature(signature_id: UUID, db=Depends(get_tenant_db_cursor)) -> Party: - contract = await Contract.find_by_signature_id(db, signature_id) +async def get_signature(signature_id: UUID, reg=Depends(get_tenant_registry)) -> Party: + contract = await Contract.find_by_signature_id(reg.db, signature_id) signature = contract.get_signature(signature_id) return signature @signature_router.post("/{signature_id}", response_description="") -async def affix_signature(signature_id: UUID, signature_file: UploadFile = File(...), db=Depends(get_tenant_db_cursor)) -> bool: - contract = await Contract.find_by_signature_id(db, signature_id) +async def affix_signature(signature_id: UUID, signature_file: UploadFile = File(...), reg=Depends(get_tenant_registry)) -> bool: + contract = await Contract.find_by_signature_id(reg.db, signature_id) if not contract: raise HTTPException(status_code=404, detail="Contract record not found!") @@ -31,5 +31,5 @@ async def affix_signature(signature_id: UUID, signature_file: UploadFile = File( with open(f'media/signatures/{signature_id}.png', "wb") as buffer: shutil.copyfileobj(signature_file.file, buffer) - await contract.affix_signature(db, signature_index) + await contract.affix_signature(reg.db, signature_index) return True diff --git a/api/rpk-api/firm/core/depends.py b/api/rpk-api/firm/core/depends.py index 28c191f..b33b3c4 100644 --- a/api/rpk-api/firm/core/depends.py +++ b/api/rpk-api/firm/core/depends.py @@ -4,31 +4,48 @@ from hub.auth import get_current_user from firm.db import get_db_client from firm.current_firm import CurrentFirmModel -async def get_tenant_db_cursor(instance: str, firm: str, db_client=Depends(get_db_client)): - db_cursor = db_client[f"tenant_{instance}_{firm}"] - current_firm = await CurrentFirmModel.get(db_cursor) - if current_firm is None: - raise HTTPException(status_code=405, detail=f"Firm needs to be instantiated first") - db_cursor.firm = current_firm - return db_cursor +class Registry: + user = None -def get_logged_tenant_db_cursor(db_cursor=Depends(get_tenant_db_cursor), user=Depends(get_current_user)): - for firm in user.firms: - if firm == db_cursor.firm: - db_cursor.user = user - return db_cursor + def __init__(self, db_client, instance, firm): + self.db = db_client[f"tenant_{instance}_{firm}"] + self.instance = instance + self.firm = firm - raise HTTPException(status_code=404, detail="This firm doesn't exist or you are not allowed to access it.") + self.current_firm = CurrentFirmModel.get(self.db) -async def get_uninitialized_tenant_db_cursor(instance: str, firm: str, db_client=Depends(get_db_client), user=Depends(get_current_user)): - db_cursor = db_client[f"tenant_{instance}_{firm}"] - current_firm = await CurrentFirmModel.get(db_cursor) - if current_firm is not None: + def set_user(self, user): + for firm in user.firms: + if firm.instance == self.instance and firm.firm == firm: + self.user = user + self.db.user = user + return + + raise PermissionError + +async def get_tenant_registry(instance: str, firm: str, db_client=Depends(get_db_client)) -> Registry: + registry = Registry(db_client, instance, firm) + if await registry.current_firm is None: + raise HTTPException(status_code=405, detail=f"Firm needs to be initialized first") + + return registry + +def get_authed_tenant_registry(registry=Depends(get_tenant_registry), user=Depends(get_current_user)) -> Registry: + try: + registry.set_user(user) + except PermissionError: + raise HTTPException(status_code=404, detail="This firm doesn't exist or you are not allowed to access it.") + + return registry + +async def get_uninitialized_registry(instance: str, firm: str, db_client=Depends(get_db_client), user=Depends(get_current_user)) -> Registry: + registry = Registry(db_client, instance, firm) + if await registry.current_firm is not None: HTTPException(status_code=409, detail="Firm configuration already exists") - for firm in user.firms: - if firm == db_cursor.firm: - db_cursor.user = user - return db_cursor + try: + registry.set_user(user) + except PermissionError: + raise HTTPException(status_code=404, detail="This firm doesn't exist or you are not allowed to access it.") - raise HTTPException(status_code=404, detail="This firm doesn't exist or you are not allowed to access it.") + return registry diff --git a/api/rpk-api/firm/core/routes.py b/api/rpk-api/firm/core/routes.py index b0d0e0e..106904f 100644 --- a/api/rpk-api/firm/core/routes.py +++ b/api/rpk-api/firm/core/routes.py @@ -5,7 +5,7 @@ from fastapi_filter import FilterDepends from fastapi_pagination import Page, add_pagination from fastapi_pagination.ext.motor import paginate -from firm.core.depends import get_logged_tenant_db_cursor +from firm.core.depends import get_authed_tenant_registry from firm.core.models import CrudDocument from firm.core.schemas import Writer, Reader @@ -16,16 +16,17 @@ def get_crud_router(model: CrudDocument, model_create: Writer, model_read: Reade @router.get("/", response_model=Page[model_read], response_description=f"{model_name} records retrieved") async def read_list(filters: model_filter=FilterDepends(model_filter), db=Depends(get_logged_tenant_db_cursor)) -> Page[model_read]: return await paginate(**model.find(db, filters)) + async def read_list(filters: model_filter=FilterDepends(model_filter), reg=Depends(get_authed_tenant_registry)) -> Page[model_read]: @router.post("/", response_description=f"{model_name} added to the database") - async def create(schema: model_create, db=Depends(get_logged_tenant_db_cursor)) -> model_read: - await schema.validate_foreign_key(db) - record = await model.create(db, schema) + async def create(schema: model_create, reg=Depends(get_authed_tenant_registry)) -> model_read: + await schema.validate_foreign_key(reg.db) + record = await model.create(reg.db, schema) return model_read.validate_model(record) @router.get("/{record_id}", response_description=f"{model_name} record retrieved") - async def read_one(record_id: PydanticObjectId, db=Depends(get_logged_tenant_db_cursor)) -> model_read: - record = await model.get(db, record_id) + async def read_one(record_id: PydanticObjectId, reg=Depends(get_authed_tenant_registry)) -> model_read: + record = await model.get(reg.db, record_id) if not record: raise HTTPException( status_code=404, @@ -35,27 +36,27 @@ def get_crud_router(model: CrudDocument, model_create: Writer, model_read: Reade return model_read.from_model(record) @router.put("/{record_id}", response_description=f"{model_name} record updated") - async def update(record_id: PydanticObjectId, schema: model_update, db=Depends(get_logged_tenant_db_cursor)) -> model_read: - record = await model.get(db, record_id) + async def update(record_id: PydanticObjectId, schema: model_update, reg=Depends(get_authed_tenant_registry)) -> model_read: + record = await model.get(reg.db, record_id) if not record: raise HTTPException( status_code=404, detail=f"{model_name} record not found!" ) - record = await model.update(db, record, schema) + record = await model.update(reg.db, record, schema) return model_read.from_model(record) @router.delete("/{record_id}", response_description=f"{model_name} record deleted from the database") - async def delete(record_id: PydanticObjectId, db=Depends(get_logged_tenant_db_cursor)) -> dict: - record = await model.get(db, record_id) + async def delete(record_id: PydanticObjectId, reg=Depends(get_authed_tenant_registry)) -> dict: + record = await model.get(reg.db, record_id) if not record: raise HTTPException( status_code=404, detail=f"{model_name} record not found!" ) - await model.delete(db, record) + await model.delete(reg.db, record) return { "message": f"{model_name} deleted successfully" }