Skip to content

Commit

Permalink
refactor aggregate in database_logic (#294)
Browse files Browse the repository at this point in the history
**Related Issue(s):**
N/A

**Description:**
Refactor `aggregate()` in database logic to allow extending the
supported set of aggregations. The mapping of aggregation name to
Elasticsearch/OpenSearch functionality was in the `aggregate()`
function, which made it difficult to alter the set of supported
aggregations. I moved the mapping to a property of the database logic,
so it can be modified when the database logic is instantiated.

**PR Checklist:**

- [x] Code is formatted and linted (run `pre-commit run --all-files`)
- [x] Tests pass (run `make test`)
- [x] Documentation has been updated to reflect changes, if applicable
- [x] Changes are added to the changelog

---------

Co-authored-by: Jonathan Healy <[email protected]>
  • Loading branch information
StijnCaerts and jonhealy1 committed Sep 5, 2024
1 parent 7b2b191 commit c25229e
Show file tree
Hide file tree
Showing 4 changed files with 225 additions and 170 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.

### Added

- Added `datetime_frequency_interval` parameter for `datetime_frequency` aggregation. [#294](https://github.com/stac-utils/stac-fastapi-elasticsearch-opensearch/pull/294)

### Changed

- Refactored aggregation in database logic. [#294](https://github.com/stac-utils/stac-fastapi-elasticsearch-opensearch/pull/294)
- Fixed the `self` link for the `/collections/{collection_id}/aggregations` endpoint. [#295](https://github.com/stac-utils/stac-fastapi-elasticsearch-opensearch/pull/295)

## [v3.1.0] - 2024-09-02
Expand Down
35 changes: 35 additions & 0 deletions stac_fastapi/core/stac_fastapi/core/extensions/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class EsAggregationExtensionGetRequest(
centroid_geotile_grid_frequency_precision: Optional[int] = attr.ib(default=None)
geometry_geohash_grid_frequency_precision: Optional[int] = attr.ib(default=None)
geometry_geotile_grid_frequency_precision: Optional[int] = attr.ib(default=None)
datetime_frequency_interval: Optional[str] = attr.ib(default=None)


class EsAggregationExtensionPostRequest(
Expand All @@ -62,6 +63,7 @@ class EsAggregationExtensionPostRequest(
centroid_geotile_grid_frequency_precision: Optional[int] = None
geometry_geohash_grid_frequency_precision: Optional[int] = None
geometry_geotile_grid_frequency_precision: Optional[int] = None
datetime_frequency_interval: Optional[str] = None


@attr.s
Expand Down Expand Up @@ -124,6 +126,8 @@ class EsAsyncAggregationClient(AsyncBaseAggregationClient):
MAX_GEOHASH_PRECISION = 12
MAX_GEOHEX_PRECISION = 15
MAX_GEOTILE_PRECISION = 29
SUPPORTED_DATETIME_INTERVAL = {"day", "month", "year"}
DEFAULT_DATETIME_INTERVAL = "month"

async def get_aggregations(self, collection_id: Optional[str] = None, **kwargs):
"""Get the available aggregations for a catalog or collection defined in the STAC JSON. If no aggregations, default aggregations are used."""
Expand Down Expand Up @@ -182,6 +186,30 @@ def extract_precision(
else:
return min_value

def extract_date_histogram_interval(self, value: Optional[str]) -> str:
"""
Ensure that the interval for the date histogram is valid. If no value is provided, the default will be returned.
Args:
value: value entered by the user
Returns:
string containing the date histogram interval to use.
Raises:
HTTPException: if the supplied value is not in the supported intervals
"""
if value is not None:
if value not in self.SUPPORTED_DATETIME_INTERVAL:
raise HTTPException(
status_code=400,
detail=f"Invalid datetime interval. Must be one of {self.SUPPORTED_DATETIME_INTERVAL}",
)
else:
return value
else:
return self.DEFAULT_DATETIME_INTERVAL

@staticmethod
def _return_date(
interval: Optional[Union[DateTimeType, str]]
Expand Down Expand Up @@ -319,6 +347,7 @@ async def aggregate(
centroid_geotile_grid_frequency_precision: Optional[int] = None,
geometry_geohash_grid_frequency_precision: Optional[int] = None,
geometry_geotile_grid_frequency_precision: Optional[int] = None,
datetime_frequency_interval: Optional[str] = None,
**kwargs,
) -> Union[Dict, Exception]:
"""Get aggregations from the database."""
Expand All @@ -339,6 +368,7 @@ async def aggregate(
"centroid_geotile_grid_frequency_precision": centroid_geotile_grid_frequency_precision,
"geometry_geohash_grid_frequency_precision": geometry_geohash_grid_frequency_precision,
"geometry_geotile_grid_frequency_precision": geometry_geotile_grid_frequency_precision,
"datetime_frequency_interval": datetime_frequency_interval,
}

if collection_id:
Expand Down Expand Up @@ -475,6 +505,10 @@ async def aggregate(
self.MAX_GEOTILE_PRECISION,
)

datetime_frequency_interval = self.extract_date_histogram_interval(
aggregate_request.datetime_frequency_interval,
)

try:
db_response = await self.database.aggregate(
collections,
Expand All @@ -485,6 +519,7 @@ async def aggregate(
centroid_geotile_grid_precision,
geometry_geohash_grid_precision,
geometry_geotile_grid_precision,
datetime_frequency_interval,
)
except Exception as error:
if not isinstance(error, IndexError):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
import os
from base64 import urlsafe_b64decode, urlsafe_b64encode
from copy import deepcopy
from typing import Any, Dict, Iterable, List, Optional, Protocol, Tuple, Type, Union

import attr
Expand Down Expand Up @@ -316,6 +317,77 @@ class DatabaseLogic:

extensions: List[str] = attr.ib(default=attr.Factory(list))

aggregation_mapping: Dict[str, Dict[str, Any]] = {
"total_count": {"value_count": {"field": "id"}},
"collection_frequency": {"terms": {"field": "collection", "size": 100}},
"platform_frequency": {"terms": {"field": "properties.platform", "size": 100}},
"cloud_cover_frequency": {
"range": {
"field": "properties.eo:cloud_cover",
"ranges": [
{"to": 5},
{"from": 5, "to": 15},
{"from": 15, "to": 40},
{"from": 40},
],
}
},
"datetime_frequency": {
"date_histogram": {
"field": "properties.datetime",
"calendar_interval": "month",
}
},
"datetime_min": {"min": {"field": "properties.datetime"}},
"datetime_max": {"max": {"field": "properties.datetime"}},
"grid_code_frequency": {
"terms": {
"field": "properties.grid:code",
"missing": "none",
"size": 10000,
}
},
"sun_elevation_frequency": {
"histogram": {"field": "properties.view:sun_elevation", "interval": 5}
},
"sun_azimuth_frequency": {
"histogram": {"field": "properties.view:sun_azimuth", "interval": 5}
},
"off_nadir_frequency": {
"histogram": {"field": "properties.view:off_nadir", "interval": 5}
},
"centroid_geohash_grid_frequency": {
"geohash_grid": {
"field": "properties.proj:centroid",
"precision": 1,
}
},
"centroid_geohex_grid_frequency": {
"geohex_grid": {
"field": "properties.proj:centroid",
"precision": 0,
}
},
"centroid_geotile_grid_frequency": {
"geotile_grid": {
"field": "properties.proj:centroid",
"precision": 0,
}
},
"geometry_geohash_grid_frequency": {
"geohash_grid": {
"field": "geometry",
"precision": 1,
}
},
"geometry_geotile_grid_frequency": {
"geotile_grid": {
"field": "geometry",
"precision": 0,
}
},
}

"""CORE LOGIC"""

async def get_all_collections(
Expand Down Expand Up @@ -657,104 +729,41 @@ async def aggregate(
centroid_geotile_grid_precision: int,
geometry_geohash_grid_precision: int,
geometry_geotile_grid_precision: int,
datetime_frequency_interval: str,
ignore_unavailable: Optional[bool] = True,
):
"""Return aggregations of STAC Items."""
agg_2_es = {
"total_count": {"value_count": {"field": "id"}},
"collection_frequency": {"terms": {"field": "collection", "size": 100}},
"platform_frequency": {
"terms": {"field": "properties.platform", "size": 100}
},
"cloud_cover_frequency": {
"range": {
"field": "properties.eo:cloud_cover",
"ranges": [
{"to": 5},
{"from": 5, "to": 15},
{"from": 15, "to": 40},
{"from": 40},
],
}
},
"datetime_frequency": {
"date_histogram": {
"field": "properties.datetime",
"calendar_interval": "month",
}
},
"datetime_min": {"min": {"field": "properties.datetime"}},
"datetime_max": {"max": {"field": "properties.datetime"}},
"grid_code_frequency": {
"terms": {
"field": "properties.grid:code",
"missing": "none",
"size": 10000,
}
},
"sun_elevation_frequency": {
"histogram": {"field": "properties.view:sun_elevation", "interval": 5}
},
"sun_azimuth_frequency": {
"histogram": {"field": "properties.view:sun_azimuth", "interval": 5}
},
"off_nadir_frequency": {
"histogram": {"field": "properties.view:off_nadir", "interval": 5}
},
}

search_body: Dict[str, Any] = {}
query = search.query.to_dict() if search.query else None
if query:
search_body["query"] = query

logger.debug("Aggregations: %s", aggregations)

# include all aggregations specified
# this will ignore aggregations with the wrong names
search_body["aggregations"] = {
k: v for k, v in agg_2_es.items() if k in aggregations
}

if "centroid_geohash_grid_frequency" in aggregations:
search_body["aggregations"]["centroid_geohash_grid_frequency"] = {
"geohash_grid": {
"field": "properties.proj:centroid",
"precision": centroid_geohash_grid_precision,
}
}

if "centroid_geohex_grid_frequency" in aggregations:
search_body["aggregations"]["centroid_geohex_grid_frequency"] = {
"geohex_grid": {
"field": "properties.proj:centroid",
"precision": centroid_geohex_grid_precision,
}
def _fill_aggregation_parameters(name: str, agg: dict) -> dict:
[key] = agg.keys()
agg_precision = {
"centroid_geohash_grid_frequency": centroid_geohash_grid_precision,
"centroid_geohex_grid_frequency": centroid_geohex_grid_precision,
"centroid_geotile_grid_frequency": centroid_geotile_grid_precision,
"geometry_geohash_grid_frequency": geometry_geohash_grid_precision,
"geometry_geotile_grid_frequency": geometry_geotile_grid_precision,
}
if name in agg_precision:
agg[key]["precision"] = agg_precision[name]

if "centroid_geotile_grid_frequency" in aggregations:
search_body["aggregations"]["centroid_geotile_grid_frequency"] = {
"geotile_grid": {
"field": "properties.proj:centroid",
"precision": centroid_geotile_grid_precision,
}
}
if key == "date_histogram":
agg[key]["calendar_interval"] = datetime_frequency_interval

if "geometry_geohash_grid_frequency" in aggregations:
search_body["aggregations"]["geometry_geohash_grid_frequency"] = {
"geohash_grid": {
"field": "geometry",
"precision": geometry_geohash_grid_precision,
}
}
return agg

if "geometry_geotile_grid_frequency" in aggregations:
search_body["aggregations"]["geometry_geotile_grid_frequency"] = {
"geotile_grid": {
"field": "geometry",
"precision": geometry_geotile_grid_precision,
}
}
# include all aggregations specified
# this will ignore aggregations with the wrong names
search_body["aggregations"] = {
k: _fill_aggregation_parameters(k, deepcopy(v))
for k, v in self.aggregation_mapping.items()
if k in aggregations
}

index_param = indices(collection_ids)
search_task = asyncio.create_task(
Expand Down
Loading

0 comments on commit c25229e

Please sign in to comment.