diff --git a/api/rpk-api/firm/core/filter.py b/api/rpk-api/firm/core/filter.py index e5cc03b..666d10f 100644 --- a/api/rpk-api/firm/core/filter.py +++ b/api/rpk-api/firm/core/filter.py @@ -1,7 +1,7 @@ +from collections import defaultdict from collections.abc import Callable, Mapping -from typing import Any, Optional, Union - from pydantic import ValidationInfo, field_validator +from typing import Any, Optional, Union from fastapi_filter.base.filter import BaseFilterModel @@ -24,28 +24,6 @@ _odm_operator_transformer: dict[str, Callable[[Optional[str]], Optional[dict[str class Filter(BaseFilterModel): - """Base filter for beanie related filters. - - Example: - ```python - class MyModel: - id: PrimaryKey() - name: StringField(null=True) - count: IntField() - created_at: DatetimeField() - - class MyModelFilter(Filter): - id: Optional[int] - id__in: Optional[str] - count: Optional[int] - count__lte: Optional[int] - created_at__gt: Optional[datetime] - name__ne: Optional[str] - name__nin: Optional[list[str]] - name__isnull: Optional[bool] - ``` - """ - def sort(self): if not self.ordering_values: return None @@ -130,6 +108,51 @@ class Filter(BaseFilterModel): query[field_name] = value return query + @staticmethod + def field_exists(model, field_path: str) -> bool: + if "." in field_path: + [root, field] = field_path.split(".", 1) + return hasattr(model, "model_fields") and root in model.model_fields \ + and model.model_fields[root].discriminator == field + + return hasattr(model, field_path) or (hasattr(model, "model_fields") and field_path in model.model_fields) + + @field_validator("*", mode="before", check_fields=False) + def validate_order_by(cls, value, field: ValidationInfo): + if field.field_name != cls.Constants.ordering_field_name: + return value + + if not value: + return None + + field_name_usages = defaultdict(list) + duplicated_field_names = set() + + for field_name_with_direction in value: + field_name = field_name_with_direction.replace("-", "").replace("+", "") + + if not cls.field_exists(cls.Constants.model, field_name): + raise ValueError(f"{field_name} is not a valid ordering field.") + + field_name_usages[field_name].append(field_name_with_direction) + if len(field_name_usages[field_name]) > 1: + duplicated_field_names.add(field_name) + + if duplicated_field_names: + ambiguous_field_names = ", ".join( + [ + field_name_with_direction + for field_name in sorted(duplicated_field_names) + for field_name_with_direction in field_name_usages[field_name] + ] + ) + raise ValueError( + f"Field names can appear at most once for {cls.Constants.ordering_field_name}. " + f"The following was ambiguous: {ambiguous_field_names}." + ) + + return value + class FilterSchema(Filter): label__ilike: Optional[str] = None order_by: Optional[list[str]] = None