from typing import Optional from sqlmodel import select, Relationship from sqlalchemy.sql import text from account.enums import CategoryFamily, Asset, Liability, AccountFamily from account.schemas import AccountBaseId class Account(AccountBaseId, table=True): parent_account: Optional["Account"] = Relationship( back_populates="children_accounts", sa_relationship_kwargs=dict(remote_side='Account.id') ) children_accounts: list["Account"] = Relationship(back_populates='parent_account') def get_child_path(self): return f"{self.path}{self.name}/" def get_root_path(self): root = "/Categories" if self.is_category() else "/Accounts" return f"{root}/{self.family}/" def update_children_path(self, session, old_path): request = f"UPDATE {self.__tablename__} SET path=REPLACE(path, '{old_path}', '{self.path}') WHERE path LIKE '{old_path}{self.name }/%'" session.exec(text(request)) def is_category(self): return self.family in [v.value for v in CategoryFamily] @classmethod def get_by_path(cls, session, path): if not path: return None return session.exec(select(cls).where(cls.path == path)) def get_parent(self, session): if self.parent_account_id is None: return None self.parent_account = self.get(session, self.parent_account_id) return self.parent_account def compute_path(self, session): if self.parent_account_id is None: self.path = self.get_root_path() else: self.parent_account = self.get(session, self.parent_account_id) self.path = self.parent_account.get_child_path() return self.path def compute_family(self): if self.type in Asset: self.family = AccountFamily.Asset elif self.type in Liability: self.family = AccountFamily.Liability else: self.family = self.type return self.family @classmethod def schema_to_model(cls, session, schema, model=None): try: if model: model = cls.model_validate(model, update=schema) else: schema.path = "" schema.family = "" model = cls.model_validate(schema) except Exception as e: print(e) model.compute_family() model.validate_parent(session) model.path = model.get_path(session) return model def validate_parent(self, session): if self.parent_account_id is None: return True parent = self.get_parent(session) if not parent: raise KeyError("Parent account not found.") if parent.family != self.family: raise ValueError("Account family mismatch with parent account..") if parent.path.startswith(self.path): raise ValueError("Parent Account is descendant") return True @classmethod def create(cls, account, session): account_db = cls.schema_to_model(session, account) session.add(account_db) session.flush() session.refresh(account_db) session.commit() session.refresh(account_db) return account_db @classmethod def create_equity_account(cls, session): account_db = Account(name="Equity", family="Equity", type="Equity", path="/Equity/") session.add(account_db) session.commit() session.refresh(account_db) return account_db @classmethod def list(cls, filters): return filters.sort(filters.filter(select(Account))) @classmethod def list_accounts(cls, filters): return cls.list(filters).where( Account.family.in_(["Asset", "Liability"]) ) @classmethod def list_assets(cls, filters): return cls.list(filters).where(Account.family == "Asset") @classmethod def list_liabilities(cls, filters): return cls.list(filters).where(Account.family == "Liability") @classmethod def list_categories(cls, filters): return cls.list(filters).where( Account.type.in_(["Expense", "Income"]) ) @classmethod def list_expenses(cls, filters): return cls.list(filters).where(Account.family == "Expense") @classmethod def list_incomes(cls, filters): return cls.list(filters).where(Account.family == "Income") @classmethod def get(cls, session, account_id): return session.get(Account, account_id) @classmethod def update(cls, session, account_db, account_data): previous_path = account_db.path account_db.sqlmodel_update(cls.schema_to_model(session, account_data, account_db)) if previous_path != account_db.path or account_data['name'] != account_db.name: account_db.update_children_path(session, previous_path) session.add(account_db) session.commit() session.refresh(account_db) return account_db @classmethod def delete(cls, session, account): session.delete(account) session.commit()