from sqlmodel import select from account.models import Account class AccountResource: @classmethod def get_by_path(cls, session, path): if not path: return None return session.exec(select(Account).where(Account.path == path)).first() @classmethod def get_parent(cls, session, model): if model.parent_account_id is None: return None model.parent_account = cls.get(session, model.parent_account_id) return model.parent_account @classmethod def schema_to_model(cls, session, schema, model=None): try: if model: model = Account.model_validate(model, update=schema) else: schema.path = "" schema.family = "" model = Account.model_validate(schema) except Exception as e: print(e) raise model.compute_family() cls.validate_parent(session, model) model.compute_path() return model @classmethod def validate_parent(cls, session, model): if model.parent_account_id is None: return True parent = cls.get_parent(session, model) if not parent: raise KeyError("Parent account not found.") if parent.family != model.family: raise ValueError("Account family mismatch with parent account..") if model.path and parent.path.startswith(model.path): raise ValueError("Parent Account is descendant") model.parent_account = parent return model @classmethod def create(cls, account, session): account_db = cls.schema_to_model(session, account) session.add(account_db) session.flush() session.refresh(account_db) session.commit() session.refresh(account_db) return account_db @classmethod def create_equity_account(cls, session): account_db = Account(name="Equity", family="Equity", type="Equity", path="/Equity/") session.add(account_db) session.commit() session.refresh(account_db) return account_db @classmethod def select(cls): return select(Account) @classmethod def list(cls, filters): return filters.sort(filters.filter( cls.select() )) @classmethod def list_accounts(cls, filters): return cls.list(filters).where( Account.family.in_(["Asset", "Liability"]) ) @classmethod def list_assets(cls, filters): return cls.list(filters).where(Account.family == "Asset") @classmethod def list_liabilities(cls, filters): return cls.list(filters).where(Account.family == "Liability") @classmethod def list_categories(cls, filters): return cls.list(filters).where( Account.type.in_(["Expense", "Income"]) ) @classmethod def list_expenses(cls, filters): return cls.list(filters).where(Account.family == "Expense") @classmethod def list_incomes(cls, filters): return cls.list(filters).where(Account.family == "Income") @classmethod def get(cls, session, account_id): return session.get(Account, account_id) @classmethod def update(cls, session, account_db, account_data): previous_path = account_db.path account_db.sqlmodel_update(cls.schema_to_model(session, account_data, account_db)) if previous_path != account_db.path or account_data['name'] != account_db.name: account_db.update_children_path(session, previous_path) session.add(account_db) session.commit() session.refresh(account_db) return account_db @classmethod def delete(cls, session, account): session.delete(account) session.commit()