diff --git a/api/app/account/fixtures.py b/api/app/account/fixtures.py index bb9d99b..44a9c80 100644 --- a/api/app/account/fixtures.py +++ b/api/app/account/fixtures.py @@ -1,44 +1,68 @@ +from datetime import date + from account.models import Account -from account.schemas import AccountCreate +from account.schemas import AccountCreate, CategoryCreate def inject_fixtures(session): - for f in fixtures: - if f['parent_path']: - parent = Account.get_by_path(session, f['parent_path']) - f['parent_account_id'] = parent.id - else: - f['parent_account_id'] = None - del f['parent_path'] + for f in fixtures_account: + f = prepare_dict(session, f) schema = AccountCreate(**f) - Account.create(schema) + Account.create(schema, session) -fixtures = [{ + for f in fixtures_category: + f = prepare_dict(session, f) + schema = CategoryCreate(**f) + Account.create(schema, session) + +def prepare_dict(session, entry): + if entry['parent_path']: + parent = Account.get_by_path(session, entry['parent_path']) + entry['parent_account_id'] = parent.id + else: + entry['parent_account_id'] = None + del entry['parent_path'] + return entry + +fixtures_account = [ + { "name": "Current Assets", "parent_path": None, - "type": "Asset" - },{ + "type": "Asset", + "opening_date": date(1970, 1, 1), + }, + { "name": "Cash in Wallet", - "parent_path": "/Accounts/Asset/", + "parent_path": "/Accounts/Asset/Current Assets/", "type": "Asset", - },{ + "opening_date": date(1970, 1, 1), + }, + { "name": "Checking Account", - "parent_path": "/Accounts/Asset/", + "parent_path": "/Accounts/Asset/Current Assets/", "type": "Asset", - },{ + "opening_date": date(1970, 1, 1), + }, + { "name": "Savings Account", - "parent_path": "/Accounts/Asset/", + "parent_path": "/Accounts/Asset/Current Assets/", "type": "Asset", + "opening_date": date(1970, 1, 1), }, { "name": "Debt Accounts", "parent_path": None, "type": "Liability", + "opening_date": date(1970, 1, 1), }, { "name": "Credit Card", - "parent_path": "/Accounts/Liability/", + "parent_path": "/Accounts/Liability/Debt Accounts/", "type": "Liability", + "opening_date": date(1970, 1, 1), }, +] + +fixtures_category = [ { "name": "Salary", "parent_path": None, @@ -61,12 +85,12 @@ fixtures = [{ }, { "name": "Rent", - "parent_path": "/Categories/Expense/Home", + "parent_path": "/Categories/Expense/Home/", "type": "Expense", }, { "name": "Electricity", - "parent_path": "/Categories/Expense/Home", + "parent_path": "/Categories/Expense/Home/", "type": "Expense", }, { @@ -78,11 +102,9 @@ fixtures = [{ "name": "Groceries", "parent_path": None, "type": "Expense", - }, + } ] - - """ diff --git a/api/app/account/models.py b/api/app/account/models.py index 5b35a70..eb45a56 100644 --- a/api/app/account/models.py +++ b/api/app/account/models.py @@ -12,13 +12,15 @@ class Account(AccountBaseId, table=True): sa_relationship_kwargs=dict(remote_side='Account.id') ) children_accounts: list["Account"] = Relationship(back_populates='parent_account') + transaction_splits: list["Split"] = Relationship(back_populates='account') - def get_child_path(self): - return f"{self.path}{self.name}/" + + def get_child_path(self, child): + return f"{self.path}{child.name}/" def get_root_path(self): root = "/Categories" if self.is_category() else "/Accounts" - return f"{root}/{self.family}/" + return f"{root}/{self.family}/{self.name}/" def update_children_path(self, session, old_path): request = f"UPDATE {self.__tablename__} SET path=REPLACE(path, '{old_path}', '{self.path}') WHERE path LIKE '{old_path}{self.name }/%'" @@ -32,7 +34,7 @@ class Account(AccountBaseId, table=True): if not path: return None - return session.exec(select(cls).where(cls.path == path)) + return session.exec(select(cls).where(cls.path == path)).first() def get_parent(self, session): @@ -47,15 +49,15 @@ class Account(AccountBaseId, table=True): self.path = self.get_root_path() else: self.parent_account = self.get(session, self.parent_account_id) - self.path = self.parent_account.get_child_path() + self.path = self.parent_account.get_child_path(self) return self.path def compute_family(self): if self.type in Asset: - self.family = AccountFamily.Asset + self.family = AccountFamily.Asset.value elif self.type in Liability: - self.family = AccountFamily.Liability + self.family = AccountFamily.Liability.value else: self.family = self.type return self.family @@ -74,7 +76,7 @@ class Account(AccountBaseId, table=True): model.compute_family() model.validate_parent(session) - model.path = model.get_path(session) + model.compute_path(session) return model def validate_parent(self, session): @@ -88,7 +90,7 @@ class Account(AccountBaseId, table=True): if parent.family != self.family: raise ValueError("Account family mismatch with parent account..") - if parent.path.startswith(self.path): + if self.path and parent.path.startswith(self.path): raise ValueError("Parent Account is descendant") return True @@ -114,9 +116,15 @@ class Account(AccountBaseId, table=True): return account_db + @classmethod + def select(cls): + return select(Account) + @classmethod def list(cls, filters): - return filters.sort(filters.filter(select(Account))) + return filters.sort(filters.filter( + cls.select() + )) @classmethod def list_accounts(cls, filters): diff --git a/api/app/account/resource.py b/api/app/account/resource.py new file mode 100644 index 0000000..985e2bf --- /dev/null +++ b/api/app/account/resource.py @@ -0,0 +1,144 @@ +class AccountResource: + @classmethod + def get_by_path(cls, session, path): + if not path: + return None + + return session.exec(select(cls).where(cls.path == path)).first() + + def get_parent(self, session): + if self.parent_account_id is None: + return None + + self.parent_account = self.get(session, self.parent_account_id) + return self.parent_account + + def compute_path(self, session): + if self.parent_account_id is None: + self.path = self.get_root_path() + else: + self.parent_account = self.get(session, self.parent_account_id) + self.path = self.parent_account.get_child_path(self) + + return self.path + + def compute_family(self): + if self.type in Asset: + self.family = AccountFamily.Asset.value + elif self.type in Liability: + self.family = AccountFamily.Liability.value + else: + self.family = self.type + return self.family + + @classmethod + def schema_to_model(cls, session, schema, model=None): + try: + if model: + model = cls.model_validate(model, update=schema) + else: + schema.path = "" + schema.family = "" + model = cls.model_validate(schema) + except Exception as e: + print(e) + + model.compute_family() + model.validate_parent(session) + model.compute_path(session) + return model + + def validate_parent(self, session): + if self.parent_account_id is None: + return True + + parent = self.get_parent(session) + if not parent: + raise KeyError("Parent account not found.") + + if parent.family != self.family: + raise ValueError("Account family mismatch with parent account..") + + if self.path and parent.path.startswith(self.path): + raise ValueError("Parent Account is descendant") + + return True + + @classmethod + def create(cls, account, session): + account_db = cls.schema_to_model(session, account) + session.add(account_db) + session.flush() + session.refresh(account_db) + + session.commit() + session.refresh(account_db) + + return account_db + + @classmethod + def create_equity_account(cls, session): + account_db = Account(name="Equity", family="Equity", type="Equity", path="/Equity/") + session.add(account_db) + session.commit() + session.refresh(account_db) + + return account_db + + @classmethod + def select(cls): + return select(Account) + + @classmethod + def list(cls, filters): + return filters.sort(filters.filter( + cls.select() + )) + + @classmethod + def list_accounts(cls, filters): + return cls.list(filters).where( + Account.family.in_(["Asset", "Liability"]) + ) + + @classmethod + def list_assets(cls, filters): + return cls.list(filters).where(Account.family == "Asset") + + @classmethod + def list_liabilities(cls, filters): + return cls.list(filters).where(Account.family == "Liability") + + @classmethod + def list_categories(cls, filters): + return cls.list(filters).where( + Account.type.in_(["Expense", "Income"]) + ) + + @classmethod + def list_expenses(cls, filters): + return cls.list(filters).where(Account.family == "Expense") + + @classmethod + def list_incomes(cls, filters): + return cls.list(filters).where(Account.family == "Income") + + @classmethod + def get(cls, session, account_id): + return session.get(Account, account_id) + + @classmethod + def update(cls, session, account_db, account_data): + previous_path = account_db.path + account_db.sqlmodel_update(cls.schema_to_model(session, account_data, account_db)) + if previous_path != account_db.path or account_data['name'] != account_db.name: + account_db.update_children_path(session, previous_path) + session.add(account_db) + session.commit() + session.refresh(account_db) + return account_db + + @classmethod + def delete(cls, session, account): + session.delete(account) + session.commit() \ No newline at end of file diff --git a/api/app/account/schemas.py b/api/app/account/schemas.py index 2506bde..a272cfa 100644 --- a/api/app/account/schemas.py +++ b/api/app/account/schemas.py @@ -20,14 +20,12 @@ class AccountBaseId(AccountBase): path: str = Field(index=True) class AccountRead(AccountBaseId): - opening_date: date = Field() - opening_balance: Decimal = Field(decimal_places=2, default=0) + pass#opening_date: date = Field() + #opening_balance: Decimal = Field(decimal_places=2, default=0) class BaseAccountWrite(AccountBase): path: SkipJsonSchema[str] = Field(default="") family: SkipJsonSchema[str] = Field(default="") - opening_date: date = Field() - opening_balance: Decimal = Field(decimal_places=2, default=0) class AccountWrite(BaseAccountWrite): type: Asset | Liability = Field() @@ -40,6 +38,8 @@ class AccountWrite(BaseAccountWrite): } } }) + opening_date: date = Field() + opening_balance: Decimal = Field(decimal_places=2, default=0) class AccountCreate(AccountWrite): pass @@ -52,6 +52,17 @@ class CategoryRead(AccountBaseId): class CategoryWrite(BaseAccountWrite): type: CategoryFamily = Field() + parent_account_id: UUID | None = PydField(default=None, json_schema_extra={ + "foreign_key": { + "reference": { + "resource": "categories", + "schema": "CategoryRead", + "label": "name" + } + } + }) + opening_date: date = date(1970, 1, 1) + opening_balance: Decimal = 0 class CategoryCreate(CategoryWrite): pass diff --git a/api/app/db.py b/api/app/db.py index 0ba4938..30f397e 100644 --- a/api/app/db.py +++ b/api/app/db.py @@ -15,7 +15,7 @@ def create_db_and_tables(): SQLModel.metadata.create_all(engine) def drop_tables(): - SQLModel.metadata.create_all(engine) + SQLModel.metadata.drop_all(engine) def get_session() -> Session: with Session(engine) as session: diff --git a/api/app/transaction/models.py b/api/app/transaction/models.py index 81bb163..9114e6c 100644 --- a/api/app/transaction/models.py +++ b/api/app/transaction/models.py @@ -96,7 +96,7 @@ class TransactionUpdate(TransactionWrite): class Split(SplitBaseId, table=True): transaction: Transaction = Relationship(back_populates="splits") - account: Account | None = Relationship() + account: Account | None = Relationship(back_populates="transaction_splits") payee: Payee | None = Relationship() @classmethod