Files
budget-forecast/api/app/account/models.py

177 lines
5.3 KiB
Python

from typing import Optional
from sqlmodel import select, Relationship
from sqlalchemy.sql import text
from account.enums import CategoryFamily, Asset, Liability, AccountFamily
from account.schemas import AccountBaseId
class Account(AccountBaseId, table=True):
parent_account: Optional["Account"] = Relationship(
back_populates="children_accounts",
sa_relationship_kwargs=dict(remote_side='Account.id')
)
children_accounts: list["Account"] = Relationship(back_populates='parent_account')
transaction_splits: list["Split"] = Relationship(back_populates='account')
def get_child_path(self, child):
return f"{self.path}{child.name}/"
def get_root_path(self):
root = "/Categories" if self.is_category() else "/Accounts"
return f"{root}/{self.family}/{self.name}/"
def update_children_path(self, session, old_path):
request = f"UPDATE {self.__tablename__} SET path=REPLACE(path, '{old_path}', '{self.path}') WHERE path LIKE '{old_path}{self.name }/%'"
session.exec(text(request))
def is_category(self):
return self.family in [v.value for v in CategoryFamily]
@classmethod
def get_by_path(cls, session, path):
if not path:
return None
return session.exec(select(cls).where(cls.path == path)).first()
def get_parent(self, session):
if self.parent_account_id is None:
return None
self.parent_account = self.get(session, self.parent_account_id)
return self.parent_account
def compute_path(self, session):
if self.parent_account_id is None:
self.path = self.get_root_path()
else:
self.parent_account = self.get(session, self.parent_account_id)
self.path = self.parent_account.get_child_path(self)
return self.path
def compute_family(self):
if self.type in Asset:
self.family = AccountFamily.Asset.value
elif self.type in Liability:
self.family = AccountFamily.Liability.value
else:
self.family = self.type
return self.family
@classmethod
def schema_to_model(cls, session, schema, model=None):
try:
if model:
model = cls.model_validate(model, update=schema)
else:
schema.path = ""
schema.family = ""
model = cls.model_validate(schema)
except Exception as e:
print(e)
model.compute_family()
model.validate_parent(session)
model.compute_path(session)
return model
def validate_parent(self, session):
if self.parent_account_id is None:
return True
parent = self.get_parent(session)
if not parent:
raise KeyError("Parent account not found.")
if parent.family != self.family:
raise ValueError("Account family mismatch with parent account..")
if self.path and parent.path.startswith(self.path):
raise ValueError("Parent Account is descendant")
return True
@classmethod
def create(cls, account, session):
account_db = cls.schema_to_model(session, account)
session.add(account_db)
session.flush()
session.refresh(account_db)
session.commit()
session.refresh(account_db)
return account_db
@classmethod
def create_equity_account(cls, session):
account_db = Account(name="Equity", family="Equity", type="Equity", path="/Equity/")
session.add(account_db)
session.commit()
session.refresh(account_db)
return account_db
@classmethod
def select(cls):
return select(Account)
@classmethod
def list(cls, filters):
return filters.sort(filters.filter(
cls.select()
))
@classmethod
def list_accounts(cls, filters):
return cls.list(filters).where(
Account.family.in_(["Asset", "Liability"])
)
@classmethod
def list_assets(cls, filters):
return cls.list(filters).where(Account.family == "Asset")
@classmethod
def list_liabilities(cls, filters):
return cls.list(filters).where(Account.family == "Liability")
@classmethod
def list_categories(cls, filters):
return cls.list(filters).where(
Account.type.in_(["Expense", "Income"])
)
@classmethod
def list_expenses(cls, filters):
return cls.list(filters).where(Account.family == "Expense")
@classmethod
def list_incomes(cls, filters):
return cls.list(filters).where(Account.family == "Income")
@classmethod
def get(cls, session, account_id):
return session.get(Account, account_id)
@classmethod
def update(cls, session, account_db, account_data):
previous_path = account_db.path
account_db.sqlmodel_update(cls.schema_to_model(session, account_data, account_db))
if previous_path != account_db.path or account_data['name'] != account_db.name:
account_db.update_children_path(session, previous_path)
session.add(account_db)
session.commit()
session.refresh(account_db)
return account_db
@classmethod
def delete(cls, session, account):
session.delete(account)
session.commit()