Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
Merge branch 'main' into btabaska/183-implement-opportunity-page
Browse files Browse the repository at this point in the history
  • Loading branch information
acouch committed Sep 17, 2024
2 parents eba976f + 5f919a8 commit 0a17f43
Show file tree
Hide file tree
Showing 36 changed files with 2,405 additions and 170 deletions.
7 changes: 5 additions & 2 deletions api/bin/create_erds.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import src.db.models.staging.opportunity as staging_opportunity_models
import src.db.models.staging.synopsis as staging_synopsis_models
import src.logging
from src.db.models import opportunity_models
from src.db.models import agency_models, opportunity_models
from src.db.models.transfer import topportunity_models

logger = logging.getLogger(__name__)
Expand All @@ -23,7 +23,10 @@
ERD_FOLDER = pathlib.Path(__file__).parent.resolve()

# If we want to generate separate files for more specific groups, we can set that up here
API_MODULES = (opportunity_models,)
API_MODULES = (
opportunity_models,
agency_models,
)
STAGING_TABLE_MODULES = (
staging_opportunity_models,
staging_forecast_models,
Expand Down
111 changes: 106 additions & 5 deletions api/src/adapters/search/opensearch_client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Any, Sequence
from typing import Any, Generator, Iterable

import opensearchpy

Expand Down Expand Up @@ -75,7 +75,7 @@ def delete_index(self, index_name: str) -> None:
def bulk_upsert(
self,
index_name: str,
records: Sequence[dict[str, Any]],
records: Iterable[dict[str, Any]],
primary_key_field: str,
*,
refresh: bool = True
Expand Down Expand Up @@ -103,10 +103,51 @@ def bulk_upsert(
logger.info(
"Upserting records to %s",
index_name,
extra={"index_name": index_name, "record_count": int(len(bulk_operations) / 2)},
extra={
"index_name": index_name,
"record_count": int(len(bulk_operations) / 2),
"operation": "update",
},
)
self._client.bulk(index=index_name, body=bulk_operations, refresh=refresh)

def bulk_delete(self, index_name: str, ids: Iterable[Any], *, refresh: bool = True) -> None:
"""
Bulk delete records from an index
See: https://opensearch.org/docs/latest/api-reference/document-apis/bulk/ for details.
In this method, we delete records based on the IDs passed in.
"""
bulk_operations = []

for _id in ids:
# { "delete": { "_id": "tt2229499" } }
bulk_operations.append({"delete": {"_id": _id}})

logger.info(
"Deleting records from %s",
index_name,
extra={
"index_name": index_name,
"record_count": len(bulk_operations),
"operation": "delete",
},
)
self._client.bulk(index=index_name, body=bulk_operations, refresh=refresh)

def index_exists(self, index_name: str) -> bool:
"""
Check if an index OR alias exists by a given name
"""
return self._client.indices.exists(index_name)

def alias_exists(self, alias_name: str) -> bool:
"""
Check if an alias exists
"""
existing_index_mapping = self._client.cat.aliases(alias_name, format="json")
return len(existing_index_mapping) > 0

def swap_alias_index(
self, index_name: str, alias_name: str, *, delete_prior_indexes: bool = False
) -> None:
Expand Down Expand Up @@ -144,11 +185,71 @@ def search_raw(self, index_name: str, search_query: dict) -> dict:
return self._client.search(index=index_name, body=search_query)

def search(
self, index_name: str, search_query: dict, include_scores: bool = True
self,
index_name: str,
search_query: dict,
include_scores: bool = True,
params: dict | None = None,
) -> SearchResponse:
response = self._client.search(index=index_name, body=search_query)
if params is None:
params = {}

response = self._client.search(index=index_name, body=search_query, params=params)
return SearchResponse.from_opensearch_response(response, include_scores)

def scroll(
self,
index_name: str,
search_query: dict,
include_scores: bool = True,
duration: str = "10m",
) -> Generator[SearchResponse, None, None]:
"""
Scroll (iterate) over a large result set a given search query.
This query uses additional resources to keep the response open, but
keeps a consistent set of results and is useful for backend processes
that need to fetch a large amount of search data. After processing the results,
the scroll lock is closed for you.
This method is setup as a generator method and the results can be iterated over::
for response in search_client.scroll("my_index", {"size": 10000}):
for record in response.records:
process_record(record)
See: https://opensearch.org/docs/latest/api-reference/scroll/
"""

# start scroll
response = self.search(
index_name=index_name,
search_query=search_query,
include_scores=include_scores,
params={"scroll": duration},
)
scroll_id = response.scroll_id

yield response

# iterate
while True:
raw_response = self._client.scroll({"scroll_id": scroll_id, "scroll": duration})
response = SearchResponse.from_opensearch_response(raw_response, include_scores)

# The scroll ID can change between queries according to the docs, so we
# keep updating the value while iterating in case they change.
scroll_id = response.scroll_id

if len(response.records) == 0:
break

yield response

# close scroll
self._client.clear_scroll(scroll_id=scroll_id)


def _get_connection_parameters(opensearch_config: OpensearchConfig) -> dict[str, Any]:
# TODO - we'll want to add the AWS connection params here when we set that up
Expand Down
6 changes: 5 additions & 1 deletion api/src/adapters/search/opensearch_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ class SearchResponse:

aggregations: dict[str, dict[str, int]]

scroll_id: str | None

@classmethod
def from_opensearch_response(
cls, raw_json: dict[str, typing.Any], include_scores: bool = True
Expand Down Expand Up @@ -40,6 +42,8 @@ def from_opensearch_response(
]
}
"""
scroll_id = raw_json.get("_scroll_id", None)

hits = raw_json.get("hits", {})
hits_total = hits.get("total", {})
total_records = hits_total.get("value", 0)
Expand All @@ -59,7 +63,7 @@ def from_opensearch_response(
raw_aggs: dict[str, dict[str, typing.Any]] = raw_json.get("aggregations", {})
aggregations = _parse_aggregations(raw_aggs)

return cls(total_records, records, aggregations)
return cls(total_records, records, aggregations, scroll_id)


def _parse_aggregations(
Expand Down
16 changes: 6 additions & 10 deletions api/src/api/opportunities_v1/opportunity_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,38 +337,34 @@ class OpportunitySearchFilterV1Schema(Schema):
)
expected_number_of_awards = fields.Nested(
IntegerSearchSchemaBuilder("ExpectedNumberAwardsFilterV1Schema")
.with_minimum_value(example=0)
.with_maximum_value(example=25)
.with_integer_range(min_example=0, max_example=25)
.build()
)

award_floor = fields.Nested(
IntegerSearchSchemaBuilder("AwardFloorFilterV1Schema")
.with_minimum_value(example=0)
.with_maximum_value(example=10_000)
.with_integer_range(min_example=0, max_example=10_000)
.build()
)

award_ceiling = fields.Nested(
IntegerSearchSchemaBuilder("AwardCeilingFilterV1Schema")
.with_minimum_value(example=0)
.with_maximum_value(example=10_000_000)
.with_integer_range(min_example=0, max_example=10_000_000)
.build()
)

estimated_total_program_funding = fields.Nested(
IntegerSearchSchemaBuilder("EstimatedTotalProgramFundingFilterV1Schema")
.with_minimum_value(example=0)
.with_maximum_value(example=10_000_000)
.with_integer_range(min_example=0, max_example=10_000_000)
.build()
)

post_date = fields.Nested(
DateSearchSchemaBuilder("PostDateFilterV1Schema").with_start_date().with_end_date().build()
DateSearchSchemaBuilder("PostDateFilterV1Schema").with_date_range().build()
)

close_date = fields.Nested(
DateSearchSchemaBuilder("CloseDateFilterV1Schema").with_start_date().with_end_date().build()
DateSearchSchemaBuilder("CloseDateFilterV1Schema").with_date_range().build()
)


Expand Down
87 changes: 64 additions & 23 deletions api/src/api/schemas/search_schema.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from enum import StrEnum
from typing import Any, Pattern, Type
from typing import Any, Callable, Pattern, Type

from marshmallow import ValidationError, validates_schema

Expand Down Expand Up @@ -37,8 +37,9 @@ def validates_non_empty(self, data: dict, **kwargs: Any) -> None:

class BaseSearchSchemaBuilder:
def __init__(self, schema_class_name: str):
# schema fields are the fields and functions of the class
self.schema_fields: dict[str, fields.MixinField | Callable[..., Any]] = {}
# The schema class name is used on the endpoint
self.schema_fields: dict[str, fields.MixinField] = {}
self.schema_class_name = schema_class_name

def build(self) -> Schema:
Expand Down Expand Up @@ -147,13 +148,23 @@ class IntegerSearchSchemaBuilder(BaseSearchSchemaBuilder):
class OpportunitySearchFilterSchema(Schema):
example_int_field = fields.Nested(
IntegerSearchSchemaBuilder("ExampleIntFieldSchema")
.with_minimum_value(example=1)
.with_maximum_value(example=25)
.with_integer_range(min_example=1, max_example=25)
.build()
)
"""

def with_minimum_value(
def with_integer_range(
self,
min_example: int | None = None,
max_example: int | None = None,
positive_only: bool = True,
) -> "IntegerSearchSchemaBuilder":
self._with_minimum_value(min_example, positive_only)
self._with_maximum_value(max_example, positive_only)
self._with_int_range_validator()
return self

def _with_minimum_value(
self, example: int | None = None, positive_only: bool = True
) -> "IntegerSearchSchemaBuilder":
metadata = {}
Expand All @@ -169,7 +180,7 @@ def with_minimum_value(
)
return self

def with_maximum_value(
def _with_maximum_value(
self, example: int | None = None, positive_only: bool = True
) -> "IntegerSearchSchemaBuilder":
metadata = {}
Expand All @@ -185,6 +196,28 @@ def with_maximum_value(
)
return self

def _with_int_range_validator(self) -> "IntegerSearchSchemaBuilder":
# Define a schema validator function that we'll use to define any
# rules that go across fields in the validation
@validates_schema
def validate_int_range(_: Any, data: dict, **kwargs: Any) -> None:
min_value = data.get("min", None)
max_value = data.get("max", None)

# Error if min and max value are None (either explicitly set, or because they are missing)
if min_value is None and max_value is None:
raise ValidationError(
[
MarshmallowErrorContainer(
ValidationErrorType.REQUIRED,
"At least one of min or max must be provided.",
)
]
)

self.schema_fields["validate_int_range"] = validate_int_range
return self


class BoolSearchSchemaBuilder(BaseSearchSchemaBuilder):
"""
Expand Down Expand Up @@ -250,30 +283,38 @@ class DateSearchSchemaBuilder(BaseSearchSchemaBuilder):
Usage::
# In a search request schema, you would use it like so:
example_start_date_field = fields.Nested(
DateSearchSchemaBuilder("ExampleStartDateFieldSchema")
.with_start_date()
.build()
)
example_end_date_field = fields.Nested(
DateSearchSchemaBuilder("ExampleEndDateFieldSchema")
.with_end_date()
.build()
)
example_startend_date_field = fields.Nested(
DateSearchSchemaBuilder("ExampleStartEndDateFieldSchema")
.with_start_date()
.with_end_date()
.with_date_range()
.build()
)
"""

def with_start_date(self) -> "DateSearchSchemaBuilder":
def with_date_range(self) -> "DateSearchSchemaBuilder":
self.schema_fields["start_date"] = fields.Date(allow_none=True)
self.schema_fields["end_date"] = fields.Date(allow_none=True)
self._with_date_range_validator()

return self

def with_end_date(self) -> "DateSearchSchemaBuilder":
self.schema_fields["end_date"] = fields.Date(allow_none=True)
def _with_date_range_validator(self) -> "DateSearchSchemaBuilder":
# Define a schema validator function that we'll use to define any
# rules that go across fields in the validation
@validates_schema
def validate_date_range(_: Any, data: dict, **kwargs: Any) -> None:
start_date = data.get("start_date", None)
end_date = data.get("end_date", None)

# Error if start and end date are None (either explicitly set, or because they are missing)
if start_date is None and end_date is None:
raise ValidationError(
[
MarshmallowErrorContainer(
ValidationErrorType.REQUIRED,
"At least one of start_date or end_date must be provided.",
)
]
)

self.schema_fields["validate_date_range"] = validate_date_range
return self
11 changes: 11 additions & 0 deletions api/src/constants/lookup_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,3 +105,14 @@ class FundingInstrument(StrEnum):
GRANT = "grant" # G
PROCUREMENT_CONTRACT = "procurement_contract" # PC
OTHER = "other" # O


class AgencyDownloadFileType(StrEnum):
XML = "xml"
PDF = "pdf"


class AgencySubmissionNotificationSetting(StrEnum):
NEVER = "never"
FIRST_APPLICATION_ONLY = "first_application_only"
ALWAYS = "always"
Loading

0 comments on commit 0a17f43

Please sign in to comment.