From 148b1d00c4f5262c7dcb39a305a658039ed64fb5 Mon Sep 17 00:00:00 2001 From: ewandor Date: Sun, 16 Feb 2025 03:17:29 +0100 Subject: [PATCH] Fully functional opening transactions --- api/app/account/account_routes.py | 25 +++++++---- api/app/account/fixtures.py | 12 +++--- api/app/account/resource.py | 57 ++++++++++++++++++++++++- api/app/account/schemas.py | 9 +++- api/app/core/types.py | 28 +++++++----- api/app/transaction/models.py | 4 +- api/app/transaction/resource.py | 20 ++------- gui/app/src/pages/accounts/edit.tsx | 6 +-- gui/app/src/providers/data-provider.tsx | 3 +- 9 files changed, 114 insertions(+), 50 deletions(-) diff --git a/api/app/account/account_routes.py b/api/app/account/account_routes.py index 3dba587..2e797e0 100644 --- a/api/app/account/account_routes.py +++ b/api/app/account/account_routes.py @@ -7,7 +7,7 @@ from fastapi_filter.contrib.sqlalchemy import Filter from fastapi_pagination import Page from fastapi_pagination.ext.sqlmodel import paginate -from account.schemas import AccountCreate, AccountRead, AccountUpdate +from account.schemas import AccountCreate, AccountRead, AccountUpdate, OpeningTransaction, OpeningTransactionUpdate from account.models import Account from account.resource import AccountResource @@ -57,13 +57,6 @@ def read_account(account_id: UUID, session: SessionDep, current_user=Depends(get raise HTTPException(status_code=404, detail="Account not found") return account -@router.get("/{account_id}/opening_state") -def read_account_opening_state(account_id: UUID, session: SessionDep, current_user=Depends(get_current_user)) -> TransactionRead: - transaction = TransactionResource.get_opening_transaction(session, account_id) - if not transaction: - raise HTTPException(status_code=404, detail="Account not found") - return transaction - @router.put("/{account_id}") def update_account(account_id: UUID, account: AccountUpdate, session: SessionDep, current_user=Depends(get_current_user)) -> AccountRead: db_account = AccountResource.get(session, account_id) @@ -74,6 +67,22 @@ def update_account(account_id: UUID, account: AccountUpdate, session: SessionDep account = AccountResource.update(session, db_account, account_data) return account +@router.get("/{account_id}/opening_state") +def read_account_opening_state(account_id: UUID, session: SessionDep, current_user=Depends(get_current_user)) -> OpeningTransaction: + transaction = AccountResource.get_opening_transaction(session, account_id) + if not transaction: + raise HTTPException(status_code=404, detail="Account not found") + return transaction + +@router.put("/{account_id}/opening_state") +def update_account_opening_state(account_id: UUID, opening_transaction: OpeningTransactionUpdate, session: SessionDep, current_user=Depends(get_current_user)) -> OpeningTransaction: + account = AccountResource.get(session, account_id) + if not account: + raise HTTPException(status_code=404, detail="Account not found") + + transaction = AccountResource.update_opening_transaction(session, account, opening_transaction) + return transaction + @router.delete("/{account_id}") def delete_account(account_id: UUID, session: SessionDep, current_user=Depends(get_current_user)): account = AccountResource.get(session, account_id) diff --git a/api/app/account/fixtures.py b/api/app/account/fixtures.py index 94ab0a7..69df116 100644 --- a/api/app/account/fixtures.py +++ b/api/app/account/fixtures.py @@ -29,42 +29,42 @@ fixtures_account = [ "name": "Current Assets", "parent_path": None, "type": "Asset", - "opening_date": date(1970, 1, 1), + "opening_date": date(1970, 1, 2), "opening_balance": Decimal("0.00"), }, { "name": "Cash in Wallet", "parent_path": "/Accounts/Asset/Current Assets/", "type": "Asset", - "opening_date": date(1970, 1, 1), + "opening_date": date(1970, 1, 3), "opening_balance": Decimal("0.00"), }, { "name": "Checking Account", "parent_path": "/Accounts/Asset/Current Assets/", "type": "Asset", - "opening_date": date(1970, 1, 1), + "opening_date": date(1970, 1, 4), "opening_balance": Decimal("0.00"), }, { "name": "Savings Account", "parent_path": "/Accounts/Asset/Current Assets/", "type": "Asset", - "opening_date": date(1970, 1, 1), + "opening_date": date(1970, 1, 5), "opening_balance": Decimal("0.00"), }, { "name": "Debt Accounts", "parent_path": None, "type": "Liability", - "opening_date": date(1970, 1, 1), + "opening_date": date(1970, 1, 6), "opening_balance": Decimal("0.00"), }, { "name": "Credit Card", "parent_path": "/Accounts/Liability/Debt Accounts/", "type": "Liability", - "opening_date": date(1970, 1, 1), + "opening_date": date(1970, 1, 7), "opening_balance": Decimal("0.00"), }, ] diff --git a/api/app/account/resource.py b/api/app/account/resource.py index c02ed80..3d12cd3 100644 --- a/api/app/account/resource.py +++ b/api/app/account/resource.py @@ -1,8 +1,11 @@ -from sqlalchemy import literal_column +from datetime import date + +from sqlalchemy import and_ from sqlalchemy.orm import aliased from sqlmodel import select from account.models import Account +from account.schemas import OpeningTransaction from transaction.models import Split, Transaction @@ -24,9 +27,9 @@ class AccountResource: @classmethod def create_opening_transaction(cls, session, account, schema): - equity_account = cls.get_by_path(session, "/Equity/") t = Transaction() + t.transaction_date = schema.opening_date split_opening = Split() split_opening.id = 0 split_opening.transaction = t @@ -41,6 +44,56 @@ class AccountResource: account.transaction_splits.append(split_opening) + @classmethod + def fetch_opening_transaction(cls, session, account_id): + split_account = aliased(Split) + split_equity = aliased(Split) + account_equity = aliased(Account) + + return session.execute(select(Transaction) + .join(split_account, and_(split_account.transaction_id == Transaction.id, split_account.account_id == account_id)) + .join(split_equity, split_equity.transaction_id == Transaction.id) + .join(account_equity, and_(account_equity.id == split_equity.account_id, account_equity.path == "/Equity/")) + ).first()[0] + + @classmethod + def get_opening_transaction(cls, session, account_id): + transaction = cls.fetch_opening_transaction(session, account_id) + + if transaction is None: + return None + + return OpeningTransaction( + opening_date=transaction.transaction_date, + opening_balance=transaction.splits[0].amount + ) + + @classmethod + def update_opening_transaction(cls, session, account, schema): + opening_transaction = cls.fetch_opening_transaction(session, account.id) + + stmt = select(Transaction).join(Split) \ + .where(Transaction.id != opening_transaction.id) \ + .where(Split.account_id == account.id) \ + .order_by(Transaction.transaction_date.asc()) + + first_transaction = session.exec(stmt).first() + if first_transaction and schema.opening_date > first_transaction[0].transaction_date: + raise ValueError("Account opening date is posterior to its first transaction date") + + opening_transaction = cls.fetch_opening_transaction(session, account.id) + opening_transaction.transaction_date = schema.opening_date + opening_transaction.splits[0].amount = schema.opening_balance + opening_transaction.splits[1].amount = - schema.opening_balance + + session.commit() + session.refresh(opening_transaction) + + return OpeningTransaction( + opening_date=opening_transaction.transaction_date, + opening_balance=opening_transaction.splits[0].amount + ) + @classmethod def schema_to_model(cls, session, schema, model=None): try: diff --git a/api/app/account/schemas.py b/api/app/account/schemas.py index 720af1e..f65d1ee 100644 --- a/api/app/account/schemas.py +++ b/api/app/account/schemas.py @@ -4,7 +4,7 @@ from typing import Optional from uuid import UUID, uuid4 from sqlmodel import Field, SQLModel -from pydantic import Field as PydField +from pydantic import Field as PydField, BaseModel from pydantic.json_schema import SkipJsonSchema from account.enums import Asset, Liability, CategoryFamily @@ -67,3 +67,10 @@ class CategoryCreate(CategoryWrite): class CategoryUpdate(CategoryWrite): pass + +class OpeningTransaction(BaseModel): + opening_date: date = Field() + opening_balance: MonetaryAmount = Field(default=0) + +class OpeningTransactionUpdate(OpeningTransaction): + pass diff --git a/api/app/core/types.py b/api/app/core/types.py index d46a769..3196096 100644 --- a/api/app/core/types.py +++ b/api/app/core/types.py @@ -2,25 +2,33 @@ from dataclasses import dataclass from decimal import Decimal from typing import Any -from pydantic import GetCoreSchemaHandler +from pydantic import GetCoreSchemaHandler, GetJsonSchemaHandler from pydantic_core import core_schema +from pydantic.json_schema import JsonSchemaValue @dataclass class MonetaryAmount: + @classmethod + def __get_pydantic_json_schema__( + cls, schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler + ) -> JsonSchemaValue: + json_schema = handler(schema) + if "anyOf" in json_schema: + for key, value in json_schema["anyOf"][0].items(): + json_schema[key] = value + del json_schema["anyOf"] + + json_schema["format"] = "monetary" + + return json_schema + @classmethod def __get_pydantic_core_schema__( cls, source: type[Any], handler: GetCoreSchemaHandler ) -> core_schema.CoreSchema: assert source is MonetaryAmount - return core_schema.no_info_after_validator_function( - cls._validate, - core_schema.decimal_schema(multiple_of=0.01), - serialization=core_schema.plain_serializer_function_ser_schema( - cls._serialize, - info_arg=False, - return_schema=core_schema.decimal_schema(multiple_of=0.01), - ), - ) + + return core_schema.decimal_schema(multiple_of=0.01) def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/api/app/transaction/models.py b/api/app/transaction/models.py index 84981ed..e434417 100644 --- a/api/app/transaction/models.py +++ b/api/app/transaction/models.py @@ -1,3 +1,4 @@ +from datetime import date from decimal import Decimal from typing import Optional from uuid import UUID, uuid4 @@ -11,7 +12,8 @@ from payee.models import Payee, PayeeRead class TransactionBase(SQLModel): - pass + transaction_date: date = Field() + payment_date: Optional[date] = Field(default=None) class TransactionBaseId(TransactionBase): id: UUID | None = Field(default_factory=uuid4, primary_key=True) diff --git a/api/app/transaction/resource.py b/api/app/transaction/resource.py index ea15e5f..73e6abb 100644 --- a/api/app/transaction/resource.py +++ b/api/app/transaction/resource.py @@ -1,15 +1,11 @@ -from decimal import Decimal -from typing import Optional -from uuid import UUID, uuid4 +from datetime import date from sqlalchemy.orm import aliased -from sqlmodel import Field, SQLModel, select, Relationship -from pydantic import Field as PydField +from sqlmodel import select from account.models import Account -from account.schemas import AccountRead +from account.schemas import OpeningTransaction from transaction.models import Transaction, Split -from payee.models import Payee, PayeeRead class TransactionResource: @classmethod @@ -29,16 +25,6 @@ class TransactionResource: def get(cls, session, transaction_id): return session.get(Transaction, transaction_id) - @classmethod - def get_opening_transaction(cls, session, account_id): - split_account = aliased(Split) - split_equity = aliased(Split) - account_filter = aliased(Account) - return session.exec(select(Transaction) - .join(split_account, split_account.account_id == account_id) - .join(split_equity) - .join(account_filter, account_filter.id == split_equity.account_id and Account.path == "/Equity/")).first() - @classmethod def update(cls, session, transaction_db, transaction_data): transaction_db.sqlmodel_update(Transaction.model_validate(transaction_data)) diff --git a/gui/app/src/pages/accounts/edit.tsx b/gui/app/src/pages/accounts/edit.tsx index 3d96d05..3e3ef8c 100644 --- a/gui/app/src/pages/accounts/edit.tsx +++ b/gui/app/src/pages/accounts/edit.tsx @@ -13,9 +13,9 @@ export const AccountEdit: React.FC = () => { id={id} /> ); diff --git a/gui/app/src/providers/data-provider.tsx b/gui/app/src/providers/data-provider.tsx index 940ed91..33943db 100644 --- a/gui/app/src/providers/data-provider.tsx +++ b/gui/app/src/providers/data-provider.tsx @@ -14,8 +14,7 @@ const fetcher = async (url: string, options?: RequestInit) => { export const dataProvider: DataProvider = { getOne: async ({ resource, id, meta }) => { - const response = await fetcher(`${API_URL}/${resource}/${id}`); - + const response = id !== "" ? await fetcher(`${API_URL}/${resource}/${id}`) : await fetcher(`${API_URL}/${resource}`); if (response.status < 200 || response.status > 299) throw response; const data = await response.json();