From 539410c18ba72009acd06ce1d46e154108254a0d Mon Sep 17 00:00:00 2001 From: ewandor Date: Wed, 12 Feb 2025 20:11:59 +0100 Subject: [PATCH] Moving requests to resource --- api/app/account/account_routes.py | 20 ++-- api/app/account/category_routes.py | 20 ++-- api/app/account/fixtures.py | 8 +- api/app/account/models.py | 148 ++--------------------------- api/app/account/resource.py | 56 +++++------ 5 files changed, 56 insertions(+), 196 deletions(-) diff --git a/api/app/account/account_routes.py b/api/app/account/account_routes.py index 0d5f36a..0334784 100644 --- a/api/app/account/account_routes.py +++ b/api/app/account/account_routes.py @@ -9,6 +9,8 @@ from fastapi_pagination.ext.sqlmodel import paginate from account.schemas import AccountCreate, AccountRead, AccountUpdate from account.models import Account +from account.resource import AccountResource + from db import SessionDep from user.manager import get_current_user @@ -25,49 +27,49 @@ router = APIRouter() @router.post("") def create_account(account: AccountCreate, session: SessionDep, current_user=Depends(get_current_user)) -> AccountRead: - result = Account.create(account, session) + result = AccountResource.create(account, session) return result @router.get("") def read_accounts(session: SessionDep, filters: AccountFilters = FilterDepends(AccountFilters), current_user=Depends(get_current_user)) -> Page[AccountRead]: - return paginate(session, Account.list_accounts(filters)) + return paginate(session, AccountResource.list_accounts(filters)) @router.get("/assets") def read_assets(session: SessionDep, filters: AccountFilters = FilterDepends(AccountFilters), current_user=Depends(get_current_user)) -> Page[AccountRead]: - return paginate(session, Account.list_assets(filters)) + return paginate(session, AccountResource.list_assets(filters)) @router.get("/liabilities") def read_liabilities(session: SessionDep, filters: AccountFilters = FilterDepends(AccountFilters), current_user=Depends(get_current_user)) -> Page[AccountRead]: - return paginate(session, Account.list_liabilities(filters)) + return paginate(session, AccountResource.list_liabilities(filters)) @router.get("/{account_id}") def read_account(account_id: UUID, session: SessionDep, current_user=Depends(get_current_user)) -> AccountRead: - account = Account.get(session, account_id) + account = AccountResource.get(session, account_id) if not account: raise HTTPException(status_code=404, detail="Account not found") return account @router.put("/{account_id}") def update_account(account_id: UUID, account: AccountUpdate, session: SessionDep, current_user=Depends(get_current_user)) -> AccountRead: - db_account = Account.get(session, account_id) + db_account = AccountResource.get(session, account_id) if not db_account: raise HTTPException(status_code=404, detail="Account not found") account_data = account.model_dump(exclude_unset=True) - account = Account.update(session, db_account, account_data) + account = AccountResource.update(session, db_account, account_data) return account @router.delete("/{account_id}") def delete_account(account_id: UUID, session: SessionDep, current_user=Depends(get_current_user)): - account = Account.get(session, account_id) + account = AccountResource.get(session, account_id) if not account: raise HTTPException(status_code=404, detail="Account not found") - Account.delete(session, account) + AccountResource.delete(session, account) return {"ok": True} diff --git a/api/app/account/category_routes.py b/api/app/account/category_routes.py index d6834b9..7dffd15 100644 --- a/api/app/account/category_routes.py +++ b/api/app/account/category_routes.py @@ -8,6 +8,8 @@ from fastapi_pagination.ext.sqlmodel import paginate from account.account_routes import AccountFilters from account.schemas import CategoryRead, CategoryCreate, CategoryUpdate from account.models import Account +from account.resource import AccountResource + from db import SessionDep from user.manager import get_current_user @@ -15,49 +17,49 @@ router = APIRouter() @router.post("") def create_category(category: CategoryCreate, session: SessionDep, current_user=Depends(get_current_user)) -> CategoryRead: - result = Account.create(category, session) + result = AccountResource.create(category, session) return result @router.get("") def read_categories(session: SessionDep, filters: AccountFilters = FilterDepends(AccountFilters), current_user=Depends(get_current_user)) -> Page[CategoryRead]: - return paginate(session, Account.list_categories(filters)) + return paginate(session, AccountResource.list_categories(filters)) @router.get("expenses") def read_expenses(session: SessionDep, filters: AccountFilters = FilterDepends(AccountFilters), current_user=Depends(get_current_user)) -> Page[CategoryRead]: - return paginate(session, Account.list_expenses(filters)) + return paginate(session, AccountResource.list_expenses(filters)) @router.get("incomes") def read_incomes(session: SessionDep, filters: AccountFilters = FilterDepends(AccountFilters), current_user=Depends(get_current_user)) -> Page[CategoryRead]: - return paginate(session, Account.list_incomes(filters)) + return paginate(session, AccountResource.list_incomes(filters)) @router.get("/{category_id}") def read_category(category_id: UUID, session: SessionDep, current_user=Depends(get_current_user)) -> CategoryRead: - category = Account.get(session, category_id) + category = AccountResource.get(session, category_id) if not category: raise HTTPException(status_code=404, detail="Category not found") return category @router.put("/{category_id}") def update_category(category_id: UUID, category: CategoryUpdate, session: SessionDep, current_user=Depends(get_current_user)) -> CategoryRead: - db_category = Account.get(session, category_id) + db_category = AccountResource.get(session, category_id) if not db_category: raise HTTPException(status_code=404, detail="Category not found") category_data = category.model_dump(exclude_unset=True) - category = Account.update(session, db_category, category_data) + category = AccountResource.update(session, db_category, category_data) return category @router.delete("/{category_id}") def delete_category(category_id: UUID, session: SessionDep, current_user=Depends(get_current_user)): - category = Account.get(session, category_id) + category = AccountResource.get(session, category_id) if not category: raise HTTPException(status_code=404, detail="Category not found") - Account.delete(session, category) + AccountResource.delete(session, category) return {"ok": True} diff --git a/api/app/account/fixtures.py b/api/app/account/fixtures.py index 44a9c80..88f3c52 100644 --- a/api/app/account/fixtures.py +++ b/api/app/account/fixtures.py @@ -1,22 +1,22 @@ from datetime import date -from account.models import Account +from account.resource import AccountResource from account.schemas import AccountCreate, CategoryCreate def inject_fixtures(session): for f in fixtures_account: f = prepare_dict(session, f) schema = AccountCreate(**f) - Account.create(schema, session) + AccountResource.create(schema, session) for f in fixtures_category: f = prepare_dict(session, f) schema = CategoryCreate(**f) - Account.create(schema, session) + AccountResource.create(schema, session) def prepare_dict(session, entry): if entry['parent_path']: - parent = Account.get_by_path(session, entry['parent_path']) + parent = AccountResource.get_by_path(session, entry['parent_path']) entry['parent_account_id'] = parent.id else: entry['parent_account_id'] = None diff --git a/api/app/account/models.py b/api/app/account/models.py index eb45a56..6050c87 100644 --- a/api/app/account/models.py +++ b/api/app/account/models.py @@ -1,6 +1,7 @@ -from typing import Optional +from typing import Optional, Any -from sqlmodel import select, Relationship +from pydantic import computed_field +from sqlmodel import Relationship from sqlalchemy.sql import text from account.enums import CategoryFamily, Asset, Liability, AccountFamily @@ -14,6 +15,8 @@ class Account(AccountBaseId, table=True): children_accounts: list["Account"] = Relationship(back_populates='parent_account') transaction_splits: list["Split"] = Relationship(back_populates='account') + def is_category(self): + return self.family in [v.value for v in CategoryFamily] def get_child_path(self, child): return f"{self.path}{child.name}/" @@ -22,35 +25,11 @@ class Account(AccountBaseId, table=True): root = "/Categories" if self.is_category() else "/Accounts" return f"{root}/{self.family}/{self.name}/" - def update_children_path(self, session, old_path): - request = f"UPDATE {self.__tablename__} SET path=REPLACE(path, '{old_path}', '{self.path}') WHERE path LIKE '{old_path}{self.name }/%'" - session.exec(text(request)) - - def is_category(self): - return self.family in [v.value for v in CategoryFamily] - - @classmethod - def get_by_path(cls, session, path): - if not path: - return None - - return session.exec(select(cls).where(cls.path == path)).first() - - - def get_parent(self, session): - if self.parent_account_id is None: - return None - - self.parent_account = self.get(session, self.parent_account_id) - return self.parent_account - - def compute_path(self, session): - if self.parent_account_id is None: + def compute_path(self): + if self.parent_account is None: self.path = self.get_root_path() else: - self.parent_account = self.get(session, self.parent_account_id) self.path = self.parent_account.get_child_path(self) - return self.path def compute_family(self): @@ -61,116 +40,3 @@ class Account(AccountBaseId, table=True): else: self.family = self.type return self.family - - @classmethod - def schema_to_model(cls, session, schema, model=None): - try: - if model: - model = cls.model_validate(model, update=schema) - else: - schema.path = "" - schema.family = "" - model = cls.model_validate(schema) - except Exception as e: - print(e) - - model.compute_family() - model.validate_parent(session) - model.compute_path(session) - return model - - def validate_parent(self, session): - if self.parent_account_id is None: - return True - - parent = self.get_parent(session) - if not parent: - raise KeyError("Parent account not found.") - - if parent.family != self.family: - raise ValueError("Account family mismatch with parent account..") - - if self.path and parent.path.startswith(self.path): - raise ValueError("Parent Account is descendant") - - return True - - @classmethod - def create(cls, account, session): - account_db = cls.schema_to_model(session, account) - session.add(account_db) - session.flush() - session.refresh(account_db) - - session.commit() - session.refresh(account_db) - - return account_db - - @classmethod - def create_equity_account(cls, session): - account_db = Account(name="Equity", family="Equity", type="Equity", path="/Equity/") - session.add(account_db) - session.commit() - session.refresh(account_db) - - return account_db - - @classmethod - def select(cls): - return select(Account) - - @classmethod - def list(cls, filters): - return filters.sort(filters.filter( - cls.select() - )) - - @classmethod - def list_accounts(cls, filters): - return cls.list(filters).where( - Account.family.in_(["Asset", "Liability"]) - ) - - @classmethod - def list_assets(cls, filters): - return cls.list(filters).where(Account.family == "Asset") - - @classmethod - def list_liabilities(cls, filters): - return cls.list(filters).where(Account.family == "Liability") - - @classmethod - def list_categories(cls, filters): - return cls.list(filters).where( - Account.type.in_(["Expense", "Income"]) - ) - - @classmethod - def list_expenses(cls, filters): - return cls.list(filters).where(Account.family == "Expense") - - @classmethod - def list_incomes(cls, filters): - return cls.list(filters).where(Account.family == "Income") - - @classmethod - def get(cls, session, account_id): - return session.get(Account, account_id) - - @classmethod - def update(cls, session, account_db, account_data): - previous_path = account_db.path - account_db.sqlmodel_update(cls.schema_to_model(session, account_data, account_db)) - if previous_path != account_db.path or account_data['name'] != account_db.name: - account_db.update_children_path(session, previous_path) - session.add(account_db) - session.commit() - session.refresh(account_db) - return account_db - - @classmethod - def delete(cls, session, account): - session.delete(account) - session.commit() - diff --git a/api/app/account/resource.py b/api/app/account/resource.py index 985e2bf..65309c1 100644 --- a/api/app/account/resource.py +++ b/api/app/account/resource.py @@ -1,68 +1,58 @@ +from sqlmodel import select + +from account.models import Account + class AccountResource: @classmethod def get_by_path(cls, session, path): if not path: return None - return session.exec(select(cls).where(cls.path == path)).first() + return session.exec(select(Account).where(Account.path == path)).first() - def get_parent(self, session): - if self.parent_account_id is None: + @classmethod + def get_parent(cls, session, model): + if model.parent_account_id is None: return None - self.parent_account = self.get(session, self.parent_account_id) - return self.parent_account - - def compute_path(self, session): - if self.parent_account_id is None: - self.path = self.get_root_path() - else: - self.parent_account = self.get(session, self.parent_account_id) - self.path = self.parent_account.get_child_path(self) - - return self.path - - def compute_family(self): - if self.type in Asset: - self.family = AccountFamily.Asset.value - elif self.type in Liability: - self.family = AccountFamily.Liability.value - else: - self.family = self.type - return self.family + model.parent_account = cls.get(session, model.parent_account_id) + return model.parent_account @classmethod def schema_to_model(cls, session, schema, model=None): try: if model: - model = cls.model_validate(model, update=schema) + model = Account.model_validate(model, update=schema) else: schema.path = "" schema.family = "" - model = cls.model_validate(schema) + model = Account.model_validate(schema) except Exception as e: print(e) + raise model.compute_family() - model.validate_parent(session) - model.compute_path(session) + cls.validate_parent(session, model) + model.compute_path() return model - def validate_parent(self, session): - if self.parent_account_id is None: + @classmethod + def validate_parent(cls, session, model): + if model.parent_account_id is None: return True - parent = self.get_parent(session) + parent = cls.get_parent(session, model) if not parent: raise KeyError("Parent account not found.") - if parent.family != self.family: + if parent.family != model.family: raise ValueError("Account family mismatch with parent account..") - if self.path and parent.path.startswith(self.path): + if model.path and parent.path.startswith(model.path): raise ValueError("Parent Account is descendant") - return True + model.parent_account = parent + return model @classmethod def create(cls, account, session):