diff --git a/api/app/account/fixtures.py b/api/app/account/fixtures.py index 88f3c52..94ab0a7 100644 --- a/api/app/account/fixtures.py +++ b/api/app/account/fixtures.py @@ -1,4 +1,5 @@ from datetime import date +from decimal import Decimal from account.resource import AccountResource from account.schemas import AccountCreate, CategoryCreate @@ -29,36 +30,42 @@ fixtures_account = [ "parent_path": None, "type": "Asset", "opening_date": date(1970, 1, 1), + "opening_balance": Decimal("0.00"), }, { "name": "Cash in Wallet", "parent_path": "/Accounts/Asset/Current Assets/", "type": "Asset", "opening_date": date(1970, 1, 1), + "opening_balance": Decimal("0.00"), }, { "name": "Checking Account", "parent_path": "/Accounts/Asset/Current Assets/", "type": "Asset", "opening_date": date(1970, 1, 1), + "opening_balance": Decimal("0.00"), }, { "name": "Savings Account", "parent_path": "/Accounts/Asset/Current Assets/", "type": "Asset", "opening_date": date(1970, 1, 1), + "opening_balance": Decimal("0.00"), }, { "name": "Debt Accounts", "parent_path": None, "type": "Liability", "opening_date": date(1970, 1, 1), + "opening_balance": Decimal("0.00"), }, { "name": "Credit Card", "parent_path": "/Accounts/Liability/Debt Accounts/", "type": "Liability", "opening_date": date(1970, 1, 1), + "opening_balance": Decimal("0.00"), }, ] diff --git a/api/app/account/resource.py b/api/app/account/resource.py index 65309c1..4442b87 100644 --- a/api/app/account/resource.py +++ b/api/app/account/resource.py @@ -1,6 +1,8 @@ from sqlmodel import select from account.models import Account +from transaction.models import Split, Transaction + class AccountResource: @classmethod @@ -18,6 +20,25 @@ class AccountResource: model.parent_account = cls.get(session, model.parent_account_id) return model.parent_account + @classmethod + def create_opening_transaction(cls, session, account, schema): + + equity_account = cls.get_by_path(session, "/Equity/") + t = Transaction() + split_opening = Split() + split_opening.id = 0 + split_opening.transaction = t + split_opening.account = account + split_opening.amount = schema.opening_balance + + split_equity = Split() + split_equity.id = 1 + split_equity.transaction = t + split_equity.account = equity_account + split_equity.amount = - schema.opening_balance + + account.transaction_splits.append(split_opening) + @classmethod def schema_to_model(cls, session, schema, model=None): try: @@ -27,6 +48,7 @@ class AccountResource: schema.path = "" schema.family = "" model = Account.model_validate(schema) + cls.create_opening_transaction(session, model, schema) except Exception as e: print(e) raise @@ -34,6 +56,7 @@ class AccountResource: model.compute_family() cls.validate_parent(session, model) model.compute_path() + return model @classmethod @@ -68,12 +91,10 @@ class AccountResource: @classmethod def create_equity_account(cls, session): - account_db = Account(name="Equity", family="Equity", type="Equity", path="/Equity/") - session.add(account_db) - session.commit() - session.refresh(account_db) - - return account_db + if cls.get_by_path(session, "/Equity/") is None: + account_db = Account(name="Equity", family="Equity", type="Equity", path="/Equity/") + session.add(account_db) + session.commit() @classmethod def select(cls): diff --git a/api/app/account/schemas.py b/api/app/account/schemas.py index fb10721..23fdd63 100644 --- a/api/app/account/schemas.py +++ b/api/app/account/schemas.py @@ -63,6 +63,9 @@ class CategoryWrite(BaseAccountWrite): } }) + opening_date: SkipJsonSchema[date] = Field(default=date(1970, 1, 1)) + opening_balance: SkipJsonSchema[Decimal] = Field(default=0) + class CategoryCreate(CategoryWrite): pass diff --git a/api/app/core/types.py b/api/app/core/types.py index 21a05a4..d46a769 100644 --- a/api/app/core/types.py +++ b/api/app/core/types.py @@ -7,8 +7,6 @@ from pydantic_core import core_schema @dataclass class MonetaryAmount: - amount: Decimal# = Field(decimal_places=2, default=0) - @classmethod def __get_pydantic_core_schema__( cls, source: type[Any], handler: GetCoreSchemaHandler @@ -16,26 +14,24 @@ class MonetaryAmount: assert source is MonetaryAmount return core_schema.no_info_after_validator_function( cls._validate, - core_schema.float_schema(multiple_of=0.01), + core_schema.decimal_schema(multiple_of=0.01), serialization=core_schema.plain_serializer_function_ser_schema( cls._serialize, info_arg=False, - return_schema=core_schema.float_schema(multiple_of=0.01), + return_schema=core_schema.decimal_schema(multiple_of=0.01), ), ) - @staticmethod - def _validate(value: str) -> 'CompressedString': - inverse_dictionary: dict[str, int] = {} - text: list[int] = [] - for word in value.split(' '): - if word not in inverse_dictionary: - inverse_dictionary[word] = len(inverse_dictionary) - text.append(inverse_dictionary[word]) - return MonetaryAmount( - {v: k for k, v in inverse_dictionary.items()}, text - ) + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) @staticmethod - def _serialize(value: 'CompressedString') -> str: + def _validate(value: Decimal) -> 'MonetaryAmount': + if value.as_tuple()[2] < -2: + raise ValueError(f'{value} has more than two decimal places.') + + return value + + @staticmethod + def _serialize(value: 'MonetaryAmount') -> str: return value.amount diff --git a/api/app/initialize_db.py b/api/app/initialize_db.py index e7cc9c7..949bd1a 100644 --- a/api/app/initialize_db.py +++ b/api/app/initialize_db.py @@ -1,12 +1,14 @@ import main +from account.resource import AccountResource from account.fixtures import inject_fixtures as account_inject_fixtures from db import create_db_and_tables, get_session, drop_tables -from user import create_admin_account +from user import create_admin_user drop_tables() create_db_and_tables() session = get_session().__next__() -create_admin_account(session) +create_admin_user(session) +AccountResource.create_equity_account(session) account_inject_fixtures(session) diff --git a/api/app/transaction/models.py b/api/app/transaction/models.py index 9114e6c..b7b9cc2 100644 --- a/api/app/transaction/models.py +++ b/api/app/transaction/models.py @@ -1,4 +1,5 @@ from decimal import Decimal +from typing import Optional from uuid import UUID, uuid4 from sqlmodel import Field, SQLModel, select, Relationship @@ -50,7 +51,7 @@ class Transaction(TransactionBaseId, table=True): class SplitBase(SQLModel): account_id: UUID = Field(foreign_key="account.id") - payee_id: UUID = Field(foreign_key="payee.id") + payee_id: Optional[UUID] = Field(foreign_key="payee.id") amount: Decimal = Field(decimal_places=2) class SplitBaseId(SplitBase): @@ -74,7 +75,7 @@ class SplitWrite(SplitBase): } } }) - payee_id: UUID = PydField(json_schema_extra={ + payee_id: UUID | None = PydField(json_schema_extra={ "foreign_key": { "reference": { "resource": "payees",