From 56ade7f43b93b152628881f742fdde18c13561ea Mon Sep 17 00:00:00 2001 From: Pete Gadomski Date: Fri, 16 Jun 2023 09:45:08 -0600 Subject: [PATCH 1/3] feat: add test --- tests/api/test_api.py | 53 ++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 52 insertions(+), 1 deletion(-) diff --git a/tests/api/test_api.py b/tests/api/test_api.py index ae766eb..62335f2 100644 --- a/tests/api/test_api.py +++ b/tests/api/test_api.py @@ -1,10 +1,17 @@ from datetime import datetime, timedelta -from typing import Any, Dict, List +from typing import Any, Callable, Coroutine, Dict, List, Optional, TypeVar from urllib.parse import quote_plus import orjson import pytest +from fastapi import Request +from httpx import AsyncClient from pystac import Collection, Extent, Item, SpatialExtent, TemporalExtent +from stac_fastapi.api.app import StacApi +from stac_fastapi.types import stac as stac_types + +from stac_fastapi.pgstac.core import CoreCrudClient, Settings +from stac_fastapi.pgstac.db import close_db_connection, connect_to_db STAC_CORE_ROUTES = [ "GET /", @@ -622,3 +629,47 @@ async def search(query: Dict[str, Any]) -> List[Item]: } items = await search(query) assert len(items) == 10, items + + +@pytest.mark.asyncio +async def test_wrapped_function() -> None: + # Ensure wrappers, e.g. Planetary Computer's rate limiting, work. + # https://github.com/gadomski/planetary-computer-apis/blob/2719ccf6ead3e06de0784c39a2918d4d1811368b/pccommon/pccommon/redis.py#L205-L238 + + T = TypeVar("T") + + def wrap() -> ( + Callable[ + [Callable[..., Coroutine[Any, Any, T]]], + Callable[..., Coroutine[Any, Any, T]], + ] + ): + def decorator( + fn: Callable[..., Coroutine[Any, Any, T]] + ) -> Callable[..., Coroutine[Any, Any, T]]: + async def _wrapper(*args: Any, **kwargs: Any) -> T: + request: Optional[Request] = kwargs.get("request") + if request: + pass # This is where rate limiting would be applied + else: + raise ValueError(f"Missing request in {fn.__name__}") + return await fn(*args, **kwargs) + + return _wrapper + + return decorator + + class Client(CoreCrudClient): + @wrap() + async def all_collections(self, **kwargs) -> stac_types.Collections: + return await super().all_collections(**kwargs) + + api = StacApi(client=Client(), settings=Settings(testing=True)) + app = api.app + await connect_to_db(app) + try: + async with AsyncClient(app=app) as client: + response = await client.get("http://test/collections") + assert response.status_code == 200 + finally: + await close_db_connection(app) From 62bd0477fa29bb75ad098781d250ea11b9714c8f Mon Sep 17 00:00:00 2001 From: Pete Gadomski Date: Fri, 16 Jun 2023 11:47:29 -0600 Subject: [PATCH 2/3] fixup the test --- tests/api/test_api.py | 43 +++++++++++++++++++++++++++++++++++++------ 1 file changed, 37 insertions(+), 6 deletions(-) diff --git a/tests/api/test_api.py b/tests/api/test_api.py index 62335f2..02f9505 100644 --- a/tests/api/test_api.py +++ b/tests/api/test_api.py @@ -8,10 +8,14 @@ from httpx import AsyncClient from pystac import Collection, Extent, Item, SpatialExtent, TemporalExtent from stac_fastapi.api.app import StacApi +from stac_fastapi.api.models import create_post_request_model +from stac_fastapi.extensions.core import FieldsExtension, TransactionExtension from stac_fastapi.types import stac as stac_types from stac_fastapi.pgstac.core import CoreCrudClient, Settings from stac_fastapi.pgstac.db import close_db_connection, connect_to_db +from stac_fastapi.pgstac.transactions import TransactionsClient +from stac_fastapi.pgstac.types.search import PgstacSearch STAC_CORE_ROUTES = [ "GET /", @@ -632,7 +636,7 @@ async def search(query: Dict[str, Any]) -> List[Item]: @pytest.mark.asyncio -async def test_wrapped_function() -> None: +async def test_wrapped_function(load_test_data) -> None: # Ensure wrappers, e.g. Planetary Computer's rate limiting, work. # https://github.com/gadomski/planetary-computer-apis/blob/2719ccf6ead3e06de0784c39a2918d4d1811368b/pccommon/pccommon/redis.py#L205-L238 @@ -661,15 +665,42 @@ async def _wrapper(*args: Any, **kwargs: Any) -> T: class Client(CoreCrudClient): @wrap() - async def all_collections(self, **kwargs) -> stac_types.Collections: - return await super().all_collections(**kwargs) + async def get_collection( + self, collection_id: str, request: Request, **kwargs + ) -> stac_types.Item: + return await super().get_collection( + collection_id, request=request, **kwargs + ) - api = StacApi(client=Client(), settings=Settings(testing=True)) + settings = Settings(testing=True) + extensions = [ + TransactionExtension(client=TransactionsClient(), settings=settings), + FieldsExtension(), + ] + post_request_model = create_post_request_model(extensions, base_model=PgstacSearch) + api = StacApi( + client=Client(post_request_model=post_request_model), + settings=settings, + extensions=extensions, + search_post_request_model=post_request_model, + ) app = api.app await connect_to_db(app) try: async with AsyncClient(app=app) as client: - response = await client.get("http://test/collections") - assert response.status_code == 200 + response = await client.post( + "http://test/collections", + json=load_test_data("test_collection.json"), + ) + assert response.status_code == 200 + response = await client.post( + "http://test/collections/test-collection/items", + json=load_test_data("test_item.json"), + ) + assert response.status_code == 200 + response = await client.get( + "http://test/collections/test-collection/items/test-item" + ) + assert response.status_code == 200 finally: await close_db_connection(app) From 1708e8e61f991dedbef9ae7b9f68050c5d3f5b28 Mon Sep 17 00:00:00 2001 From: Pete Gadomski Date: Tue, 13 Jun 2023 12:24:08 -0600 Subject: [PATCH 3/3] fix: pass request by name into methods This makes #22 less breaking. --- CHANGES.md | 4 ++++ stac_fastapi/pgstac/core.py | 12 ++++++------ 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index 1b93d96..7d2912b 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -2,6 +2,10 @@ ## [Unreleased] +### Fixed + +- Pass `request` by name when calling endpoints from other endpoints ([#44](https://github.com/stac-utils/stac-fastapi-pgstac/pull/44)) + ## [2.4.8] - 2023-06-08 ### Changed diff --git a/stac_fastapi/pgstac/core.py b/stac_fastapi/pgstac/core.py index 7e1b4a3..ef20f25 100644 --- a/stac_fastapi/pgstac/core.py +++ b/stac_fastapi/pgstac/core.py @@ -213,7 +213,7 @@ async def _add_item_links( if settings.use_api_hydrate: async def _get_base_item(collection_id: str) -> Dict[str, Any]: - return await self._get_base_item(collection_id, request) + return await self._get_base_item(collection_id, request=request) base_item_cache = settings.base_item_cache( fetch_base_item=_get_base_item, request=request @@ -267,7 +267,7 @@ async def item_collection( An ItemCollection. """ # If collection does not exist, NotFoundError wil be raised - await self.get_collection(collection_id, request) + await self.get_collection(collection_id, request=request) base_args = { "collections": [collection_id], @@ -285,7 +285,7 @@ async def item_collection( search_request = self.post_request_model( **clean, ) - item_collection = await self._search_base(search_request, request) + item_collection = await self._search_base(search_request, request=request) links = await ItemCollectionLinks( collection_id=collection_id, request=request ).get_links(extra_links=item_collection["links"]) @@ -307,12 +307,12 @@ async def get_item( Item. """ # If collection does not exist, NotFoundError wil be raised - await self.get_collection(collection_id, request) + await self.get_collection(collection_id, request=request) search_request = self.post_request_model( ids=[item_id], collections=[collection_id], limit=1 ) - item_collection = await self._search_base(search_request, request) + item_collection = await self._search_base(search_request, request=request) if not item_collection["features"]: raise NotFoundError( f"Item {item_id} in Collection {collection_id} does not exist." @@ -333,7 +333,7 @@ async def post_search( Returns: ItemCollection containing items which match the search criteria. """ - item_collection = await self._search_base(search_request, request) + item_collection = await self._search_base(search_request, request=request) return ItemCollection(**item_collection) async def get_search(