diff --git a/api/app/account/models.py b/api/app/account/models.py index 6096d42..c89fcf5 100644 --- a/api/app/account/models.py +++ b/api/app/account/models.py @@ -17,7 +17,7 @@ class AccountBaseId(AccountBase): id: UUID | None = Field(default_factory=uuid4, primary_key=True) family: str = Field(index=True) type: str = Field(index=True) - path: str = Field(index=True) + path: str = Field(index=True, unique=True) class Account(AccountBaseId, table=True): parent_account: Optional["Account"] = Relationship( @@ -40,12 +40,19 @@ class Account(AccountBaseId, table=True): def is_category(self): return self.type in [v.value for v in CategoryType] + def get_parent(self, session): + if self.parent_account_id is None: + return None + + self.parent_account = self.get(session, self.parent_account_id) + return self.parent_account + def get_path(self, session): if self.parent_account_id is None: return self.get_root_path() - parent = self.get(session, self.parent_account_id) - return parent.get_child_path() + self.parent_account = self.get(session, self.parent_account_id) + return self.parent_account.get_child_path() def get_family(self): if self.type in Asset: @@ -67,9 +74,27 @@ class Account(AccountBaseId, table=True): print(e) model.family = model.get_family() + cls.validate_parent(session, model) model.path = model.get_path(session) return model + @classmethod + def validate_parent(cls, session, model): + if model.parent_account_id is None: + return True + + parent = model.get_parent(session) + if not parent: + raise ValueError("Parent account not found.") + + if parent.family != model.family: + raise ValueError("Account family mismatch with parent account..") + + if parent.path.startswith(model.path): + raise ValueError("Parent Account is descendant") + + return True + @classmethod def create(cls, account, session): account_db = cls.schema_to_model(session, account)