Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
[Issue #178] Finish connecting new search parameters to backend queri…
Browse files Browse the repository at this point in the history
…es (#197)

## Summary
Fixes #178

### Time to review: __10 mins__

## Changes proposed
Adjusted the logic that connects the API requests to the builder in the
search layer to now connect all of the new fields.

A few minor validation adjustments to the API to prevent a few common
mistakes a user could make

## Context for reviewers
The length of the search tests are getting pretty long, I think a good
follow-up would be to split the test file into validation and response
testing.

I adjusted some validation/setup of the API schemas because I don't see
a scenario where min/max OR start/end dates would not ever be needed
together. This also let me add a quick validation rule that a user would
provide at least one of the values.

I adjusted some of the way the search_opportunities file was structured
as we only supported filtering by strings before, and it used the name
of the variables to determine the type. I made it a bit more explicit,
as before random variables could be passed through to the search layer
which seems potentially problematic if not filtered out somewhere.
  • Loading branch information
chouinar committed Sep 13, 2024
1 parent 2854d43 commit 4b0bd5b
Show file tree
Hide file tree
Showing 5 changed files with 514 additions and 110 deletions.
16 changes: 6 additions & 10 deletions api/src/api/opportunities_v1/opportunity_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,38 +337,34 @@ class OpportunitySearchFilterV1Schema(Schema):
)
expected_number_of_awards = fields.Nested(
IntegerSearchSchemaBuilder("ExpectedNumberAwardsFilterV1Schema")
.with_minimum_value(example=0)
.with_maximum_value(example=25)
.with_integer_range(min_example=0, max_example=25)
.build()
)

award_floor = fields.Nested(
IntegerSearchSchemaBuilder("AwardFloorFilterV1Schema")
.with_minimum_value(example=0)
.with_maximum_value(example=10_000)
.with_integer_range(min_example=0, max_example=10_000)
.build()
)

award_ceiling = fields.Nested(
IntegerSearchSchemaBuilder("AwardCeilingFilterV1Schema")
.with_minimum_value(example=0)
.with_maximum_value(example=10_000_000)
.with_integer_range(min_example=0, max_example=10_000_000)
.build()
)

estimated_total_program_funding = fields.Nested(
IntegerSearchSchemaBuilder("EstimatedTotalProgramFundingFilterV1Schema")
.with_minimum_value(example=0)
.with_maximum_value(example=10_000_000)
.with_integer_range(min_example=0, max_example=10_000_000)
.build()
)

post_date = fields.Nested(
DateSearchSchemaBuilder("PostDateFilterV1Schema").with_start_date().with_end_date().build()
DateSearchSchemaBuilder("PostDateFilterV1Schema").with_date_range().build()
)

close_date = fields.Nested(
DateSearchSchemaBuilder("CloseDateFilterV1Schema").with_start_date().with_end_date().build()
DateSearchSchemaBuilder("CloseDateFilterV1Schema").with_date_range().build()
)


Expand Down
87 changes: 64 additions & 23 deletions api/src/api/schemas/search_schema.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from enum import StrEnum
from typing import Any, Pattern, Type
from typing import Any, Callable, Pattern, Type

from marshmallow import ValidationError, validates_schema

Expand Down Expand Up @@ -37,8 +37,9 @@ def validates_non_empty(self, data: dict, **kwargs: Any) -> None:

class BaseSearchSchemaBuilder:
def __init__(self, schema_class_name: str):
# schema fields are the fields and functions of the class
self.schema_fields: dict[str, fields.MixinField | Callable[..., Any]] = {}
# The schema class name is used on the endpoint
self.schema_fields: dict[str, fields.MixinField] = {}
self.schema_class_name = schema_class_name

def build(self) -> Schema:
Expand Down Expand Up @@ -147,13 +148,23 @@ class IntegerSearchSchemaBuilder(BaseSearchSchemaBuilder):
class OpportunitySearchFilterSchema(Schema):
example_int_field = fields.Nested(
IntegerSearchSchemaBuilder("ExampleIntFieldSchema")
.with_minimum_value(example=1)
.with_maximum_value(example=25)
.with_integer_range(min_example=1, max_example=25)
.build()
)
"""

def with_minimum_value(
def with_integer_range(
self,
min_example: int | None = None,
max_example: int | None = None,
positive_only: bool = True,
) -> "IntegerSearchSchemaBuilder":
self._with_minimum_value(min_example, positive_only)
self._with_maximum_value(max_example, positive_only)
self._with_int_range_validator()
return self

def _with_minimum_value(
self, example: int | None = None, positive_only: bool = True
) -> "IntegerSearchSchemaBuilder":
metadata = {}
Expand All @@ -169,7 +180,7 @@ def with_minimum_value(
)
return self

def with_maximum_value(
def _with_maximum_value(
self, example: int | None = None, positive_only: bool = True
) -> "IntegerSearchSchemaBuilder":
metadata = {}
Expand All @@ -185,6 +196,28 @@ def with_maximum_value(
)
return self

def _with_int_range_validator(self) -> "IntegerSearchSchemaBuilder":
# Define a schema validator function that we'll use to define any
# rules that go across fields in the validation
@validates_schema
def validate_int_range(_: Any, data: dict, **kwargs: Any) -> None:
min_value = data.get("min", None)
max_value = data.get("max", None)

# Error if min and max value are None (either explicitly set, or because they are missing)
if min_value is None and max_value is None:
raise ValidationError(
[
MarshmallowErrorContainer(
ValidationErrorType.REQUIRED,
"At least one of min or max must be provided.",
)
]
)

self.schema_fields["validate_int_range"] = validate_int_range
return self


class BoolSearchSchemaBuilder(BaseSearchSchemaBuilder):
"""
Expand Down Expand Up @@ -250,30 +283,38 @@ class DateSearchSchemaBuilder(BaseSearchSchemaBuilder):
Usage::
# In a search request schema, you would use it like so:
example_start_date_field = fields.Nested(
DateSearchSchemaBuilder("ExampleStartDateFieldSchema")
.with_start_date()
.build()
)
example_end_date_field = fields.Nested(
DateSearchSchemaBuilder("ExampleEndDateFieldSchema")
.with_end_date()
.build()
)
example_startend_date_field = fields.Nested(
DateSearchSchemaBuilder("ExampleStartEndDateFieldSchema")
.with_start_date()
.with_end_date()
.with_date_range()
.build()
)
"""

def with_start_date(self) -> "DateSearchSchemaBuilder":
def with_date_range(self) -> "DateSearchSchemaBuilder":
self.schema_fields["start_date"] = fields.Date(allow_none=True)
self.schema_fields["end_date"] = fields.Date(allow_none=True)
self._with_date_range_validator()

return self

def with_end_date(self) -> "DateSearchSchemaBuilder":
self.schema_fields["end_date"] = fields.Date(allow_none=True)
def _with_date_range_validator(self) -> "DateSearchSchemaBuilder":
# Define a schema validator function that we'll use to define any
# rules that go across fields in the validation
@validates_schema
def validate_date_range(_: Any, data: dict, **kwargs: Any) -> None:
start_date = data.get("start_date", None)
end_date = data.get("end_date", None)

# Error if start and end date are None (either explicitly set, or because they are missing)
if start_date is None and end_date is None:
raise ValidationError(
[
MarshmallowErrorContainer(
ValidationErrorType.REQUIRED,
"At least one of start_date or end_date must be provided.",
)
]
)

self.schema_fields["validate_date_range"] = validate_date_range
return self
21 changes: 21 additions & 0 deletions api/src/search/search_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from datetime import date

from pydantic import BaseModel


class StrSearchFilter(BaseModel):
one_of: list[str] | None = None


class BoolSearchFilter(BaseModel):
one_of: list[bool] | None = None


class IntSearchFilter(BaseModel):
min: int | None = None
max: int | None = None


class DateSearchFilter(BaseModel):
start_date: date | None = None
end_date: date | None = None
61 changes: 53 additions & 8 deletions api/src/services/opportunities_v1/search_opportunities.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@
from src.api.opportunities_v1.opportunity_schemas import OpportunityV1Schema
from src.pagination.pagination_models import PaginationInfo, PaginationParams, SortDirection
from src.search.search_config import get_search_config
from src.search.search_models import (
BoolSearchFilter,
DateSearchFilter,
IntSearchFilter,
StrSearchFilter,
)

logger = logging.getLogger(__name__)

Expand All @@ -28,6 +34,11 @@
"funding_instrument": "summary.funding_instruments.keyword",
"funding_category": "summary.funding_categories.keyword",
"applicant_type": "summary.applicant_types.keyword",
"is_cost_sharing": "summary.is_cost_sharing",
"expected_number_of_awards": "summary.expected_number_of_awards",
"award_floor": "summary.award_floor",
"award_ceiling": "summary.award_ceiling",
"estimated_total_program_funding": "summary.estimated_total_program_funding",
}

SEARCH_FIELDS = [
Expand All @@ -45,11 +56,31 @@
SCHEMA = OpportunityV1Schema()


class OpportunityFilters(BaseModel):
applicant_type: StrSearchFilter | None = None
funding_instrument: StrSearchFilter | None = None
funding_category: StrSearchFilter | None = None
funding_applicant_type: StrSearchFilter | None = None
opportunity_status: StrSearchFilter | None = None
agency: StrSearchFilter | None = None
assistance_listing_number: StrSearchFilter | None = None

is_cost_sharing: BoolSearchFilter | None = None

expected_number_of_awards: IntSearchFilter | None = None
award_floor: IntSearchFilter | None = None
award_ceiling: IntSearchFilter | None = None
estimated_total_program_funding: IntSearchFilter | None = None

post_date: DateSearchFilter | None = None
close_date: DateSearchFilter | None = None


class SearchOpportunityParams(BaseModel):
pagination: PaginationParams

query: str | None = Field(default=None)
filters: dict | None = Field(default=None)
filters: OpportunityFilters | None = Field(default=None)


def _adjust_field_name(field: str) -> str:
Expand All @@ -68,16 +99,30 @@ def _get_sort_by(pagination: PaginationParams) -> list[tuple[str, SortDirection]
return sort_by


def _add_search_filters(builder: search.SearchQueryBuilder, filters: dict | None) -> None:
def _add_search_filters(
builder: search.SearchQueryBuilder, filters: OpportunityFilters | None
) -> None:
if filters is None:
return

for field, field_filters in filters.items():
# one_of filters translate to an opensearch term filter
# see: https://opensearch.org/docs/latest/query-dsl/term/terms/
one_of_filters = field_filters.get("one_of", None)
if one_of_filters:
builder.filter_terms(_adjust_field_name(field), one_of_filters)
for field in filters.model_fields_set:
field_filters = getattr(filters, field)
field_name = _adjust_field_name(field)

# We use the type of the search filter to determine what methods
# we call on the builder. This way we can make sure we have the proper
# type mappings.
if isinstance(field_filters, StrSearchFilter) and field_filters.one_of:
builder.filter_terms(field_name, field_filters.one_of)

elif isinstance(field_filters, BoolSearchFilter) and field_filters.one_of:
builder.filter_terms(field_name, field_filters.one_of)

elif isinstance(field_filters, IntSearchFilter):
builder.filter_int_range(field_name, field_filters.min, field_filters.max)

elif isinstance(field_filters, DateSearchFilter):
builder.filter_date_range(field_name, field_filters.start_date, field_filters.end_date)


def _add_aggregations(builder: search.SearchQueryBuilder) -> None:
Expand Down
Loading

0 comments on commit 4b0bd5b

Please sign in to comment.