from typing import Optional from uuid import UUID, uuid4 from enum import Enum from fastapi_filter.contrib.sqlalchemy import Filter from pydantic.json_schema import SkipJsonSchema from sqlmodel import Field, SQLModel, select, Relationship from pydantic import Field as PydField from sqlalchemy.sql import text class AccountBase(SQLModel): name: str = Field(index=True) parent_account_id: Optional[UUID] = Field(default=None, foreign_key="account.id") class AccountBaseId(AccountBase): id: UUID | None = Field(default_factory=uuid4, primary_key=True) family: str = Field(index=True) type: str = Field(index=True) path: str = Field(index=True, unique=True) class Account(AccountBaseId, table=True): parent_account: Optional["Account"] = Relationship( back_populates="children_accounts", sa_relationship_kwargs=dict(remote_side='Account.id') ) children_accounts: list["Account"] = Relationship(back_populates='parent_account') def get_child_path(self): return f"{self.path}{self.name}/" def get_root_path(self): root = "/Categories" if self.is_category() else "/Accounts" return f"{root}/{self.family}/" def update_children_path(self, session, old_path): request = f"UPDATE {self.__tablename__} SET path=REPLACE(path, '{old_path}', '{self.path}') WHERE path LIKE '{old_path}{self.name }/%'" session.exec(text(request)) def is_category(self): return self.type in [v.value for v in CategoryType] def get_parent(self, session): if self.parent_account_id is None: return None self.parent_account = self.get(session, self.parent_account_id) return self.parent_account def get_path(self, session): if self.parent_account_id is None: return self.get_root_path() self.parent_account = self.get(session, self.parent_account_id) return self.parent_account.get_child_path() def get_family(self): if self.type in Asset: return "Asset" if self.type in Liability: return "Liability" return self.type @classmethod def schema_to_model(cls, session, schema, model=None): try: if model: model = cls.model_validate(model, update=schema) else: schema.path = "" schema.family = "" model = cls.model_validate(schema) except Exception as e: print(e) model.family = model.get_family() cls.validate_parent(session, model) model.path = model.get_path(session) return model @classmethod def validate_parent(cls, session, model): if model.parent_account_id is None: return True parent = model.get_parent(session) if not parent: raise ValueError("Parent account not found.") if parent.family != model.family: raise ValueError("Account family mismatch with parent account..") if parent.path.startswith(model.path): raise ValueError("Parent Account is descendant") return True @classmethod def create(cls, account, session): account_db = cls.schema_to_model(session, account) session.add(account_db) session.commit() session.refresh(account_db) return account_db @classmethod def list(cls): return select(Account) @classmethod def list_accounts(cls): return cls.list().where( Account.type.not_in([v.value for v in CategoryType]) ) @classmethod def list_assets(cls): return cls.list().where(Account.family == "Asset") @classmethod def list_liabilities(cls): return cls.list().where(Account.family == "Liability") @classmethod def list_categories(cls): return cls.list().where( Account.type.in_([v.value for v in CategoryType]) ) @classmethod def list_expenses(cls): return cls.list().where(Account.family == "Expense") @classmethod def list_incomes(cls): return cls.list().where(Account.family == "Income") @classmethod def get(cls, session, account_id): return session.get(Account, account_id) @classmethod def update(cls, session, account_db, account_data): previous_path = account_db.path account_db.sqlmodel_update(cls.schema_to_model(session, account_data, account_db)) if previous_path != account_db.path or account_data['name'] != account_db.name: account_db.update_children_path(session, previous_path) session.add(account_db) session.commit() session.refresh(account_db) return account_db @classmethod def delete(cls, session, account): session.delete(account) session.commit() class AccountRead(AccountBaseId): pass class AccountType(Enum): Asset = "Asset" # < Denotes a generic asset account. Checkings = "Checkings" # < Standard checking account Savings = "Savings" # < Typical savings account Cash = "Cash" # < Denotes a shoe-box or pillowcase stuffed with cash Liability = "Liability" # < Denotes a generic liability account. CreditCard = "CreditCard" # < Credit card accounts Loan = "Loan" # < Loan and mortgage accounts (liability) CertificateDep = "CertificateDep" # < Certificates of Deposit Investment = "Investment" # < Investment account MoneyMarket = "MoneyMarket" # < Money Market Account Currency = "Currency" # < Denotes a currency trading account. AssetLoan = "AssetLoan" # < Denotes a loan (asset of the owner of this object) Stock = "Stock" # < Denotes an security account as sub-account for an investment Equity = "Equity" # < Denotes an equity account e.g. opening/closing balance Income = "Income" # < Denotes an income account Expense = "Expense" # < Denotes an expense account class Asset(Enum): Asset = "Asset" Checkings = "Checkings" Savings = "Savings" Cash = "Cash" Investment = "Investment" class Liability(Enum): Liability = "Liability" CreditCard = "CreditCard" Loan = "Loan" class BaseAccountWrite(AccountBase): path: SkipJsonSchema[str] = Field(default="") family: SkipJsonSchema[str] = Field(default="") class AccountWrite(BaseAccountWrite): type: Asset | Liability = Field() parent_account_id: UUID | None = PydField(default=None, json_schema_extra={ "foreign_key": { "reference": { "resource": "accounts", "schema": "AccountRead", "label": "name" } } }) class AccountCreate(AccountWrite): pass class AccountUpdate(AccountWrite): pass class CategoryType(Enum): Income = "Income" Expense = "Expense" class CategoryWrite(BaseAccountWrite): type: CategoryType = Field() class CategoryCreate(CategoryWrite): pass class CategoryUpdate(CategoryWrite): pass class AccountFilter(Filter): name__like: Optional[str] = None class Constants(Filter.Constants): model = Account search_model_fields = ["name"]