161 lines
4.9 KiB
Python
161 lines
4.9 KiB
Python
from typing import Optional
|
|
|
|
from sqlmodel import select, Relationship
|
|
from sqlalchemy.sql import text
|
|
|
|
from account.enums import CategoryFamily, Asset, Liability, AccountFamily
|
|
from account.schemas import AccountBaseId
|
|
|
|
class Account(AccountBaseId, table=True):
|
|
parent_account: Optional["Account"] = Relationship(
|
|
back_populates="children_accounts",
|
|
sa_relationship_kwargs=dict(remote_side='Account.id')
|
|
)
|
|
children_accounts: list["Account"] = Relationship(back_populates='parent_account')
|
|
|
|
def get_child_path(self):
|
|
return f"{self.path}{self.name}/"
|
|
|
|
def get_root_path(self):
|
|
root = "/Categories" if self.is_category() else "/Accounts"
|
|
return f"{root}/{self.family}/"
|
|
|
|
def update_children_path(self, session, old_path):
|
|
request = f"UPDATE {self.__tablename__} SET path=REPLACE(path, '{old_path}', '{self.path}') WHERE path LIKE '{old_path}{self.name }/%'"
|
|
session.exec(text(request))
|
|
|
|
def is_category(self):
|
|
return self.family in [v.value for v in CategoryFamily]
|
|
|
|
def get_parent(self, session):
|
|
if self.parent_account_id is None:
|
|
return None
|
|
|
|
self.parent_account = self.get(session, self.parent_account_id)
|
|
return self.parent_account
|
|
|
|
def compute_path(self, session):
|
|
if self.parent_account_id is None:
|
|
self.path = self.get_root_path()
|
|
else:
|
|
self.parent_account = self.get(session, self.parent_account_id)
|
|
self.path = self.parent_account.get_child_path()
|
|
|
|
return self.path
|
|
|
|
def compute_family(self):
|
|
if self.type in Asset:
|
|
self.family = AccountFamily.Asset
|
|
elif self.type in Liability:
|
|
self.family = AccountFamily.Liability
|
|
else:
|
|
self.family = self.type
|
|
return self.family
|
|
|
|
@classmethod
|
|
def schema_to_model(cls, session, schema, model=None):
|
|
try:
|
|
if model:
|
|
model = cls.model_validate(model, update=schema)
|
|
else:
|
|
schema.path = ""
|
|
schema.family = ""
|
|
model = cls.model_validate(schema)
|
|
except Exception as e:
|
|
print(e)
|
|
|
|
model.compute_family()
|
|
model.validate_parent(session)
|
|
model.path = model.get_path(session)
|
|
return model
|
|
|
|
def validate_parent(self, session):
|
|
if self.parent_account_id is None:
|
|
return True
|
|
|
|
parent = self.get_parent(session)
|
|
if not parent:
|
|
raise KeyError("Parent account not found.")
|
|
|
|
if parent.family != self.family:
|
|
raise ValueError("Account family mismatch with parent account..")
|
|
|
|
if parent.path.startswith(self.path):
|
|
raise ValueError("Parent Account is descendant")
|
|
|
|
return True
|
|
|
|
@classmethod
|
|
def create(cls, account, session):
|
|
account_db = cls.schema_to_model(session, account)
|
|
session.add(account_db)
|
|
session.flush()
|
|
session.refresh(account_db)
|
|
|
|
session.commit()
|
|
session.refresh(account_db)
|
|
|
|
return account_db
|
|
|
|
@classmethod
|
|
def create_equity_account(cls, session):
|
|
account_db = Account(name="Equity", family="Equity", type="Equity", path="/Equity/")
|
|
session.add(account_db)
|
|
session.commit()
|
|
session.refresh(account_db)
|
|
|
|
return account_db
|
|
|
|
@classmethod
|
|
def list(cls, filters):
|
|
return filters.sort(filters.filter(select(Account)))
|
|
|
|
@classmethod
|
|
def list_accounts(cls, filters):
|
|
return cls.list(filters).where(
|
|
Account.family.in_(["Asset", "Liability"])
|
|
)
|
|
|
|
@classmethod
|
|
def list_assets(cls, filters):
|
|
return cls.list(filters).where(Account.family == "Asset")
|
|
|
|
@classmethod
|
|
def list_liabilities(cls, filters):
|
|
return cls.list(filters).where(Account.family == "Liability")
|
|
|
|
@classmethod
|
|
def list_categories(cls, filters):
|
|
return cls.list(filters).where(
|
|
Account.type.in_(["Expense", "Income"])
|
|
)
|
|
|
|
@classmethod
|
|
def list_expenses(cls, filters):
|
|
return cls.list(filters).where(Account.family == "Expense")
|
|
|
|
@classmethod
|
|
def list_incomes(cls, filters):
|
|
return cls.list(filters).where(Account.family == "Income")
|
|
|
|
@classmethod
|
|
def get(cls, session, account_id):
|
|
return session.get(Account, account_id)
|
|
|
|
@classmethod
|
|
def update(cls, session, account_db, account_data):
|
|
previous_path = account_db.path
|
|
account_db.sqlmodel_update(cls.schema_to_model(session, account_data, account_db))
|
|
if previous_path != account_db.path or account_data['name'] != account_db.name:
|
|
account_db.update_children_path(session, previous_path)
|
|
session.add(account_db)
|
|
session.commit()
|
|
session.refresh(account_db)
|
|
return account_db
|
|
|
|
@classmethod
|
|
def delete(cls, session, account):
|
|
session.delete(account)
|
|
session.commit()
|
|
|