Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass request by name into methods #44

Merged
merged 4 commits into from
Jun 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

## [Unreleased]

### Fixed

- Pass `request` by name when calling endpoints from other endpoints ([#44](https://github.com/stac-utils/stac-fastapi-pgstac/pull/44))

## [2.4.8] - 2023-06-08

### Changed
Expand Down
12 changes: 6 additions & 6 deletions stac_fastapi/pgstac/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ async def _add_item_links(
if settings.use_api_hydrate:

async def _get_base_item(collection_id: str) -> Dict[str, Any]:
return await self._get_base_item(collection_id, request)
return await self._get_base_item(collection_id, request=request)

base_item_cache = settings.base_item_cache(
fetch_base_item=_get_base_item, request=request
Expand Down Expand Up @@ -267,7 +267,7 @@ async def item_collection(
An ItemCollection.
"""
# If collection does not exist, NotFoundError wil be raised
await self.get_collection(collection_id, request)
await self.get_collection(collection_id, request=request)

base_args = {
"collections": [collection_id],
Expand All @@ -285,7 +285,7 @@ async def item_collection(
search_request = self.post_request_model(
**clean,
)
item_collection = await self._search_base(search_request, request)
item_collection = await self._search_base(search_request, request=request)
links = await ItemCollectionLinks(
collection_id=collection_id, request=request
).get_links(extra_links=item_collection["links"])
Expand All @@ -307,12 +307,12 @@ async def get_item(
Item.
"""
# If collection does not exist, NotFoundError wil be raised
await self.get_collection(collection_id, request)
await self.get_collection(collection_id, request=request)

search_request = self.post_request_model(
ids=[item_id], collections=[collection_id], limit=1
)
item_collection = await self._search_base(search_request, request)
item_collection = await self._search_base(search_request, request=request)
if not item_collection["features"]:
raise NotFoundError(
f"Item {item_id} in Collection {collection_id} does not exist."
Expand All @@ -333,7 +333,7 @@ async def post_search(
Returns:
ItemCollection containing items which match the search criteria.
"""
item_collection = await self._search_base(search_request, request)
item_collection = await self._search_base(search_request, request=request)
return ItemCollection(**item_collection)

async def get_search(
Expand Down
84 changes: 83 additions & 1 deletion tests/api/test_api.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,21 @@
from datetime import datetime, timedelta
from typing import Any, Dict, List
from typing import Any, Callable, Coroutine, Dict, List, Optional, TypeVar
from urllib.parse import quote_plus

import orjson
import pytest
from fastapi import Request
from httpx import AsyncClient
from pystac import Collection, Extent, Item, SpatialExtent, TemporalExtent
from stac_fastapi.api.app import StacApi
from stac_fastapi.api.models import create_post_request_model
from stac_fastapi.extensions.core import FieldsExtension, TransactionExtension
from stac_fastapi.types import stac as stac_types

from stac_fastapi.pgstac.core import CoreCrudClient, Settings
from stac_fastapi.pgstac.db import close_db_connection, connect_to_db
from stac_fastapi.pgstac.transactions import TransactionsClient
from stac_fastapi.pgstac.types.search import PgstacSearch

STAC_CORE_ROUTES = [
"GET /",
Expand Down Expand Up @@ -622,3 +633,74 @@ async def search(query: Dict[str, Any]) -> List[Item]:
}
items = await search(query)
assert len(items) == 10, items


@pytest.mark.asyncio
async def test_wrapped_function(load_test_data) -> None:
# Ensure wrappers, e.g. Planetary Computer's rate limiting, work.
# https://github.com/gadomski/planetary-computer-apis/blob/2719ccf6ead3e06de0784c39a2918d4d1811368b/pccommon/pccommon/redis.py#L205-L238

T = TypeVar("T")

def wrap() -> (
Callable[
[Callable[..., Coroutine[Any, Any, T]]],
Callable[..., Coroutine[Any, Any, T]],
]
):
def decorator(
fn: Callable[..., Coroutine[Any, Any, T]]
) -> Callable[..., Coroutine[Any, Any, T]]:
async def _wrapper(*args: Any, **kwargs: Any) -> T:
request: Optional[Request] = kwargs.get("request")
if request:
pass # This is where rate limiting would be applied
else:
raise ValueError(f"Missing request in {fn.__name__}")
return await fn(*args, **kwargs)

return _wrapper

return decorator

class Client(CoreCrudClient):
@wrap()
async def get_collection(
self, collection_id: str, request: Request, **kwargs
) -> stac_types.Item:
return await super().get_collection(
collection_id, request=request, **kwargs
)

settings = Settings(testing=True)
extensions = [
TransactionExtension(client=TransactionsClient(), settings=settings),
FieldsExtension(),
]
post_request_model = create_post_request_model(extensions, base_model=PgstacSearch)
api = StacApi(
client=Client(post_request_model=post_request_model),
settings=settings,
extensions=extensions,
search_post_request_model=post_request_model,
)
app = api.app
await connect_to_db(app)
try:
async with AsyncClient(app=app) as client:
response = await client.post(
"http://test/collections",
json=load_test_data("test_collection.json"),
)
assert response.status_code == 200
response = await client.post(
"http://test/collections/test-collection/items",
json=load_test_data("test_item.json"),
)
assert response.status_code == 200
response = await client.get(
"http://test/collections/test-collection/items/test-item"
)
assert response.status_code == 200
finally:
await close_db_connection(app)
Loading